diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2025 The HuggingFace Inc. team.
|
3
3
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4
4
|
#
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -20,16 +20,21 @@ import itertools
|
|
20
20
|
import json
|
21
21
|
import os
|
22
22
|
import re
|
23
|
+
import shutil
|
24
|
+
import tempfile
|
23
25
|
from collections import OrderedDict
|
24
|
-
from
|
26
|
+
from contextlib import ExitStack, contextmanager
|
27
|
+
from functools import wraps
|
25
28
|
from pathlib import Path
|
26
|
-
from typing import Any, Callable, List, Optional, Tuple, Union
|
29
|
+
from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Type, Union
|
27
30
|
|
28
31
|
import safetensors
|
29
32
|
import torch
|
30
|
-
|
33
|
+
import torch.utils.checkpoint
|
34
|
+
from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards
|
31
35
|
from huggingface_hub.utils import validate_hf_hub_args
|
32
36
|
from torch import Tensor, nn
|
37
|
+
from typing_extensions import Self
|
33
38
|
|
34
39
|
from .. import __version__
|
35
40
|
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
|
@@ -48,6 +53,7 @@ from ..utils import (
|
|
48
53
|
is_accelerate_available,
|
49
54
|
is_bitsandbytes_available,
|
50
55
|
is_bitsandbytes_version,
|
56
|
+
is_peft_available,
|
51
57
|
is_torch_version,
|
52
58
|
logging,
|
53
59
|
)
|
@@ -61,16 +67,49 @@ from .model_loading_utils import (
|
|
61
67
|
_fetch_index_file,
|
62
68
|
_fetch_index_file_legacy,
|
63
69
|
_load_state_dict_into_model,
|
64
|
-
_merge_sharded_checkpoints,
|
65
70
|
load_model_dict_into_meta,
|
66
71
|
load_state_dict,
|
67
72
|
)
|
68
73
|
|
69
74
|
|
75
|
+
class ContextManagers:
|
76
|
+
"""
|
77
|
+
Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
|
78
|
+
in the `fastcore` library.
|
79
|
+
"""
|
80
|
+
|
81
|
+
def __init__(self, context_managers: List[ContextManager]):
|
82
|
+
self.context_managers = context_managers
|
83
|
+
self.stack = ExitStack()
|
84
|
+
|
85
|
+
def __enter__(self):
|
86
|
+
for context_manager in self.context_managers:
|
87
|
+
self.stack.enter_context(context_manager)
|
88
|
+
|
89
|
+
def __exit__(self, *args, **kwargs):
|
90
|
+
self.stack.__exit__(*args, **kwargs)
|
91
|
+
|
92
|
+
|
70
93
|
logger = logging.get_logger(__name__)
|
71
94
|
|
72
95
|
_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
|
73
96
|
|
97
|
+
TORCH_INIT_FUNCTIONS = {
|
98
|
+
"uniform_": nn.init.uniform_,
|
99
|
+
"normal_": nn.init.normal_,
|
100
|
+
"trunc_normal_": nn.init.trunc_normal_,
|
101
|
+
"constant_": nn.init.constant_,
|
102
|
+
"xavier_uniform_": nn.init.xavier_uniform_,
|
103
|
+
"xavier_normal_": nn.init.xavier_normal_,
|
104
|
+
"kaiming_uniform_": nn.init.kaiming_uniform_,
|
105
|
+
"kaiming_normal_": nn.init.kaiming_normal_,
|
106
|
+
"uniform": nn.init.uniform,
|
107
|
+
"normal": nn.init.normal,
|
108
|
+
"xavier_uniform": nn.init.xavier_uniform,
|
109
|
+
"xavier_normal": nn.init.xavier_normal,
|
110
|
+
"kaiming_uniform": nn.init.kaiming_uniform,
|
111
|
+
"kaiming_normal": nn.init.kaiming_normal,
|
112
|
+
}
|
74
113
|
|
75
114
|
if is_torch_version(">=", "1.9.0"):
|
76
115
|
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
@@ -80,10 +119,22 @@ else:
|
|
80
119
|
|
81
120
|
if is_accelerate_available():
|
82
121
|
import accelerate
|
122
|
+
from accelerate import dispatch_model
|
123
|
+
from accelerate.utils import load_offloaded_weights, save_offload_index
|
83
124
|
|
84
125
|
|
85
126
|
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
127
|
+
from ..hooks.group_offloading import _get_group_onload_device
|
128
|
+
|
129
|
+
try:
|
130
|
+
# Try to get the onload device from the group offloading hook
|
131
|
+
return _get_group_onload_device(parameter)
|
132
|
+
except ValueError:
|
133
|
+
pass
|
134
|
+
|
86
135
|
try:
|
136
|
+
# If the onload device is not available due to no group offloading hooks, try to get the device
|
137
|
+
# from the first parameter or buffer
|
87
138
|
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
88
139
|
return next(parameters_and_buffers).device
|
89
140
|
except StopIteration:
|
@@ -102,9 +153,24 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
|
102
153
|
"""
|
103
154
|
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
|
104
155
|
"""
|
156
|
+
# 1. Check if we have attached any dtype modifying hooks (eg. layerwise casting)
|
157
|
+
if isinstance(parameter, nn.Module):
|
158
|
+
for name, submodule in parameter.named_modules():
|
159
|
+
if not hasattr(submodule, "_diffusers_hook"):
|
160
|
+
continue
|
161
|
+
registry = submodule._diffusers_hook
|
162
|
+
hook = registry.get_hook("layerwise_casting")
|
163
|
+
if hook is not None:
|
164
|
+
return hook.compute_dtype
|
165
|
+
|
166
|
+
# 2. If no dtype modifying hooks are attached, return the dtype of the first floating point parameter/buffer
|
105
167
|
last_dtype = None
|
106
|
-
|
168
|
+
|
169
|
+
for name, param in parameter.named_parameters():
|
107
170
|
last_dtype = param.dtype
|
171
|
+
if parameter._keep_in_fp32_modules and any(m in name for m in parameter._keep_in_fp32_modules):
|
172
|
+
continue
|
173
|
+
|
108
174
|
if param.is_floating_point():
|
109
175
|
return param.dtype
|
110
176
|
|
@@ -134,6 +200,54 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
|
134
200
|
return last_tuple[1].dtype
|
135
201
|
|
136
202
|
|
203
|
+
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
|
204
|
+
"""
|
205
|
+
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
|
206
|
+
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
|
207
|
+
parameters.
|
208
|
+
|
209
|
+
"""
|
210
|
+
if model_to_load.device.type == "meta":
|
211
|
+
return False
|
212
|
+
|
213
|
+
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
|
214
|
+
return False
|
215
|
+
|
216
|
+
# Some models explicitly do not support param buffer assignment
|
217
|
+
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
|
218
|
+
logger.debug(
|
219
|
+
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
|
220
|
+
)
|
221
|
+
return False
|
222
|
+
|
223
|
+
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
|
224
|
+
first_key = next(iter(model_to_load.state_dict().keys()))
|
225
|
+
if start_prefix + first_key in state_dict:
|
226
|
+
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
|
227
|
+
|
228
|
+
return False
|
229
|
+
|
230
|
+
|
231
|
+
@contextmanager
|
232
|
+
def no_init_weights():
|
233
|
+
"""
|
234
|
+
Context manager to globally disable weight initialization to speed up loading large models. To do that, all the
|
235
|
+
torch.nn.init function are all replaced with skip.
|
236
|
+
"""
|
237
|
+
|
238
|
+
def _skip_init(*args, **kwargs):
|
239
|
+
pass
|
240
|
+
|
241
|
+
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
242
|
+
setattr(torch.nn.init, name, _skip_init)
|
243
|
+
try:
|
244
|
+
yield
|
245
|
+
finally:
|
246
|
+
# Restore the original initialization functions
|
247
|
+
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
248
|
+
setattr(torch.nn.init, name, init_func)
|
249
|
+
|
250
|
+
|
137
251
|
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
138
252
|
r"""
|
139
253
|
Base class for all models.
|
@@ -150,10 +264,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
150
264
|
_keys_to_ignore_on_load_unexpected = None
|
151
265
|
_no_split_modules = None
|
152
266
|
_keep_in_fp32_modules = None
|
267
|
+
_skip_layerwise_casting_patterns = None
|
268
|
+
_supports_group_offloading = True
|
153
269
|
|
154
270
|
def __init__(self):
|
155
271
|
super().__init__()
|
156
272
|
|
273
|
+
self._gradient_checkpointing_func = None
|
274
|
+
|
157
275
|
def __getattr__(self, name: str) -> Any:
|
158
276
|
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
159
277
|
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
|
@@ -179,14 +297,35 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
179
297
|
"""
|
180
298
|
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
181
299
|
|
182
|
-
def enable_gradient_checkpointing(self) -> None:
|
300
|
+
def enable_gradient_checkpointing(self, gradient_checkpointing_func: Optional[Callable] = None) -> None:
|
183
301
|
"""
|
184
302
|
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
185
303
|
*checkpoint activations* in other frameworks).
|
304
|
+
|
305
|
+
Args:
|
306
|
+
gradient_checkpointing_func (`Callable`, *optional*):
|
307
|
+
The function to use for gradient checkpointing. If `None`, the default PyTorch checkpointing function
|
308
|
+
is used (`torch.utils.checkpoint.checkpoint`).
|
186
309
|
"""
|
187
310
|
if not self._supports_gradient_checkpointing:
|
188
|
-
raise ValueError(
|
189
|
-
|
311
|
+
raise ValueError(
|
312
|
+
f"{self.__class__.__name__} does not support gradient checkpointing. Please make sure to set the boolean attribute "
|
313
|
+
f"`_supports_gradient_checkpointing` to `True` in the class definition."
|
314
|
+
)
|
315
|
+
|
316
|
+
if gradient_checkpointing_func is None:
|
317
|
+
|
318
|
+
def _gradient_checkpointing_func(module, *args):
|
319
|
+
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
320
|
+
return torch.utils.checkpoint.checkpoint(
|
321
|
+
module.__call__,
|
322
|
+
*args,
|
323
|
+
**ckpt_kwargs,
|
324
|
+
)
|
325
|
+
|
326
|
+
gradient_checkpointing_func = _gradient_checkpointing_func
|
327
|
+
|
328
|
+
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
190
329
|
|
191
330
|
def disable_gradient_checkpointing(self) -> None:
|
192
331
|
"""
|
@@ -194,7 +333,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
194
333
|
*checkpoint activations* in other frameworks).
|
195
334
|
"""
|
196
335
|
if self._supports_gradient_checkpointing:
|
197
|
-
self.
|
336
|
+
self._set_gradient_checkpointing(enable=False)
|
198
337
|
|
199
338
|
def set_use_npu_flash_attention(self, valid: bool) -> None:
|
200
339
|
r"""
|
@@ -227,14 +366,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
227
366
|
self.set_use_npu_flash_attention(False)
|
228
367
|
|
229
368
|
def set_use_xla_flash_attention(
|
230
|
-
self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None
|
369
|
+
self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None, **kwargs
|
231
370
|
) -> None:
|
232
371
|
# Recursively walk through all the children.
|
233
372
|
# Any children which exposes the set_use_xla_flash_attention method
|
234
373
|
# gets the message
|
235
374
|
def fn_recursive_set_flash_attention(module: torch.nn.Module):
|
236
375
|
if hasattr(module, "set_use_xla_flash_attention"):
|
237
|
-
module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec)
|
376
|
+
module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec, **kwargs)
|
238
377
|
|
239
378
|
for child in module.children():
|
240
379
|
fn_recursive_set_flash_attention(child)
|
@@ -243,11 +382,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
243
382
|
if isinstance(module, torch.nn.Module):
|
244
383
|
fn_recursive_set_flash_attention(module)
|
245
384
|
|
246
|
-
def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None):
|
385
|
+
def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None, **kwargs):
|
247
386
|
r"""
|
248
387
|
Enable the flash attention pallals kernel for torch_xla.
|
249
388
|
"""
|
250
|
-
self.set_use_xla_flash_attention(True, partition_spec)
|
389
|
+
self.set_use_xla_flash_attention(True, partition_spec, **kwargs)
|
251
390
|
|
252
391
|
def disable_xla_flash_attention(self):
|
253
392
|
r"""
|
@@ -314,6 +453,152 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
314
453
|
"""
|
315
454
|
self.set_use_memory_efficient_attention_xformers(False)
|
316
455
|
|
456
|
+
def enable_layerwise_casting(
|
457
|
+
self,
|
458
|
+
storage_dtype: torch.dtype = torch.float8_e4m3fn,
|
459
|
+
compute_dtype: Optional[torch.dtype] = None,
|
460
|
+
skip_modules_pattern: Optional[Tuple[str, ...]] = None,
|
461
|
+
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
|
462
|
+
non_blocking: bool = False,
|
463
|
+
) -> None:
|
464
|
+
r"""
|
465
|
+
Activates layerwise casting for the current model.
|
466
|
+
|
467
|
+
Layerwise casting is a technique that casts the model weights to a lower precision dtype for storage but
|
468
|
+
upcasts them on-the-fly to a higher precision dtype for computation. This process can significantly reduce the
|
469
|
+
memory footprint from model weights, but may lead to some quality degradation in the outputs. Most degradations
|
470
|
+
are negligible, mostly stemming from weight casting in normalization and modulation layers.
|
471
|
+
|
472
|
+
By default, most models in diffusers set the `_skip_layerwise_casting_patterns` attribute to ignore patch
|
473
|
+
embedding, positional embedding and normalization layers. This is because these layers are most likely
|
474
|
+
precision-critical for quality. If you wish to change this behavior, you can set the
|
475
|
+
`_skip_layerwise_casting_patterns` attribute to `None`, or call
|
476
|
+
[`~hooks.layerwise_casting.apply_layerwise_casting`] with custom arguments.
|
477
|
+
|
478
|
+
Example:
|
479
|
+
Using [`~models.ModelMixin.enable_layerwise_casting`]:
|
480
|
+
|
481
|
+
```python
|
482
|
+
>>> from diffusers import CogVideoXTransformer3DModel
|
483
|
+
|
484
|
+
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
|
485
|
+
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
|
486
|
+
... )
|
487
|
+
|
488
|
+
>>> # Enable layerwise casting via the model, which ignores certain modules by default
|
489
|
+
>>> transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
|
490
|
+
```
|
491
|
+
|
492
|
+
Args:
|
493
|
+
storage_dtype (`torch.dtype`):
|
494
|
+
The dtype to which the model should be cast for storage.
|
495
|
+
compute_dtype (`torch.dtype`):
|
496
|
+
The dtype to which the model weights should be cast during the forward pass.
|
497
|
+
skip_modules_pattern (`Tuple[str, ...]`, *optional*):
|
498
|
+
A list of patterns to match the names of the modules to skip during the layerwise casting process. If
|
499
|
+
set to `None`, default skip patterns are used to ignore certain internal layers of modules and PEFT
|
500
|
+
layers.
|
501
|
+
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, *optional*):
|
502
|
+
A list of module classes to skip during the layerwise casting process.
|
503
|
+
non_blocking (`bool`, *optional*, defaults to `False`):
|
504
|
+
If `True`, the weight casting operations are non-blocking.
|
505
|
+
"""
|
506
|
+
from ..hooks import apply_layerwise_casting
|
507
|
+
|
508
|
+
user_provided_patterns = True
|
509
|
+
if skip_modules_pattern is None:
|
510
|
+
from ..hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
|
511
|
+
|
512
|
+
skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
|
513
|
+
user_provided_patterns = False
|
514
|
+
if self._keep_in_fp32_modules is not None:
|
515
|
+
skip_modules_pattern += tuple(self._keep_in_fp32_modules)
|
516
|
+
if self._skip_layerwise_casting_patterns is not None:
|
517
|
+
skip_modules_pattern += tuple(self._skip_layerwise_casting_patterns)
|
518
|
+
skip_modules_pattern = tuple(set(skip_modules_pattern))
|
519
|
+
|
520
|
+
if is_peft_available() and not user_provided_patterns:
|
521
|
+
# By default, we want to skip all peft layers because they have a very low memory footprint.
|
522
|
+
# If users want to apply layerwise casting on peft layers as well, they can utilize the
|
523
|
+
# `~diffusers.hooks.layerwise_casting.apply_layerwise_casting` function which provides
|
524
|
+
# them with more flexibility and control.
|
525
|
+
|
526
|
+
from peft.tuners.loha.layer import LoHaLayer
|
527
|
+
from peft.tuners.lokr.layer import LoKrLayer
|
528
|
+
from peft.tuners.lora.layer import LoraLayer
|
529
|
+
|
530
|
+
for layer in (LoHaLayer, LoKrLayer, LoraLayer):
|
531
|
+
skip_modules_pattern += tuple(layer.adapter_layer_names)
|
532
|
+
|
533
|
+
if compute_dtype is None:
|
534
|
+
logger.info("`compute_dtype` not provided when enabling layerwise casting. Using dtype of the model.")
|
535
|
+
compute_dtype = self.dtype
|
536
|
+
|
537
|
+
apply_layerwise_casting(
|
538
|
+
self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
|
539
|
+
)
|
540
|
+
|
541
|
+
def enable_group_offload(
|
542
|
+
self,
|
543
|
+
onload_device: torch.device,
|
544
|
+
offload_device: torch.device = torch.device("cpu"),
|
545
|
+
offload_type: str = "block_level",
|
546
|
+
num_blocks_per_group: Optional[int] = None,
|
547
|
+
non_blocking: bool = False,
|
548
|
+
use_stream: bool = False,
|
549
|
+
record_stream: bool = False,
|
550
|
+
low_cpu_mem_usage=False,
|
551
|
+
) -> None:
|
552
|
+
r"""
|
553
|
+
Activates group offloading for the current model.
|
554
|
+
|
555
|
+
See [`~hooks.group_offloading.apply_group_offloading`] for more information.
|
556
|
+
|
557
|
+
Example:
|
558
|
+
|
559
|
+
```python
|
560
|
+
>>> from diffusers import CogVideoXTransformer3DModel
|
561
|
+
|
562
|
+
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
|
563
|
+
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
|
564
|
+
... )
|
565
|
+
|
566
|
+
>>> transformer.enable_group_offload(
|
567
|
+
... onload_device=torch.device("cuda"),
|
568
|
+
... offload_device=torch.device("cpu"),
|
569
|
+
... offload_type="leaf_level",
|
570
|
+
... use_stream=True,
|
571
|
+
... )
|
572
|
+
```
|
573
|
+
"""
|
574
|
+
from ..hooks import apply_group_offloading
|
575
|
+
|
576
|
+
if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
|
577
|
+
msg = (
|
578
|
+
"Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "
|
579
|
+
"forward pass is executed with tiling enabled. Please make sure to either:\n"
|
580
|
+
"1. Run a forward pass with small input shapes.\n"
|
581
|
+
"2. Or, run a forward pass with tiling disabled (can still use small dummy inputs)."
|
582
|
+
)
|
583
|
+
logger.warning(msg)
|
584
|
+
if not self._supports_group_offloading:
|
585
|
+
raise ValueError(
|
586
|
+
f"{self.__class__.__name__} does not support group offloading. Please make sure to set the boolean attribute "
|
587
|
+
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
|
588
|
+
f"open an issue at https://github.com/huggingface/diffusers/issues."
|
589
|
+
)
|
590
|
+
apply_group_offloading(
|
591
|
+
self,
|
592
|
+
onload_device,
|
593
|
+
offload_device,
|
594
|
+
offload_type,
|
595
|
+
num_blocks_per_group,
|
596
|
+
non_blocking,
|
597
|
+
use_stream,
|
598
|
+
record_stream,
|
599
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
600
|
+
)
|
601
|
+
|
317
602
|
def save_pretrained(
|
318
603
|
self,
|
319
604
|
save_directory: Union[str, os.PathLike],
|
@@ -426,7 +711,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
426
711
|
os.remove(full_filename)
|
427
712
|
|
428
713
|
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
429
|
-
shard = {tensor: state_dict[tensor] for tensor in tensors}
|
714
|
+
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
430
715
|
filepath = os.path.join(save_directory, filename)
|
431
716
|
if safe_serialization:
|
432
717
|
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
@@ -483,7 +768,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
483
768
|
|
484
769
|
@classmethod
|
485
770
|
@validate_hf_hub_args
|
486
|
-
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
771
|
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs) -> Self:
|
487
772
|
r"""
|
488
773
|
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
489
774
|
|
@@ -559,6 +844,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
559
844
|
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
560
845
|
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
561
846
|
weights. If set to `False`, `safetensors` weights are not loaded.
|
847
|
+
disable_mmap ('bool', *optional*, defaults to 'False'):
|
848
|
+
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
849
|
+
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
562
850
|
|
563
851
|
<Tip>
|
564
852
|
|
@@ -599,11 +887,19 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
599
887
|
device_map = kwargs.pop("device_map", None)
|
600
888
|
max_memory = kwargs.pop("max_memory", None)
|
601
889
|
offload_folder = kwargs.pop("offload_folder", None)
|
602
|
-
offload_state_dict = kwargs.pop("offload_state_dict",
|
890
|
+
offload_state_dict = kwargs.pop("offload_state_dict", None)
|
603
891
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
604
892
|
variant = kwargs.pop("variant", None)
|
605
893
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
606
894
|
quantization_config = kwargs.pop("quantization_config", None)
|
895
|
+
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
896
|
+
disable_mmap = kwargs.pop("disable_mmap", False)
|
897
|
+
|
898
|
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
899
|
+
torch_dtype = torch.float32
|
900
|
+
logger.warning(
|
901
|
+
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
902
|
+
)
|
607
903
|
|
608
904
|
allow_pickle = False
|
609
905
|
if use_safetensors is None:
|
@@ -674,14 +970,15 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
674
970
|
# The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
|
675
971
|
raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
|
676
972
|
|
677
|
-
# Load config if we don't provide a configuration
|
678
|
-
config_path = pretrained_model_name_or_path
|
679
|
-
|
680
973
|
user_agent = {
|
681
974
|
"diffusers": __version__,
|
682
975
|
"file_type": "model",
|
683
976
|
"framework": "pytorch",
|
684
977
|
}
|
978
|
+
unused_kwargs = {}
|
979
|
+
|
980
|
+
# Load config if we don't provide a configuration
|
981
|
+
config_path = pretrained_model_name_or_path
|
685
982
|
|
686
983
|
# load config
|
687
984
|
config, unused_kwargs, commit_hash = cls.load_config(
|
@@ -696,6 +993,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
696
993
|
revision=revision,
|
697
994
|
subfolder=subfolder,
|
698
995
|
user_agent=user_agent,
|
996
|
+
dduf_entries=dduf_entries,
|
699
997
|
**kwargs,
|
700
998
|
)
|
701
999
|
# no in-place modification of the original config.
|
@@ -718,13 +1016,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
718
1016
|
hf_quantizer = None
|
719
1017
|
|
720
1018
|
if hf_quantizer is not None:
|
721
|
-
if device_map is not None:
|
722
|
-
raise NotImplementedError(
|
723
|
-
"Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future."
|
724
|
-
)
|
725
|
-
|
726
1019
|
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
|
727
1020
|
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
1021
|
+
device_map = hf_quantizer.update_device_map(device_map)
|
728
1022
|
|
729
1023
|
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
730
1024
|
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
|
@@ -737,9 +1031,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
737
1031
|
raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
|
738
1032
|
|
739
1033
|
# Check if `_keep_in_fp32_modules` is not None
|
740
|
-
use_keep_in_fp32_modules =
|
741
|
-
|
1034
|
+
use_keep_in_fp32_modules = cls._keep_in_fp32_modules is not None and (
|
1035
|
+
hf_quantizer is None or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
|
742
1036
|
)
|
1037
|
+
|
743
1038
|
if use_keep_in_fp32_modules:
|
744
1039
|
keep_in_fp32_modules = cls._keep_in_fp32_modules
|
745
1040
|
if not isinstance(keep_in_fp32_modules, list):
|
@@ -752,10 +1047,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
752
1047
|
raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.")
|
753
1048
|
else:
|
754
1049
|
keep_in_fp32_modules = []
|
755
|
-
#######################################
|
756
1050
|
|
757
|
-
# Determine if we're loading from a directory of sharded checkpoints.
|
758
1051
|
is_sharded = False
|
1052
|
+
resolved_model_file = None
|
1053
|
+
|
1054
|
+
# Determine if we're loading from a directory of sharded checkpoints.
|
1055
|
+
sharded_metadata = None
|
759
1056
|
index_file = None
|
760
1057
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
761
1058
|
index_file_kwargs = {
|
@@ -772,22 +1069,22 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
772
1069
|
"revision": revision,
|
773
1070
|
"user_agent": user_agent,
|
774
1071
|
"commit_hash": commit_hash,
|
1072
|
+
"dduf_entries": dduf_entries,
|
775
1073
|
}
|
776
1074
|
index_file = _fetch_index_file(**index_file_kwargs)
|
777
1075
|
# In case the index file was not found we still have to consider the legacy format.
|
778
1076
|
# this becomes applicable when the variant is not None.
|
779
1077
|
if variant is not None and (index_file is None or not os.path.exists(index_file)):
|
780
1078
|
index_file = _fetch_index_file_legacy(**index_file_kwargs)
|
781
|
-
if index_file is not None and index_file.is_file():
|
1079
|
+
if index_file is not None and (dduf_entries or index_file.is_file()):
|
782
1080
|
is_sharded = True
|
783
1081
|
|
784
1082
|
if is_sharded and from_flax:
|
785
1083
|
raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
|
786
1084
|
|
787
1085
|
# load model
|
788
|
-
model_file = None
|
789
1086
|
if from_flax:
|
790
|
-
|
1087
|
+
resolved_model_file = _get_model_file(
|
791
1088
|
pretrained_model_name_or_path,
|
792
1089
|
weights_name=FLAX_WEIGHTS_NAME,
|
793
1090
|
cache_dir=cache_dir,
|
@@ -805,10 +1102,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
805
1102
|
# Convert the weights
|
806
1103
|
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
807
1104
|
|
808
|
-
model = load_flax_checkpoint_in_pytorch_model(model,
|
1105
|
+
model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
|
809
1106
|
else:
|
1107
|
+
# in the case it is sharded, we have already the index
|
810
1108
|
if is_sharded:
|
811
|
-
|
1109
|
+
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
|
812
1110
|
pretrained_model_name_or_path,
|
813
1111
|
index_file,
|
814
1112
|
cache_dir=cache_dir,
|
@@ -818,16 +1116,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
818
1116
|
user_agent=user_agent,
|
819
1117
|
revision=revision,
|
820
1118
|
subfolder=subfolder or "",
|
1119
|
+
dduf_entries=dduf_entries,
|
821
1120
|
)
|
822
|
-
|
823
|
-
if hf_quantizer is not None:
|
824
|
-
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
|
825
|
-
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
|
826
|
-
is_sharded = False
|
827
|
-
|
828
|
-
elif use_safetensors and not is_sharded:
|
1121
|
+
elif use_safetensors:
|
829
1122
|
try:
|
830
|
-
|
1123
|
+
resolved_model_file = _get_model_file(
|
831
1124
|
pretrained_model_name_or_path,
|
832
1125
|
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
833
1126
|
cache_dir=cache_dir,
|
@@ -839,6 +1132,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
839
1132
|
subfolder=subfolder,
|
840
1133
|
user_agent=user_agent,
|
841
1134
|
commit_hash=commit_hash,
|
1135
|
+
dduf_entries=dduf_entries,
|
842
1136
|
)
|
843
1137
|
|
844
1138
|
except IOError as e:
|
@@ -849,8 +1143,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
849
1143
|
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
850
1144
|
)
|
851
1145
|
|
852
|
-
if
|
853
|
-
|
1146
|
+
if resolved_model_file is None and not is_sharded:
|
1147
|
+
resolved_model_file = _get_model_file(
|
854
1148
|
pretrained_model_name_or_path,
|
855
1149
|
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
856
1150
|
cache_dir=cache_dir,
|
@@ -862,156 +1156,107 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
862
1156
|
subfolder=subfolder,
|
863
1157
|
user_agent=user_agent,
|
864
1158
|
commit_hash=commit_hash,
|
1159
|
+
dduf_entries=dduf_entries,
|
865
1160
|
)
|
866
1161
|
|
867
|
-
|
868
|
-
|
869
|
-
with accelerate.init_empty_weights():
|
870
|
-
model = cls.from_config(config, **unused_kwargs)
|
1162
|
+
if not isinstance(resolved_model_file, list):
|
1163
|
+
resolved_model_file = [resolved_model_file]
|
871
1164
|
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
1165
|
+
# set dtype to instantiate the model under:
|
1166
|
+
# 1. If torch_dtype is not None, we use that dtype
|
1167
|
+
# 2. If torch_dtype is float8, we don't use _set_default_torch_dtype and we downcast after loading the model
|
1168
|
+
dtype_orig = None
|
1169
|
+
if torch_dtype is not None and not torch_dtype == getattr(torch, "float8_e4m3fn", None):
|
1170
|
+
if not isinstance(torch_dtype, torch.dtype):
|
1171
|
+
raise ValueError(
|
1172
|
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
1173
|
+
)
|
1174
|
+
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
876
1175
|
|
877
|
-
|
878
|
-
if device_map is None and not is_sharded:
|
879
|
-
# `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
|
880
|
-
# It would error out during the `validate_environment()` call above in the absence of cuda.
|
881
|
-
if hf_quantizer is None:
|
882
|
-
param_device = "cpu"
|
883
|
-
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
|
884
|
-
else:
|
885
|
-
param_device = torch.device(torch.cuda.current_device())
|
886
|
-
state_dict = load_state_dict(model_file, variant=variant)
|
887
|
-
model._convert_deprecated_attention_blocks(state_dict)
|
888
|
-
|
889
|
-
# move the params from meta device to cpu
|
890
|
-
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
891
|
-
if hf_quantizer is not None:
|
892
|
-
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
|
893
|
-
if len(missing_keys) > 0:
|
894
|
-
raise ValueError(
|
895
|
-
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
896
|
-
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
897
|
-
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
898
|
-
" those weights or else make sure your checkpoint file is correct."
|
899
|
-
)
|
1176
|
+
init_contexts = [no_init_weights()]
|
900
1177
|
|
901
|
-
|
902
|
-
|
903
|
-
state_dict,
|
904
|
-
device=param_device,
|
905
|
-
dtype=torch_dtype,
|
906
|
-
model_name_or_path=pretrained_model_name_or_path,
|
907
|
-
hf_quantizer=hf_quantizer,
|
908
|
-
keep_in_fp32_modules=keep_in_fp32_modules,
|
909
|
-
)
|
1178
|
+
if low_cpu_mem_usage:
|
1179
|
+
init_contexts.append(accelerate.init_empty_weights())
|
910
1180
|
|
911
|
-
|
912
|
-
|
913
|
-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1181
|
+
with ContextManagers(init_contexts):
|
1182
|
+
model = cls.from_config(config, **unused_kwargs)
|
914
1183
|
|
915
|
-
|
916
|
-
|
917
|
-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
918
|
-
)
|
1184
|
+
if dtype_orig is not None:
|
1185
|
+
torch.set_default_dtype(dtype_orig)
|
919
1186
|
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
)
|
927
|
-
if device_map is None and is_sharded:
|
928
|
-
# we load the parameters on the cpu
|
929
|
-
device_map = {"": "cpu"}
|
930
|
-
force_hook = False
|
931
|
-
try:
|
932
|
-
accelerate.load_checkpoint_and_dispatch(
|
933
|
-
model,
|
934
|
-
model_file if not is_sharded else index_file,
|
935
|
-
device_map,
|
936
|
-
max_memory=max_memory,
|
937
|
-
offload_folder=offload_folder,
|
938
|
-
offload_state_dict=offload_state_dict,
|
939
|
-
dtype=torch_dtype,
|
940
|
-
force_hooks=force_hook,
|
941
|
-
strict=True,
|
942
|
-
)
|
943
|
-
except AttributeError as e:
|
944
|
-
# When using accelerate loading, we do not have the ability to load the state
|
945
|
-
# dict and rename the weight names manually. Additionally, accelerate skips
|
946
|
-
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
947
|
-
# (which look like they should be private variables?), so we can't use the standard hooks
|
948
|
-
# to rename parameters on load. We need to mimic the original weight names so the correct
|
949
|
-
# attributes are available. After we have loaded the weights, we convert the deprecated
|
950
|
-
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
951
|
-
# the weights so we don't have to do this again.
|
952
|
-
|
953
|
-
if "'Attention' object has no attribute" in str(e):
|
954
|
-
logger.warning(
|
955
|
-
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
956
|
-
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
957
|
-
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
958
|
-
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
959
|
-
" please also re-upload it or open a PR on the original repository."
|
960
|
-
)
|
961
|
-
model._temp_convert_self_to_deprecated_attention_blocks()
|
962
|
-
accelerate.load_checkpoint_and_dispatch(
|
963
|
-
model,
|
964
|
-
model_file if not is_sharded else index_file,
|
965
|
-
device_map,
|
966
|
-
max_memory=max_memory,
|
967
|
-
offload_folder=offload_folder,
|
968
|
-
offload_state_dict=offload_state_dict,
|
969
|
-
dtype=torch_dtype,
|
970
|
-
force_hooks=force_hook,
|
971
|
-
strict=True,
|
972
|
-
)
|
973
|
-
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
974
|
-
else:
|
975
|
-
raise e
|
976
|
-
|
977
|
-
loading_info = {
|
978
|
-
"missing_keys": [],
|
979
|
-
"unexpected_keys": [],
|
980
|
-
"mismatched_keys": [],
|
981
|
-
"error_msgs": [],
|
982
|
-
}
|
983
|
-
else:
|
984
|
-
model = cls.from_config(config, **unused_kwargs)
|
1187
|
+
state_dict = None
|
1188
|
+
if not is_sharded:
|
1189
|
+
# Time to load the checkpoint
|
1190
|
+
state_dict = load_state_dict(resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries)
|
1191
|
+
# We only fix it for non sharded checkpoints as we don't need it yet for sharded one.
|
1192
|
+
model._fix_state_dict_keys_on_load(state_dict)
|
985
1193
|
|
986
|
-
|
987
|
-
|
1194
|
+
if is_sharded:
|
1195
|
+
loaded_keys = sharded_metadata["all_checkpoint_keys"]
|
1196
|
+
else:
|
1197
|
+
loaded_keys = list(state_dict.keys())
|
988
1198
|
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
|
994
|
-
|
995
|
-
|
1199
|
+
if hf_quantizer is not None:
|
1200
|
+
hf_quantizer.preprocess_model(
|
1201
|
+
model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
|
1202
|
+
)
|
1203
|
+
|
1204
|
+
# Now that the model is loaded, we can determine the device_map
|
1205
|
+
device_map = _determine_device_map(
|
1206
|
+
model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
|
1207
|
+
)
|
1208
|
+
if hf_quantizer is not None:
|
1209
|
+
hf_quantizer.validate_environment(device_map=device_map)
|
1210
|
+
|
1211
|
+
(
|
1212
|
+
model,
|
1213
|
+
missing_keys,
|
1214
|
+
unexpected_keys,
|
1215
|
+
mismatched_keys,
|
1216
|
+
offload_index,
|
1217
|
+
error_msgs,
|
1218
|
+
) = cls._load_pretrained_model(
|
1219
|
+
model,
|
1220
|
+
state_dict,
|
1221
|
+
resolved_model_file,
|
1222
|
+
pretrained_model_name_or_path,
|
1223
|
+
loaded_keys,
|
1224
|
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
1225
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1226
|
+
device_map=device_map,
|
1227
|
+
offload_folder=offload_folder,
|
1228
|
+
offload_state_dict=offload_state_dict,
|
1229
|
+
dtype=torch_dtype,
|
1230
|
+
hf_quantizer=hf_quantizer,
|
1231
|
+
keep_in_fp32_modules=keep_in_fp32_modules,
|
1232
|
+
dduf_entries=dduf_entries,
|
1233
|
+
)
|
1234
|
+
loading_info = {
|
1235
|
+
"missing_keys": missing_keys,
|
1236
|
+
"unexpected_keys": unexpected_keys,
|
1237
|
+
"mismatched_keys": mismatched_keys,
|
1238
|
+
"error_msgs": error_msgs,
|
1239
|
+
}
|
996
1240
|
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1241
|
+
# Dispatch model with hooks on all devices if necessary
|
1242
|
+
if device_map is not None:
|
1243
|
+
device_map_kwargs = {
|
1244
|
+
"device_map": device_map,
|
1245
|
+
"offload_dir": offload_folder,
|
1246
|
+
"offload_index": offload_index,
|
1247
|
+
}
|
1248
|
+
dispatch_model(model, **device_map_kwargs)
|
1003
1249
|
|
1004
1250
|
if hf_quantizer is not None:
|
1005
1251
|
hf_quantizer.postprocess_model(model)
|
1006
1252
|
model.hf_quantizer = hf_quantizer
|
1007
1253
|
|
1008
|
-
if
|
1009
|
-
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1014
|
-
elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules:
|
1254
|
+
if (
|
1255
|
+
torch_dtype is not None
|
1256
|
+
and torch_dtype == getattr(torch, "float8_e4m3fn", None)
|
1257
|
+
and hf_quantizer is None
|
1258
|
+
and not use_keep_in_fp32_modules
|
1259
|
+
):
|
1015
1260
|
model = model.to(torch_dtype)
|
1016
1261
|
|
1017
1262
|
if hf_quantizer is not None:
|
@@ -1023,6 +1268,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1023
1268
|
|
1024
1269
|
# Set model in evaluation mode to deactivate DropOut modules by default
|
1025
1270
|
model.eval()
|
1271
|
+
|
1026
1272
|
if output_loading_info:
|
1027
1273
|
return model, loading_info
|
1028
1274
|
|
@@ -1031,6 +1277,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1031
1277
|
# Adapted from `transformers`.
|
1032
1278
|
@wraps(torch.nn.Module.cuda)
|
1033
1279
|
def cuda(self, *args, **kwargs):
|
1280
|
+
from ..hooks.group_offloading import _is_group_offload_enabled
|
1281
|
+
|
1034
1282
|
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
1035
1283
|
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
1036
1284
|
if getattr(self, "is_loaded_in_8bit", False):
|
@@ -1043,13 +1291,34 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1043
1291
|
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
1044
1292
|
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
1045
1293
|
)
|
1294
|
+
|
1295
|
+
# Checks if group offloading is enabled
|
1296
|
+
if _is_group_offload_enabled(self):
|
1297
|
+
logger.warning(
|
1298
|
+
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.cuda()` is not supported."
|
1299
|
+
)
|
1300
|
+
return self
|
1301
|
+
|
1046
1302
|
return super().cuda(*args, **kwargs)
|
1047
1303
|
|
1048
1304
|
# Adapted from `transformers`.
|
1049
1305
|
@wraps(torch.nn.Module.to)
|
1050
1306
|
def to(self, *args, **kwargs):
|
1307
|
+
from ..hooks.group_offloading import _is_group_offload_enabled
|
1308
|
+
|
1309
|
+
device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs
|
1051
1310
|
dtype_present_in_args = "dtype" in kwargs
|
1052
1311
|
|
1312
|
+
# Try converting arguments to torch.device in case they are passed as strings
|
1313
|
+
for arg in args:
|
1314
|
+
if not isinstance(arg, str):
|
1315
|
+
continue
|
1316
|
+
try:
|
1317
|
+
torch.device(arg)
|
1318
|
+
device_arg_or_kwarg_present = True
|
1319
|
+
except RuntimeError:
|
1320
|
+
pass
|
1321
|
+
|
1053
1322
|
if not dtype_present_in_args:
|
1054
1323
|
for arg in args:
|
1055
1324
|
if isinstance(arg, torch.dtype):
|
@@ -1074,6 +1343,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1074
1343
|
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
1075
1344
|
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
1076
1345
|
)
|
1346
|
+
|
1347
|
+
if _is_group_offload_enabled(self) and device_arg_or_kwarg_present:
|
1348
|
+
logger.warning(
|
1349
|
+
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported."
|
1350
|
+
)
|
1351
|
+
return self
|
1352
|
+
|
1077
1353
|
return super().to(*args, **kwargs)
|
1078
1354
|
|
1079
1355
|
# Taken from `transformers`.
|
@@ -1103,54 +1379,127 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1103
1379
|
cls,
|
1104
1380
|
model,
|
1105
1381
|
state_dict: OrderedDict,
|
1106
|
-
|
1382
|
+
resolved_model_file: List[str],
|
1107
1383
|
pretrained_model_name_or_path: Union[str, os.PathLike],
|
1384
|
+
loaded_keys: List[str],
|
1108
1385
|
ignore_mismatched_sizes: bool = False,
|
1386
|
+
assign_to_params_buffers: bool = False,
|
1387
|
+
hf_quantizer: Optional[DiffusersQuantizer] = None,
|
1388
|
+
low_cpu_mem_usage: bool = True,
|
1389
|
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
1390
|
+
keep_in_fp32_modules: Optional[List[str]] = None,
|
1391
|
+
device_map: Dict[str, Union[int, str, torch.device]] = None,
|
1392
|
+
offload_state_dict: Optional[bool] = None,
|
1393
|
+
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
1394
|
+
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
1109
1395
|
):
|
1110
|
-
# Retrieve missing & unexpected_keys
|
1111
1396
|
model_state_dict = model.state_dict()
|
1112
|
-
loaded_keys = list(state_dict.keys())
|
1113
|
-
|
1114
1397
|
expected_keys = list(model_state_dict.keys())
|
1115
|
-
|
1116
|
-
original_loaded_keys = loaded_keys
|
1117
|
-
|
1118
1398
|
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
1399
|
+
if hf_quantizer is not None:
|
1400
|
+
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
|
1119
1401
|
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
1402
|
+
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
1403
|
+
# the user.
|
1404
|
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
1405
|
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
1406
|
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1120
1407
|
|
1121
|
-
|
1122
|
-
model_to_load = model
|
1408
|
+
mismatched_keys = []
|
1123
1409
|
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1410
|
+
assign_to_params_buffers = None
|
1411
|
+
error_msgs = []
|
1412
|
+
|
1413
|
+
# Deal with offload
|
1414
|
+
if device_map is not None and "disk" in device_map.values():
|
1415
|
+
if offload_folder is None:
|
1416
|
+
raise ValueError(
|
1417
|
+
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
|
1418
|
+
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
|
1419
|
+
" offers the weights in this format."
|
1420
|
+
)
|
1421
|
+
if offload_folder is not None:
|
1422
|
+
os.makedirs(offload_folder, exist_ok=True)
|
1423
|
+
if offload_state_dict is None:
|
1424
|
+
offload_state_dict = True
|
1425
|
+
|
1426
|
+
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
|
1427
|
+
if offload_state_dict:
|
1428
|
+
state_dict_folder = tempfile.mkdtemp()
|
1429
|
+
state_dict_index = {}
|
1430
|
+
else:
|
1431
|
+
state_dict_folder = None
|
1432
|
+
state_dict_index = None
|
1144
1433
|
|
1145
1434
|
if state_dict is not None:
|
1146
|
-
#
|
1147
|
-
|
1435
|
+
# load_state_dict will manage the case where we pass a dict instead of a file
|
1436
|
+
# if state dict is not None, it means that we don't need to read the files from resolved_model_file also
|
1437
|
+
resolved_model_file = [state_dict]
|
1438
|
+
|
1439
|
+
if len(resolved_model_file) > 1:
|
1440
|
+
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
|
1441
|
+
|
1442
|
+
for shard_file in resolved_model_file:
|
1443
|
+
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
|
1444
|
+
|
1445
|
+
def _find_mismatched_keys(
|
1148
1446
|
state_dict,
|
1149
1447
|
model_state_dict,
|
1150
|
-
|
1448
|
+
loaded_keys,
|
1449
|
+
ignore_mismatched_sizes,
|
1450
|
+
):
|
1451
|
+
mismatched_keys = []
|
1452
|
+
if ignore_mismatched_sizes:
|
1453
|
+
for checkpoint_key in loaded_keys:
|
1454
|
+
model_key = checkpoint_key
|
1455
|
+
# If the checkpoint is sharded, we may not have the key here.
|
1456
|
+
if checkpoint_key not in state_dict:
|
1457
|
+
continue
|
1458
|
+
|
1459
|
+
if (
|
1460
|
+
model_key in model_state_dict
|
1461
|
+
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
1462
|
+
):
|
1463
|
+
mismatched_keys.append(
|
1464
|
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
1465
|
+
)
|
1466
|
+
del state_dict[checkpoint_key]
|
1467
|
+
return mismatched_keys
|
1468
|
+
|
1469
|
+
mismatched_keys += _find_mismatched_keys(
|
1470
|
+
state_dict,
|
1471
|
+
model_state_dict,
|
1472
|
+
loaded_keys,
|
1151
1473
|
ignore_mismatched_sizes,
|
1152
1474
|
)
|
1153
|
-
|
1475
|
+
|
1476
|
+
if low_cpu_mem_usage:
|
1477
|
+
offload_index, state_dict_index = load_model_dict_into_meta(
|
1478
|
+
model,
|
1479
|
+
state_dict,
|
1480
|
+
device_map=device_map,
|
1481
|
+
dtype=dtype,
|
1482
|
+
hf_quantizer=hf_quantizer,
|
1483
|
+
keep_in_fp32_modules=keep_in_fp32_modules,
|
1484
|
+
unexpected_keys=unexpected_keys,
|
1485
|
+
offload_folder=offload_folder,
|
1486
|
+
offload_index=offload_index,
|
1487
|
+
state_dict_index=state_dict_index,
|
1488
|
+
state_dict_folder=state_dict_folder,
|
1489
|
+
)
|
1490
|
+
else:
|
1491
|
+
if assign_to_params_buffers is None:
|
1492
|
+
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
|
1493
|
+
|
1494
|
+
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
|
1495
|
+
|
1496
|
+
if offload_index is not None and len(offload_index) > 0:
|
1497
|
+
save_offload_index(offload_index, offload_folder)
|
1498
|
+
offload_index = None
|
1499
|
+
|
1500
|
+
if offload_state_dict:
|
1501
|
+
load_offloaded_weights(model, state_dict_index, state_dict_folder)
|
1502
|
+
shutil.rmtree(state_dict_folder)
|
1154
1503
|
|
1155
1504
|
if len(error_msgs) > 0:
|
1156
1505
|
error_msg = "\n\t".join(error_msgs)
|
@@ -1162,17 +1511,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1162
1511
|
|
1163
1512
|
if len(unexpected_keys) > 0:
|
1164
1513
|
logger.warning(
|
1165
|
-
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
1166
|
-
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
1167
|
-
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
1168
|
-
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
1169
|
-
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
1170
|
-
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
1171
|
-
" identical (initializing a BertForSequenceClassification model from a"
|
1172
|
-
" BertForSequenceClassification model)."
|
1514
|
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
1173
1515
|
)
|
1174
1516
|
else:
|
1175
1517
|
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
1518
|
+
|
1176
1519
|
if len(missing_keys) > 0:
|
1177
1520
|
logger.warning(
|
1178
1521
|
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
@@ -1200,7 +1543,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1200
1543
|
" able to use it for predictions and inference."
|
1201
1544
|
)
|
1202
1545
|
|
1203
|
-
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
1546
|
+
return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs
|
1204
1547
|
|
1205
1548
|
@classmethod
|
1206
1549
|
def _get_signature_keys(cls, obj):
|
@@ -1214,7 +1557,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1214
1557
|
# Adapted from `transformers` modeling_utils.py
|
1215
1558
|
def _get_no_split_modules(self, device_map: str):
|
1216
1559
|
"""
|
1217
|
-
Get the modules of the model that should not be
|
1560
|
+
Get the modules of the model that should not be split when using device_map. We iterate through the modules to
|
1218
1561
|
get the underlying `_no_split_modules`.
|
1219
1562
|
|
1220
1563
|
Args:
|
@@ -1241,6 +1584,33 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1241
1584
|
modules_to_check += list(module.children())
|
1242
1585
|
return list(_no_split_modules)
|
1243
1586
|
|
1587
|
+
@classmethod
|
1588
|
+
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
|
1589
|
+
"""
|
1590
|
+
Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
|
1591
|
+
under specific dtype.
|
1592
|
+
|
1593
|
+
Args:
|
1594
|
+
dtype (`torch.dtype`):
|
1595
|
+
a floating dtype to set to.
|
1596
|
+
|
1597
|
+
Returns:
|
1598
|
+
`torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
|
1599
|
+
modified. If it wasn't, returns `None`.
|
1600
|
+
|
1601
|
+
Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
|
1602
|
+
`torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
|
1603
|
+
"""
|
1604
|
+
if not dtype.is_floating_point:
|
1605
|
+
raise ValueError(
|
1606
|
+
f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
|
1607
|
+
)
|
1608
|
+
|
1609
|
+
logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
|
1610
|
+
dtype_orig = torch.get_default_dtype()
|
1611
|
+
torch.set_default_dtype(dtype)
|
1612
|
+
return dtype_orig
|
1613
|
+
|
1244
1614
|
@property
|
1245
1615
|
def device(self) -> torch.device:
|
1246
1616
|
"""
|
@@ -1338,7 +1708,31 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1338
1708
|
mem = mem + mem_bufs
|
1339
1709
|
return mem
|
1340
1710
|
|
1341
|
-
def
|
1711
|
+
def _set_gradient_checkpointing(
|
1712
|
+
self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
|
1713
|
+
) -> None:
|
1714
|
+
is_gradient_checkpointing_set = False
|
1715
|
+
|
1716
|
+
for name, module in self.named_modules():
|
1717
|
+
if hasattr(module, "gradient_checkpointing"):
|
1718
|
+
logger.debug(f"Setting `gradient_checkpointing={enable}` for '{name}'")
|
1719
|
+
module._gradient_checkpointing_func = gradient_checkpointing_func
|
1720
|
+
module.gradient_checkpointing = enable
|
1721
|
+
is_gradient_checkpointing_set = True
|
1722
|
+
|
1723
|
+
if not is_gradient_checkpointing_set:
|
1724
|
+
raise ValueError(
|
1725
|
+
f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to "
|
1726
|
+
f"use a module that supports gradient checkpointing by creating a boolean attribute `gradient_checkpointing`."
|
1727
|
+
)
|
1728
|
+
|
1729
|
+
def _fix_state_dict_keys_on_load(self, state_dict: OrderedDict) -> None:
|
1730
|
+
"""
|
1731
|
+
This function fix the state dict of the model to take into account some changes that were made in the model
|
1732
|
+
architecture:
|
1733
|
+
- deprecated attention blocks (happened before we introduced sharded checkpoint,
|
1734
|
+
so this is why we apply this method only when loading non sharded checkpoints for now)
|
1735
|
+
"""
|
1342
1736
|
deprecated_attention_block_paths = []
|
1343
1737
|
|
1344
1738
|
def recursive_find_attn_block(name, module):
|
@@ -1381,56 +1775,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1381
1775
|
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
|
1382
1776
|
if f"{path}.proj_attn.bias" in state_dict:
|
1383
1777
|
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
1384
|
-
|
1385
|
-
def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
1386
|
-
deprecated_attention_block_modules = []
|
1387
|
-
|
1388
|
-
def recursive_find_attn_block(module):
|
1389
|
-
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
1390
|
-
deprecated_attention_block_modules.append(module)
|
1391
|
-
|
1392
|
-
for sub_module in module.children():
|
1393
|
-
recursive_find_attn_block(sub_module)
|
1394
|
-
|
1395
|
-
recursive_find_attn_block(self)
|
1396
|
-
|
1397
|
-
for module in deprecated_attention_block_modules:
|
1398
|
-
module.query = module.to_q
|
1399
|
-
module.key = module.to_k
|
1400
|
-
module.value = module.to_v
|
1401
|
-
module.proj_attn = module.to_out[0]
|
1402
|
-
|
1403
|
-
# We don't _have_ to delete the old attributes, but it's helpful to ensure
|
1404
|
-
# that _all_ the weights are loaded into the new attributes and we're not
|
1405
|
-
# making an incorrect assumption that this model should be converted when
|
1406
|
-
# it really shouldn't be.
|
1407
|
-
del module.to_q
|
1408
|
-
del module.to_k
|
1409
|
-
del module.to_v
|
1410
|
-
del module.to_out
|
1411
|
-
|
1412
|
-
def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
1413
|
-
deprecated_attention_block_modules = []
|
1414
|
-
|
1415
|
-
def recursive_find_attn_block(module) -> None:
|
1416
|
-
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
1417
|
-
deprecated_attention_block_modules.append(module)
|
1418
|
-
|
1419
|
-
for sub_module in module.children():
|
1420
|
-
recursive_find_attn_block(sub_module)
|
1421
|
-
|
1422
|
-
recursive_find_attn_block(self)
|
1423
|
-
|
1424
|
-
for module in deprecated_attention_block_modules:
|
1425
|
-
module.to_q = module.query
|
1426
|
-
module.to_k = module.key
|
1427
|
-
module.to_v = module.value
|
1428
|
-
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
|
1429
|
-
|
1430
|
-
del module.query
|
1431
|
-
del module.key
|
1432
|
-
del module.value
|
1433
|
-
del module.proj_attn
|
1778
|
+
return state_dict
|
1434
1779
|
|
1435
1780
|
|
1436
1781
|
class LegacyModelMixin(ModelMixin):
|