diffusers 0.27.0__py3-none-any.whl → 0.32.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +233 -6
- diffusers/callbacks.py +209 -0
- diffusers/commands/env.py +102 -6
- diffusers/configuration_utils.py +45 -16
- diffusers/dependency_versions_table.py +4 -3
- diffusers/image_processor.py +434 -110
- diffusers/loaders/__init__.py +42 -9
- diffusers/loaders/ip_adapter.py +626 -36
- diffusers/loaders/lora_base.py +900 -0
- diffusers/loaders/lora_conversion_utils.py +991 -125
- diffusers/loaders/lora_pipeline.py +3812 -0
- diffusers/loaders/peft.py +571 -7
- diffusers/loaders/single_file.py +405 -173
- diffusers/loaders/single_file_model.py +385 -0
- diffusers/loaders/single_file_utils.py +1783 -713
- diffusers/loaders/textual_inversion.py +41 -23
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +464 -540
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +76 -7
- diffusers/models/activations.py +65 -10
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +605 -18
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +4304 -687
- diffusers/models/autoencoders/__init__.py +8 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +110 -28
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
- diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
- diffusers/models/autoencoders/vae.py +41 -29
- diffusers/models/autoencoders/vq_model.py +182 -0
- diffusers/models/controlnet.py +47 -800
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +68 -0
- diffusers/models/controlnet_sparsectrl.py +116 -0
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/controlnets/controlnet_xs.py +1946 -0
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/downsampling.py +85 -18
- diffusers/models/embeddings.py +1856 -158
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +480 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +2 -7
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +611 -146
- diffusers/models/normalization.py +361 -20
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformers/__init__.py +16 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/dual_transformer_2d.py +9 -8
- diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +445 -0
- diffusers/models/transformers/prior_transformer.py +13 -13
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +297 -187
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +593 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +461 -0
- diffusers/models/transformers/transformer_temporal.py +21 -19
- diffusers/models/unets/unet_1d.py +8 -8
- diffusers/models/unets/unet_1d_blocks.py +31 -31
- diffusers/models/unets/unet_2d.py +17 -10
- diffusers/models/unets/unet_2d_blocks.py +225 -149
- diffusers/models/unets/unet_2d_condition.py +50 -53
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +192 -1057
- diffusers/models/unets/unet_3d_condition.py +22 -27
- diffusers/models/unets/unet_i2vgen_xl.py +22 -18
- diffusers/models/unets/unet_kandinsky3.py +2 -2
- diffusers/models/unets/unet_motion_model.py +1413 -89
- diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
- diffusers/models/unets/unet_stable_cascade.py +19 -18
- diffusers/models/unets/uvit_2d.py +2 -2
- diffusers/models/upsampling.py +95 -26
- diffusers/models/vq_model.py +12 -164
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +202 -3
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +8 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
- diffusers/pipelines/auto_pipeline.py +196 -28
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/cogvideo/__init__.py +54 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
- diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
- diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
- diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +7 -4
- diffusers/pipelines/flux/__init__.py +69 -0
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +957 -0
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +37 -0
- diffusers/pipelines/free_init_utils.py +41 -38
- diffusers/pipelines/free_noise_utils.py +596 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +338 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/pag/__init__.py +80 -0
- diffusers/pipelines/pag/pag_utils.py +243 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +74 -164
- diffusers/pipelines/pipeline_flax_utils.py +5 -10
- diffusers/pipelines/pipeline_loading_utils.py +515 -53
- diffusers/pipelines/pipeline_utils.py +411 -222
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
- diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
- diffusers/schedulers/__init__.py +12 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +27 -26
- diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
- diffusers/schedulers/scheduling_ddpm.py +27 -30
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +150 -50
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
- diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
- diffusers/schedulers/scheduling_edm_euler.py +62 -39
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
- diffusers/schedulers/scheduling_euler_discrete.py +255 -74
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
- diffusers/schedulers/scheduling_heun_discrete.py +174 -46
- diffusers/schedulers/scheduling_ipndm.py +9 -9
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +23 -29
- diffusers/schedulers/scheduling_lms_discrete.py +105 -28
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +21 -21
- diffusers/schedulers/scheduling_sasolver.py +157 -60
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +41 -36
- diffusers/schedulers/scheduling_unclip.py +19 -16
- diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
- diffusers/schedulers/scheduling_utils.py +12 -5
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +214 -30
- diffusers/utils/__init__.py +17 -1
- diffusers/utils/constants.py +3 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +592 -7
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
- diffusers/utils/dynamic_modules_utils.py +34 -29
- diffusers/utils/export_utils.py +50 -6
- diffusers/utils/hub_utils.py +131 -17
- diffusers/utils/import_utils.py +210 -8
- diffusers/utils/loading_utils.py +118 -5
- diffusers/utils/logging.py +4 -2
- diffusers/utils/peft_utils.py +37 -7
- diffusers/utils/state_dict_utils.py +13 -2
- diffusers/utils/testing_utils.py +193 -11
- diffusers/utils/torch_utils.py +4 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
- diffusers-0.32.2.dist-info/RECORD +550 -0
- {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1349
- diffusers/models/prior_transformer.py +0 -12
- diffusers/models/t5_film_transformer.py +0 -70
- diffusers/models/transformer_2d.py +0 -25
- diffusers/models/transformer_temporal.py +0 -34
- diffusers/models/unet_1d.py +0 -26
- diffusers/models/unet_1d_blocks.py +0 -203
- diffusers/models/unet_2d.py +0 -27
- diffusers/models/unet_2d_blocks.py +0 -375
- diffusers/models/unet_2d_condition.py +0 -25
- diffusers-0.27.0.dist-info/RECORD +0 -399
- {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
- {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -22,15 +22,20 @@ from pathlib import Path
|
|
22
22
|
from typing import Any, Dict, List, Optional, Union
|
23
23
|
|
24
24
|
import torch
|
25
|
-
from huggingface_hub import
|
26
|
-
|
27
|
-
)
|
25
|
+
from huggingface_hub import ModelCard, model_info
|
26
|
+
from huggingface_hub.utils import validate_hf_hub_args
|
28
27
|
from packaging import version
|
29
28
|
|
29
|
+
from .. import __version__
|
30
30
|
from ..utils import (
|
31
|
+
FLAX_WEIGHTS_NAME,
|
32
|
+
ONNX_EXTERNAL_WEIGHTS_NAME,
|
33
|
+
ONNX_WEIGHTS_NAME,
|
31
34
|
SAFETENSORS_WEIGHTS_NAME,
|
32
35
|
WEIGHTS_NAME,
|
36
|
+
deprecate,
|
33
37
|
get_class_from_dynamic_module,
|
38
|
+
is_accelerate_available,
|
34
39
|
is_peft_available,
|
35
40
|
is_transformers_available,
|
36
41
|
logging,
|
@@ -44,9 +49,12 @@ if is_transformers_available():
|
|
44
49
|
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
45
50
|
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
46
51
|
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
47
|
-
from huggingface_hub.utils import validate_hf_hub_args
|
48
52
|
|
49
|
-
|
53
|
+
if is_accelerate_available():
|
54
|
+
import accelerate
|
55
|
+
from accelerate import dispatch_model
|
56
|
+
from accelerate.hooks import remove_hook_from_module
|
57
|
+
from accelerate.utils import compute_module_sizes, get_max_memory
|
50
58
|
|
51
59
|
|
52
60
|
INDEX_FILE = "diffusion_pytorch_model.bin"
|
@@ -82,49 +90,50 @@ for library in LOADABLE_CLASSES:
|
|
82
90
|
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
83
91
|
|
84
92
|
|
85
|
-
def is_safetensors_compatible(filenames,
|
93
|
+
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
|
86
94
|
"""
|
87
95
|
Checking for safetensors compatibility:
|
88
|
-
-
|
89
|
-
|
90
|
-
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
|
96
|
+
- The model is safetensors compatible only if there is a safetensors file for each model component present in
|
97
|
+
filenames.
|
91
98
|
|
92
99
|
Converting default pytorch serialized filenames to safetensors serialized filenames:
|
93
100
|
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
|
94
101
|
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
|
95
102
|
extension is replaced with ".safetensors"
|
96
103
|
"""
|
97
|
-
pt_filenames = []
|
98
|
-
|
99
|
-
sf_filenames = set()
|
100
|
-
|
101
104
|
passed_components = passed_components or []
|
105
|
+
if folder_names is not None:
|
106
|
+
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
|
102
107
|
|
108
|
+
# extract all components of the pipeline and their associated files
|
109
|
+
components = {}
|
103
110
|
for filename in filenames:
|
104
|
-
|
111
|
+
if not len(filename.split("/")) == 2:
|
112
|
+
continue
|
105
113
|
|
106
|
-
|
114
|
+
component, component_filename = filename.split("/")
|
115
|
+
if component in passed_components:
|
107
116
|
continue
|
108
117
|
|
109
|
-
|
110
|
-
|
111
|
-
elif extension == ".safetensors":
|
112
|
-
sf_filenames.add(os.path.normpath(filename))
|
118
|
+
components.setdefault(component, [])
|
119
|
+
components[component].append(component_filename)
|
113
120
|
|
114
|
-
for
|
115
|
-
|
116
|
-
|
117
|
-
filename, extension = os.path.splitext(filename)
|
121
|
+
# If there are no component folders check the main directory for safetensors files
|
122
|
+
if not components:
|
123
|
+
return any(".safetensors" in filename for filename in filenames)
|
118
124
|
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
125
|
+
# iterate over all files of a component
|
126
|
+
# check if safetensor files exist for that component
|
127
|
+
# if variant is provided check if the variant of the safetensors exists
|
128
|
+
for component, component_filenames in components.items():
|
129
|
+
matches = []
|
130
|
+
for component_filename in component_filenames:
|
131
|
+
filename, extension = os.path.splitext(component_filename)
|
132
|
+
|
133
|
+
match_exists = extension == ".safetensors"
|
134
|
+
matches.append(match_exists)
|
123
135
|
|
124
|
-
|
125
|
-
expected_sf_filename = f"{expected_sf_filename}.safetensors"
|
126
|
-
if expected_sf_filename not in sf_filenames:
|
127
|
-
logger.warning(f"{expected_sf_filename} not found")
|
136
|
+
if not any(matches):
|
128
137
|
return False
|
129
138
|
|
130
139
|
return True
|
@@ -189,10 +198,31 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
|
|
189
198
|
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
|
190
199
|
return variant_filename
|
191
200
|
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
201
|
+
def find_component(filename):
|
202
|
+
if not len(filename.split("/")) == 2:
|
203
|
+
return
|
204
|
+
component = filename.split("/")[0]
|
205
|
+
return component
|
206
|
+
|
207
|
+
def has_sharded_variant(component, variant, variant_filenames):
|
208
|
+
# If component exists check for sharded variant index filename
|
209
|
+
# If component doesn't exist check main dir for sharded variant index filename
|
210
|
+
component = component + "/" if component else ""
|
211
|
+
variant_index_re = re.compile(
|
212
|
+
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
213
|
+
)
|
214
|
+
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
|
215
|
+
|
216
|
+
for filename in non_variant_filenames:
|
217
|
+
if convert_to_variant(filename) in variant_filenames:
|
218
|
+
continue
|
219
|
+
|
220
|
+
component = find_component(filename)
|
221
|
+
# If a sharded variant exists skip adding to allowed patterns
|
222
|
+
if has_sharded_variant(component, variant, variant_filenames):
|
223
|
+
continue
|
224
|
+
|
225
|
+
usable_filenames.add(filename)
|
196
226
|
|
197
227
|
return usable_filenames, variant_filenames
|
198
228
|
|
@@ -292,6 +322,39 @@ def get_class_obj_and_candidates(
|
|
292
322
|
return class_obj, class_candidates
|
293
323
|
|
294
324
|
|
325
|
+
def _get_custom_pipeline_class(
|
326
|
+
custom_pipeline,
|
327
|
+
repo_id=None,
|
328
|
+
hub_revision=None,
|
329
|
+
class_name=None,
|
330
|
+
cache_dir=None,
|
331
|
+
revision=None,
|
332
|
+
):
|
333
|
+
if custom_pipeline.endswith(".py"):
|
334
|
+
path = Path(custom_pipeline)
|
335
|
+
# decompose into folder & file
|
336
|
+
file_name = path.name
|
337
|
+
custom_pipeline = path.parent.absolute()
|
338
|
+
elif repo_id is not None:
|
339
|
+
file_name = f"{custom_pipeline}.py"
|
340
|
+
custom_pipeline = repo_id
|
341
|
+
else:
|
342
|
+
file_name = CUSTOM_PIPELINE_FILE_NAME
|
343
|
+
|
344
|
+
if repo_id is not None and hub_revision is not None:
|
345
|
+
# if we load the pipeline code from the Hub
|
346
|
+
# make sure to overwrite the `revision`
|
347
|
+
revision = hub_revision
|
348
|
+
|
349
|
+
return get_class_from_dynamic_module(
|
350
|
+
custom_pipeline,
|
351
|
+
module_file=file_name,
|
352
|
+
class_name=class_name,
|
353
|
+
cache_dir=cache_dir,
|
354
|
+
revision=revision,
|
355
|
+
)
|
356
|
+
|
357
|
+
|
295
358
|
def _get_pipeline_class(
|
296
359
|
class_obj,
|
297
360
|
config=None,
|
@@ -304,25 +367,10 @@ def _get_pipeline_class(
|
|
304
367
|
revision=None,
|
305
368
|
):
|
306
369
|
if custom_pipeline is not None:
|
307
|
-
|
308
|
-
path = Path(custom_pipeline)
|
309
|
-
# decompose into folder & file
|
310
|
-
file_name = path.name
|
311
|
-
custom_pipeline = path.parent.absolute()
|
312
|
-
elif repo_id is not None:
|
313
|
-
file_name = f"{custom_pipeline}.py"
|
314
|
-
custom_pipeline = repo_id
|
315
|
-
else:
|
316
|
-
file_name = CUSTOM_PIPELINE_FILE_NAME
|
317
|
-
|
318
|
-
if repo_id is not None and hub_revision is not None:
|
319
|
-
# if we load the pipeline code from the Hub
|
320
|
-
# make sure to overwrite the `revision`
|
321
|
-
revision = hub_revision
|
322
|
-
|
323
|
-
return get_class_from_dynamic_module(
|
370
|
+
return _get_custom_pipeline_class(
|
324
371
|
custom_pipeline,
|
325
|
-
|
372
|
+
repo_id=repo_id,
|
373
|
+
hub_revision=hub_revision,
|
326
374
|
class_name=class_name,
|
327
375
|
cache_dir=cache_dir,
|
328
376
|
revision=revision,
|
@@ -358,6 +406,206 @@ def _get_pipeline_class(
|
|
358
406
|
return pipeline_cls
|
359
407
|
|
360
408
|
|
409
|
+
def _load_empty_model(
|
410
|
+
library_name: str,
|
411
|
+
class_name: str,
|
412
|
+
importable_classes: List[Any],
|
413
|
+
pipelines: Any,
|
414
|
+
is_pipeline_module: bool,
|
415
|
+
name: str,
|
416
|
+
torch_dtype: Union[str, torch.dtype],
|
417
|
+
cached_folder: Union[str, os.PathLike],
|
418
|
+
**kwargs,
|
419
|
+
):
|
420
|
+
# retrieve class objects.
|
421
|
+
class_obj, _ = get_class_obj_and_candidates(
|
422
|
+
library_name,
|
423
|
+
class_name,
|
424
|
+
importable_classes,
|
425
|
+
pipelines,
|
426
|
+
is_pipeline_module,
|
427
|
+
component_name=name,
|
428
|
+
cache_dir=cached_folder,
|
429
|
+
)
|
430
|
+
|
431
|
+
if is_transformers_available():
|
432
|
+
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
433
|
+
else:
|
434
|
+
transformers_version = "N/A"
|
435
|
+
|
436
|
+
# Determine library.
|
437
|
+
is_transformers_model = (
|
438
|
+
is_transformers_available()
|
439
|
+
and issubclass(class_obj, PreTrainedModel)
|
440
|
+
and transformers_version >= version.parse("4.20.0")
|
441
|
+
)
|
442
|
+
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
443
|
+
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
444
|
+
|
445
|
+
model = None
|
446
|
+
config_path = cached_folder
|
447
|
+
user_agent = {
|
448
|
+
"diffusers": __version__,
|
449
|
+
"file_type": "model",
|
450
|
+
"framework": "pytorch",
|
451
|
+
}
|
452
|
+
|
453
|
+
if is_diffusers_model:
|
454
|
+
# Load config and then the model on meta.
|
455
|
+
config, unused_kwargs, commit_hash = class_obj.load_config(
|
456
|
+
os.path.join(config_path, name),
|
457
|
+
cache_dir=cached_folder,
|
458
|
+
return_unused_kwargs=True,
|
459
|
+
return_commit_hash=True,
|
460
|
+
force_download=kwargs.pop("force_download", False),
|
461
|
+
proxies=kwargs.pop("proxies", None),
|
462
|
+
local_files_only=kwargs.pop("local_files_only", False),
|
463
|
+
token=kwargs.pop("token", None),
|
464
|
+
revision=kwargs.pop("revision", None),
|
465
|
+
subfolder=kwargs.pop("subfolder", None),
|
466
|
+
user_agent=user_agent,
|
467
|
+
)
|
468
|
+
with accelerate.init_empty_weights():
|
469
|
+
model = class_obj.from_config(config, **unused_kwargs)
|
470
|
+
elif is_transformers_model:
|
471
|
+
config_class = getattr(class_obj, "config_class", None)
|
472
|
+
if config_class is None:
|
473
|
+
raise ValueError("`config_class` cannot be None. Please double-check the model.")
|
474
|
+
|
475
|
+
config = config_class.from_pretrained(
|
476
|
+
cached_folder,
|
477
|
+
subfolder=name,
|
478
|
+
force_download=kwargs.pop("force_download", False),
|
479
|
+
proxies=kwargs.pop("proxies", None),
|
480
|
+
local_files_only=kwargs.pop("local_files_only", False),
|
481
|
+
token=kwargs.pop("token", None),
|
482
|
+
revision=kwargs.pop("revision", None),
|
483
|
+
user_agent=user_agent,
|
484
|
+
)
|
485
|
+
with accelerate.init_empty_weights():
|
486
|
+
model = class_obj(config)
|
487
|
+
|
488
|
+
if model is not None:
|
489
|
+
model = model.to(dtype=torch_dtype)
|
490
|
+
return model
|
491
|
+
|
492
|
+
|
493
|
+
def _assign_components_to_devices(
|
494
|
+
module_sizes: Dict[str, float], device_memory: Dict[str, float], device_mapping_strategy: str = "balanced"
|
495
|
+
):
|
496
|
+
device_ids = list(device_memory.keys())
|
497
|
+
device_cycle = device_ids + device_ids[::-1]
|
498
|
+
device_memory = device_memory.copy()
|
499
|
+
|
500
|
+
device_id_component_mapping = {}
|
501
|
+
current_device_index = 0
|
502
|
+
for component in module_sizes:
|
503
|
+
device_id = device_cycle[current_device_index % len(device_cycle)]
|
504
|
+
component_memory = module_sizes[component]
|
505
|
+
curr_device_memory = device_memory[device_id]
|
506
|
+
|
507
|
+
# If the GPU doesn't fit the current component offload to the CPU.
|
508
|
+
if component_memory > curr_device_memory:
|
509
|
+
device_id_component_mapping["cpu"] = [component]
|
510
|
+
else:
|
511
|
+
if device_id not in device_id_component_mapping:
|
512
|
+
device_id_component_mapping[device_id] = [component]
|
513
|
+
else:
|
514
|
+
device_id_component_mapping[device_id].append(component)
|
515
|
+
|
516
|
+
# Update the device memory.
|
517
|
+
device_memory[device_id] -= component_memory
|
518
|
+
current_device_index += 1
|
519
|
+
|
520
|
+
return device_id_component_mapping
|
521
|
+
|
522
|
+
|
523
|
+
def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
|
524
|
+
# To avoid circular import problem.
|
525
|
+
from diffusers import pipelines
|
526
|
+
|
527
|
+
torch_dtype = kwargs.get("torch_dtype", torch.float32)
|
528
|
+
|
529
|
+
# Load each module in the pipeline on a meta device so that we can derive the device map.
|
530
|
+
init_empty_modules = {}
|
531
|
+
for name, (library_name, class_name) in init_dict.items():
|
532
|
+
if class_name.startswith("Flax"):
|
533
|
+
raise ValueError("Flax pipelines are not supported with `device_map`.")
|
534
|
+
|
535
|
+
# Define all importable classes
|
536
|
+
is_pipeline_module = hasattr(pipelines, library_name)
|
537
|
+
importable_classes = ALL_IMPORTABLE_CLASSES
|
538
|
+
loaded_sub_model = None
|
539
|
+
|
540
|
+
# Use passed sub model or load class_name from library_name
|
541
|
+
if name in passed_class_obj:
|
542
|
+
# if the model is in a pipeline module, then we load it from the pipeline
|
543
|
+
# check that passed_class_obj has correct parent class
|
544
|
+
maybe_raise_or_warn(
|
545
|
+
library_name,
|
546
|
+
library,
|
547
|
+
class_name,
|
548
|
+
importable_classes,
|
549
|
+
passed_class_obj,
|
550
|
+
name,
|
551
|
+
is_pipeline_module,
|
552
|
+
)
|
553
|
+
with accelerate.init_empty_weights():
|
554
|
+
loaded_sub_model = passed_class_obj[name]
|
555
|
+
|
556
|
+
else:
|
557
|
+
loaded_sub_model = _load_empty_model(
|
558
|
+
library_name=library_name,
|
559
|
+
class_name=class_name,
|
560
|
+
importable_classes=importable_classes,
|
561
|
+
pipelines=pipelines,
|
562
|
+
is_pipeline_module=is_pipeline_module,
|
563
|
+
pipeline_class=pipeline_class,
|
564
|
+
name=name,
|
565
|
+
torch_dtype=torch_dtype,
|
566
|
+
cached_folder=kwargs.get("cached_folder", None),
|
567
|
+
force_download=kwargs.get("force_download", None),
|
568
|
+
proxies=kwargs.get("proxies", None),
|
569
|
+
local_files_only=kwargs.get("local_files_only", None),
|
570
|
+
token=kwargs.get("token", None),
|
571
|
+
revision=kwargs.get("revision", None),
|
572
|
+
)
|
573
|
+
|
574
|
+
if loaded_sub_model is not None:
|
575
|
+
init_empty_modules[name] = loaded_sub_model
|
576
|
+
|
577
|
+
# determine device map
|
578
|
+
# Obtain a sorted dictionary for mapping the model-level components
|
579
|
+
# to their sizes.
|
580
|
+
module_sizes = {
|
581
|
+
module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
|
582
|
+
for module_name, module in init_empty_modules.items()
|
583
|
+
if isinstance(module, torch.nn.Module)
|
584
|
+
}
|
585
|
+
module_sizes = dict(sorted(module_sizes.items(), key=lambda item: item[1], reverse=True))
|
586
|
+
|
587
|
+
# Obtain maximum memory available per device (GPUs only).
|
588
|
+
max_memory = get_max_memory(max_memory)
|
589
|
+
max_memory = dict(sorted(max_memory.items(), key=lambda item: item[1], reverse=True))
|
590
|
+
max_memory = {k: v for k, v in max_memory.items() if k != "cpu"}
|
591
|
+
|
592
|
+
# Obtain a dictionary mapping the model-level components to the available
|
593
|
+
# devices based on the maximum memory and the model sizes.
|
594
|
+
final_device_map = None
|
595
|
+
if len(max_memory) > 0:
|
596
|
+
device_id_component_mapping = _assign_components_to_devices(
|
597
|
+
module_sizes, max_memory, device_mapping_strategy=device_map
|
598
|
+
)
|
599
|
+
|
600
|
+
# Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}`
|
601
|
+
final_device_map = {}
|
602
|
+
for device_id, components in device_id_component_mapping.items():
|
603
|
+
for component in components:
|
604
|
+
final_device_map[component] = device_id
|
605
|
+
|
606
|
+
return final_device_map
|
607
|
+
|
608
|
+
|
361
609
|
def load_sub_model(
|
362
610
|
library_name: str,
|
363
611
|
class_name: str,
|
@@ -378,9 +626,12 @@ def load_sub_model(
|
|
378
626
|
variant: str,
|
379
627
|
low_cpu_mem_usage: bool,
|
380
628
|
cached_folder: Union[str, os.PathLike],
|
629
|
+
use_safetensors: bool,
|
381
630
|
):
|
382
631
|
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
632
|
+
|
383
633
|
# retrieve class candidates
|
634
|
+
|
384
635
|
class_obj, class_candidates = get_class_obj_and_candidates(
|
385
636
|
library_name,
|
386
637
|
class_name,
|
@@ -445,6 +696,7 @@ def load_sub_model(
|
|
445
696
|
loading_kwargs["offload_folder"] = offload_folder
|
446
697
|
loading_kwargs["offload_state_dict"] = offload_state_dict
|
447
698
|
loading_kwargs["variant"] = model_variants.pop(name, None)
|
699
|
+
loading_kwargs["use_safetensors"] = use_safetensors
|
448
700
|
|
449
701
|
if from_flax:
|
450
702
|
loading_kwargs["from_flax"] = True
|
@@ -475,6 +727,22 @@ def load_sub_model(
|
|
475
727
|
# else load from the root directory
|
476
728
|
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
477
729
|
|
730
|
+
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
|
731
|
+
# remove hooks
|
732
|
+
remove_hook_from_module(loaded_sub_model, recurse=True)
|
733
|
+
needs_offloading_to_cpu = device_map[""] == "cpu"
|
734
|
+
|
735
|
+
if needs_offloading_to_cpu:
|
736
|
+
dispatch_model(
|
737
|
+
loaded_sub_model,
|
738
|
+
state_dict=loaded_sub_model.state_dict(),
|
739
|
+
device_map=device_map,
|
740
|
+
force_hooks=True,
|
741
|
+
main_device=0,
|
742
|
+
)
|
743
|
+
else:
|
744
|
+
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
|
745
|
+
|
478
746
|
return loaded_sub_model
|
479
747
|
|
480
748
|
|
@@ -506,3 +774,197 @@ def _fetch_class_library_tuple(module):
|
|
506
774
|
class_name = not_compiled_module.__class__.__name__
|
507
775
|
|
508
776
|
return (library, class_name)
|
777
|
+
|
778
|
+
|
779
|
+
def _identify_model_variants(folder: str, variant: str, config: dict) -> dict:
|
780
|
+
model_variants = {}
|
781
|
+
if variant is not None:
|
782
|
+
for sub_folder in os.listdir(folder):
|
783
|
+
folder_path = os.path.join(folder, sub_folder)
|
784
|
+
is_folder = os.path.isdir(folder_path) and sub_folder in config
|
785
|
+
variant_exists = is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
|
786
|
+
if variant_exists:
|
787
|
+
model_variants[sub_folder] = variant
|
788
|
+
return model_variants
|
789
|
+
|
790
|
+
|
791
|
+
def _resolve_custom_pipeline_and_cls(folder, config, custom_pipeline):
|
792
|
+
custom_class_name = None
|
793
|
+
if os.path.isfile(os.path.join(folder, f"{custom_pipeline}.py")):
|
794
|
+
custom_pipeline = os.path.join(folder, f"{custom_pipeline}.py")
|
795
|
+
elif isinstance(config["_class_name"], (list, tuple)) and os.path.isfile(
|
796
|
+
os.path.join(folder, f"{config['_class_name'][0]}.py")
|
797
|
+
):
|
798
|
+
custom_pipeline = os.path.join(folder, f"{config['_class_name'][0]}.py")
|
799
|
+
custom_class_name = config["_class_name"][1]
|
800
|
+
|
801
|
+
return custom_pipeline, custom_class_name
|
802
|
+
|
803
|
+
|
804
|
+
def _maybe_raise_warning_for_inpainting(pipeline_class, pretrained_model_name_or_path: str, config: dict):
|
805
|
+
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
806
|
+
version.parse(config["_diffusers_version"]).base_version
|
807
|
+
) <= version.parse("0.5.1"):
|
808
|
+
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
|
809
|
+
|
810
|
+
pipeline_class = StableDiffusionInpaintPipelineLegacy
|
811
|
+
|
812
|
+
deprecation_message = (
|
813
|
+
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
|
814
|
+
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
|
815
|
+
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
|
816
|
+
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
|
817
|
+
f" checkpoint {pretrained_model_name_or_path} to the format of"
|
818
|
+
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
|
819
|
+
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
|
820
|
+
)
|
821
|
+
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
|
822
|
+
|
823
|
+
|
824
|
+
def _update_init_kwargs_with_connected_pipeline(
|
825
|
+
init_kwargs: dict, passed_pipe_kwargs: dict, passed_class_objs: dict, folder: str, **pipeline_loading_kwargs
|
826
|
+
) -> dict:
|
827
|
+
from .pipeline_utils import DiffusionPipeline
|
828
|
+
|
829
|
+
modelcard = ModelCard.load(os.path.join(folder, "README.md"))
|
830
|
+
connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
|
831
|
+
|
832
|
+
# We don't scheduler argument to match the existing logic:
|
833
|
+
# https://github.com/huggingface/diffusers/blob/867e0c919e1aa7ef8b03c8eb1460f4f875a683ae/src/diffusers/pipelines/pipeline_utils.py#L906C13-L925C14
|
834
|
+
pipeline_loading_kwargs_cp = pipeline_loading_kwargs.copy()
|
835
|
+
if pipeline_loading_kwargs_cp is not None and len(pipeline_loading_kwargs_cp) >= 1:
|
836
|
+
for k in pipeline_loading_kwargs:
|
837
|
+
if "scheduler" in k:
|
838
|
+
_ = pipeline_loading_kwargs_cp.pop(k)
|
839
|
+
|
840
|
+
def get_connected_passed_kwargs(prefix):
|
841
|
+
connected_passed_class_obj = {
|
842
|
+
k.replace(f"{prefix}_", ""): w for k, w in passed_class_objs.items() if k.split("_")[0] == prefix
|
843
|
+
}
|
844
|
+
connected_passed_pipe_kwargs = {
|
845
|
+
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
|
846
|
+
}
|
847
|
+
|
848
|
+
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
|
849
|
+
return connected_passed_kwargs
|
850
|
+
|
851
|
+
connected_pipes = {
|
852
|
+
prefix: DiffusionPipeline.from_pretrained(
|
853
|
+
repo_id, **pipeline_loading_kwargs_cp, **get_connected_passed_kwargs(prefix)
|
854
|
+
)
|
855
|
+
for prefix, repo_id in connected_pipes.items()
|
856
|
+
if repo_id is not None
|
857
|
+
}
|
858
|
+
|
859
|
+
for prefix, connected_pipe in connected_pipes.items():
|
860
|
+
# add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
|
861
|
+
init_kwargs.update(
|
862
|
+
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
|
863
|
+
)
|
864
|
+
|
865
|
+
return init_kwargs
|
866
|
+
|
867
|
+
|
868
|
+
def _get_custom_components_and_folders(
|
869
|
+
pretrained_model_name: str,
|
870
|
+
config_dict: Dict[str, Any],
|
871
|
+
filenames: Optional[List[str]] = None,
|
872
|
+
variant_filenames: Optional[List[str]] = None,
|
873
|
+
variant: Optional[str] = None,
|
874
|
+
):
|
875
|
+
config_dict = config_dict.copy()
|
876
|
+
|
877
|
+
# retrieve all folder_names that contain relevant files
|
878
|
+
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
|
879
|
+
|
880
|
+
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
881
|
+
pipelines = getattr(diffusers_module, "pipelines")
|
882
|
+
|
883
|
+
# optionally create a custom component <> custom file mapping
|
884
|
+
custom_components = {}
|
885
|
+
for component in folder_names:
|
886
|
+
module_candidate = config_dict[component][0]
|
887
|
+
|
888
|
+
if module_candidate is None or not isinstance(module_candidate, str):
|
889
|
+
continue
|
890
|
+
|
891
|
+
# We compute candidate file path on the Hub. Do not use `os.path.join`.
|
892
|
+
candidate_file = f"{component}/{module_candidate}.py"
|
893
|
+
|
894
|
+
if candidate_file in filenames:
|
895
|
+
custom_components[component] = module_candidate
|
896
|
+
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
|
897
|
+
raise ValueError(
|
898
|
+
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
|
899
|
+
)
|
900
|
+
|
901
|
+
if len(variant_filenames) == 0 and variant is not None:
|
902
|
+
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
903
|
+
raise ValueError(error_message)
|
904
|
+
|
905
|
+
return custom_components, folder_names
|
906
|
+
|
907
|
+
|
908
|
+
def _get_ignore_patterns(
|
909
|
+
passed_components,
|
910
|
+
model_folder_names: List[str],
|
911
|
+
model_filenames: List[str],
|
912
|
+
variant_filenames: List[str],
|
913
|
+
use_safetensors: bool,
|
914
|
+
from_flax: bool,
|
915
|
+
allow_pickle: bool,
|
916
|
+
use_onnx: bool,
|
917
|
+
is_onnx: bool,
|
918
|
+
variant: Optional[str] = None,
|
919
|
+
) -> List[str]:
|
920
|
+
if (
|
921
|
+
use_safetensors
|
922
|
+
and not allow_pickle
|
923
|
+
and not is_safetensors_compatible(
|
924
|
+
model_filenames, passed_components=passed_components, folder_names=model_folder_names
|
925
|
+
)
|
926
|
+
):
|
927
|
+
raise EnvironmentError(
|
928
|
+
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
|
929
|
+
)
|
930
|
+
|
931
|
+
if from_flax:
|
932
|
+
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
|
933
|
+
|
934
|
+
elif use_safetensors and is_safetensors_compatible(
|
935
|
+
model_filenames, passed_components=passed_components, folder_names=model_folder_names
|
936
|
+
):
|
937
|
+
ignore_patterns = ["*.bin", "*.msgpack"]
|
938
|
+
|
939
|
+
use_onnx = use_onnx if use_onnx is not None else is_onnx
|
940
|
+
if not use_onnx:
|
941
|
+
ignore_patterns += ["*.onnx", "*.pb"]
|
942
|
+
|
943
|
+
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
|
944
|
+
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
|
945
|
+
if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
|
946
|
+
logger.warning(
|
947
|
+
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
|
948
|
+
f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
|
949
|
+
f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
|
950
|
+
f"expected, please check your folder structure."
|
951
|
+
)
|
952
|
+
|
953
|
+
else:
|
954
|
+
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
955
|
+
|
956
|
+
use_onnx = use_onnx if use_onnx is not None else is_onnx
|
957
|
+
if not use_onnx:
|
958
|
+
ignore_patterns += ["*.onnx", "*.pb"]
|
959
|
+
|
960
|
+
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
|
961
|
+
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
|
962
|
+
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
963
|
+
logger.warning(
|
964
|
+
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
|
965
|
+
f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
|
966
|
+
f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
|
967
|
+
f"your folder structure."
|
968
|
+
)
|
969
|
+
|
970
|
+
return ignore_patterns
|