diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +20 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +46 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
- diffusers/schedulers/scheduling_edm_euler.py +53 -30
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
- diffusers/schedulers/scheduling_euler_discrete.py +163 -67
- diffusers/schedulers/scheduling_heun_discrete.py +60 -38
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +27 -25
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
- diffusers-0.28.0.dist-info/RECORD +414 -0
- diffusers-0.27.1.dist-info/RECORD +0 -399
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -92,6 +92,21 @@ class AnimateDiffPipeline(metaclass=DummyObject):
|
|
92
92
|
requires_backends(cls, ["torch", "transformers"])
|
93
93
|
|
94
94
|
|
95
|
+
class AnimateDiffSDXLPipeline(metaclass=DummyObject):
|
96
|
+
_backends = ["torch", "transformers"]
|
97
|
+
|
98
|
+
def __init__(self, *args, **kwargs):
|
99
|
+
requires_backends(self, ["torch", "transformers"])
|
100
|
+
|
101
|
+
@classmethod
|
102
|
+
def from_config(cls, *args, **kwargs):
|
103
|
+
requires_backends(cls, ["torch", "transformers"])
|
104
|
+
|
105
|
+
@classmethod
|
106
|
+
def from_pretrained(cls, *args, **kwargs):
|
107
|
+
requires_backends(cls, ["torch", "transformers"])
|
108
|
+
|
109
|
+
|
95
110
|
class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject):
|
96
111
|
_backends = ["torch", "transformers"]
|
97
112
|
|
@@ -677,6 +692,36 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
|
|
677
692
|
requires_backends(cls, ["torch", "transformers"])
|
678
693
|
|
679
694
|
|
695
|
+
class MarigoldDepthPipeline(metaclass=DummyObject):
|
696
|
+
_backends = ["torch", "transformers"]
|
697
|
+
|
698
|
+
def __init__(self, *args, **kwargs):
|
699
|
+
requires_backends(self, ["torch", "transformers"])
|
700
|
+
|
701
|
+
@classmethod
|
702
|
+
def from_config(cls, *args, **kwargs):
|
703
|
+
requires_backends(cls, ["torch", "transformers"])
|
704
|
+
|
705
|
+
@classmethod
|
706
|
+
def from_pretrained(cls, *args, **kwargs):
|
707
|
+
requires_backends(cls, ["torch", "transformers"])
|
708
|
+
|
709
|
+
|
710
|
+
class MarigoldNormalsPipeline(metaclass=DummyObject):
|
711
|
+
_backends = ["torch", "transformers"]
|
712
|
+
|
713
|
+
def __init__(self, *args, **kwargs):
|
714
|
+
requires_backends(self, ["torch", "transformers"])
|
715
|
+
|
716
|
+
@classmethod
|
717
|
+
def from_config(cls, *args, **kwargs):
|
718
|
+
requires_backends(cls, ["torch", "transformers"])
|
719
|
+
|
720
|
+
@classmethod
|
721
|
+
def from_pretrained(cls, *args, **kwargs):
|
722
|
+
requires_backends(cls, ["torch", "transformers"])
|
723
|
+
|
724
|
+
|
680
725
|
class MusicLDMPipeline(metaclass=DummyObject):
|
681
726
|
_backends = ["torch", "transformers"]
|
682
727
|
|
@@ -737,6 +782,21 @@ class PixArtAlphaPipeline(metaclass=DummyObject):
|
|
737
782
|
requires_backends(cls, ["torch", "transformers"])
|
738
783
|
|
739
784
|
|
785
|
+
class PixArtSigmaPipeline(metaclass=DummyObject):
|
786
|
+
_backends = ["torch", "transformers"]
|
787
|
+
|
788
|
+
def __init__(self, *args, **kwargs):
|
789
|
+
requires_backends(self, ["torch", "transformers"])
|
790
|
+
|
791
|
+
@classmethod
|
792
|
+
def from_config(cls, *args, **kwargs):
|
793
|
+
requires_backends(cls, ["torch", "transformers"])
|
794
|
+
|
795
|
+
@classmethod
|
796
|
+
def from_pretrained(cls, *args, **kwargs):
|
797
|
+
requires_backends(cls, ["torch", "transformers"])
|
798
|
+
|
799
|
+
|
740
800
|
class SemanticStableDiffusionPipeline(metaclass=DummyObject):
|
741
801
|
_backends = ["torch", "transformers"]
|
742
802
|
|
@@ -902,6 +962,21 @@ class StableDiffusionControlNetPipeline(metaclass=DummyObject):
|
|
902
962
|
requires_backends(cls, ["torch", "transformers"])
|
903
963
|
|
904
964
|
|
965
|
+
class StableDiffusionControlNetXSPipeline(metaclass=DummyObject):
|
966
|
+
_backends = ["torch", "transformers"]
|
967
|
+
|
968
|
+
def __init__(self, *args, **kwargs):
|
969
|
+
requires_backends(self, ["torch", "transformers"])
|
970
|
+
|
971
|
+
@classmethod
|
972
|
+
def from_config(cls, *args, **kwargs):
|
973
|
+
requires_backends(cls, ["torch", "transformers"])
|
974
|
+
|
975
|
+
@classmethod
|
976
|
+
def from_pretrained(cls, *args, **kwargs):
|
977
|
+
requires_backends(cls, ["torch", "transformers"])
|
978
|
+
|
979
|
+
|
905
980
|
class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject):
|
906
981
|
_backends = ["torch", "transformers"]
|
907
982
|
|
@@ -1247,6 +1322,21 @@ class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
|
|
1247
1322
|
requires_backends(cls, ["torch", "transformers"])
|
1248
1323
|
|
1249
1324
|
|
1325
|
+
class StableDiffusionXLControlNetXSPipeline(metaclass=DummyObject):
|
1326
|
+
_backends = ["torch", "transformers"]
|
1327
|
+
|
1328
|
+
def __init__(self, *args, **kwargs):
|
1329
|
+
requires_backends(self, ["torch", "transformers"])
|
1330
|
+
|
1331
|
+
@classmethod
|
1332
|
+
def from_config(cls, *args, **kwargs):
|
1333
|
+
requires_backends(cls, ["torch", "transformers"])
|
1334
|
+
|
1335
|
+
@classmethod
|
1336
|
+
def from_pretrained(cls, *args, **kwargs):
|
1337
|
+
requires_backends(cls, ["torch", "transformers"])
|
1338
|
+
|
1339
|
+
|
1250
1340
|
class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
|
1251
1341
|
_backends = ["torch", "transformers"]
|
1252
1342
|
|
@@ -201,7 +201,7 @@ def get_cached_module_file(
|
|
201
201
|
module_file: str,
|
202
202
|
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
203
203
|
force_download: bool = False,
|
204
|
-
resume_download: bool =
|
204
|
+
resume_download: Optional[bool] = None,
|
205
205
|
proxies: Optional[Dict[str, str]] = None,
|
206
206
|
token: Optional[Union[bool, str]] = None,
|
207
207
|
revision: Optional[str] = None,
|
@@ -228,9 +228,9 @@ def get_cached_module_file(
|
|
228
228
|
cache should not be used.
|
229
229
|
force_download (`bool`, *optional*, defaults to `False`):
|
230
230
|
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
231
|
-
exist.
|
232
|
-
|
233
|
-
|
231
|
+
exist. resume_download:
|
232
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
233
|
+
of Diffusers.
|
234
234
|
proxies (`Dict[str, str]`, *optional*):
|
235
235
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
236
236
|
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
@@ -246,8 +246,8 @@ def get_cached_module_file(
|
|
246
246
|
|
247
247
|
<Tip>
|
248
248
|
|
249
|
-
You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private
|
250
|
-
|
249
|
+
You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or
|
250
|
+
[gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
251
251
|
|
252
252
|
</Tip>
|
253
253
|
|
@@ -329,6 +329,11 @@ def get_cached_module_file(
|
|
329
329
|
# The only reason we do the copy is to avoid putting too many folders in sys.path.
|
330
330
|
shutil.copy(resolved_module_file, submodule_path / module_file)
|
331
331
|
for module_needed in modules_needed:
|
332
|
+
if len(module_needed.split(".")) == 2:
|
333
|
+
module_needed = "/".join(module_needed.split("."))
|
334
|
+
module_folder = module_needed.split("/")[0]
|
335
|
+
if not os.path.exists(submodule_path / module_folder):
|
336
|
+
os.makedirs(submodule_path / module_folder)
|
332
337
|
module_needed = f"{module_needed}.py"
|
333
338
|
shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
|
334
339
|
else:
|
@@ -343,9 +348,16 @@ def get_cached_module_file(
|
|
343
348
|
create_dynamic_module(full_submodule)
|
344
349
|
|
345
350
|
if not (submodule_path / module_file).exists():
|
351
|
+
if len(module_file.split("/")) == 2:
|
352
|
+
module_folder = module_file.split("/")[0]
|
353
|
+
if not os.path.exists(submodule_path / module_folder):
|
354
|
+
os.makedirs(submodule_path / module_folder)
|
346
355
|
shutil.copy(resolved_module_file, submodule_path / module_file)
|
356
|
+
|
347
357
|
# Make sure we also have every file with relative
|
348
358
|
for module_needed in modules_needed:
|
359
|
+
if len(module_needed.split(".")) == 2:
|
360
|
+
module_needed = "/".join(module_needed.split("."))
|
349
361
|
if not (submodule_path / module_needed).exists():
|
350
362
|
get_cached_module_file(
|
351
363
|
pretrained_model_name_or_path,
|
@@ -368,7 +380,7 @@ def get_class_from_dynamic_module(
|
|
368
380
|
class_name: Optional[str] = None,
|
369
381
|
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
370
382
|
force_download: bool = False,
|
371
|
-
resume_download: bool =
|
383
|
+
resume_download: Optional[bool] = None,
|
372
384
|
proxies: Optional[Dict[str, str]] = None,
|
373
385
|
token: Optional[Union[bool, str]] = None,
|
374
386
|
revision: Optional[str] = None,
|
@@ -405,8 +417,9 @@ def get_class_from_dynamic_module(
|
|
405
417
|
force_download (`bool`, *optional*, defaults to `False`):
|
406
418
|
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
407
419
|
exist.
|
408
|
-
resume_download
|
409
|
-
|
420
|
+
resume_download:
|
421
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 of
|
422
|
+
Diffusers.
|
410
423
|
proxies (`Dict[str, str]`, *optional*):
|
411
424
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
412
425
|
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
@@ -422,8 +435,8 @@ def get_class_from_dynamic_module(
|
|
422
435
|
|
423
436
|
<Tip>
|
424
437
|
|
425
|
-
You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private
|
426
|
-
|
438
|
+
You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or
|
439
|
+
[gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
427
440
|
|
428
441
|
</Tip>
|
429
442
|
|
diffusers/utils/hub_utils.py
CHANGED
@@ -112,7 +112,8 @@ def load_or_create_model_card(
|
|
112
112
|
repo_id_or_path (`str`):
|
113
113
|
The repo id (e.g., "runwayml/stable-diffusion-v1-5") or local path where to look for the model card.
|
114
114
|
token (`str`, *optional*):
|
115
|
-
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more
|
115
|
+
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more
|
116
|
+
details.
|
116
117
|
is_pipeline (`bool`):
|
117
118
|
Boolean to indicate if we're adding tag to a [`DiffusionPipeline`].
|
118
119
|
from_training: (`bool`): Boolean flag to denote if the model card is being created from a training script.
|
@@ -282,7 +283,7 @@ def _get_model_file(
|
|
282
283
|
cache_dir: Optional[str] = None,
|
283
284
|
force_download: bool = False,
|
284
285
|
proxies: Optional[Dict] = None,
|
285
|
-
resume_download: bool =
|
286
|
+
resume_download: Optional[bool] = None,
|
286
287
|
local_files_only: bool = False,
|
287
288
|
token: Optional[str] = None,
|
288
289
|
user_agent: Optional[Union[Dict, str]] = None,
|
diffusers/utils/import_utils.py
CHANGED
@@ -295,6 +295,46 @@ try:
|
|
295
295
|
except importlib_metadata.PackageNotFoundError:
|
296
296
|
_torchvision_available = False
|
297
297
|
|
298
|
+
_matplotlib_available = importlib.util.find_spec("matplotlib") is not None
|
299
|
+
try:
|
300
|
+
_matplotlib_version = importlib_metadata.version("matplotlib")
|
301
|
+
logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}")
|
302
|
+
except importlib_metadata.PackageNotFoundError:
|
303
|
+
_matplotlib_available = False
|
304
|
+
|
305
|
+
_timm_available = importlib.util.find_spec("timm") is not None
|
306
|
+
if _timm_available:
|
307
|
+
try:
|
308
|
+
_timm_version = importlib_metadata.version("timm")
|
309
|
+
logger.info(f"Timm version {_timm_version} available.")
|
310
|
+
except importlib_metadata.PackageNotFoundError:
|
311
|
+
_timm_available = False
|
312
|
+
|
313
|
+
|
314
|
+
def is_timm_available():
|
315
|
+
return _timm_available
|
316
|
+
|
317
|
+
|
318
|
+
_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None
|
319
|
+
try:
|
320
|
+
_bitsandbytes_version = importlib_metadata.version("bitsandbytes")
|
321
|
+
logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}")
|
322
|
+
except importlib_metadata.PackageNotFoundError:
|
323
|
+
_bitsandbytes_available = False
|
324
|
+
|
325
|
+
# Taken from `huggingface_hub`.
|
326
|
+
_is_notebook = False
|
327
|
+
try:
|
328
|
+
shell_class = get_ipython().__class__ # type: ignore # noqa: F821
|
329
|
+
for parent_class in shell_class.__mro__: # e.g. "is subclass of"
|
330
|
+
if parent_class.__name__ == "ZMQInteractiveShell":
|
331
|
+
_is_notebook = True # Jupyter notebook, Google colab or qtconsole
|
332
|
+
break
|
333
|
+
except NameError:
|
334
|
+
pass # Probably standard Python interpreter
|
335
|
+
|
336
|
+
_is_google_colab = "google.colab" in sys.modules
|
337
|
+
|
298
338
|
|
299
339
|
def is_torch_available():
|
300
340
|
return _torch_available
|
@@ -392,6 +432,26 @@ def is_torchvision_available():
|
|
392
432
|
return _torchvision_available
|
393
433
|
|
394
434
|
|
435
|
+
def is_matplotlib_available():
|
436
|
+
return _matplotlib_available
|
437
|
+
|
438
|
+
|
439
|
+
def is_safetensors_available():
|
440
|
+
return _safetensors_available
|
441
|
+
|
442
|
+
|
443
|
+
def is_bitsandbytes_available():
|
444
|
+
return _bitsandbytes_available
|
445
|
+
|
446
|
+
|
447
|
+
def is_notebook():
|
448
|
+
return _is_notebook
|
449
|
+
|
450
|
+
|
451
|
+
def is_google_colab():
|
452
|
+
return _is_google_colab
|
453
|
+
|
454
|
+
|
395
455
|
# docstyle-ignore
|
396
456
|
FLAX_IMPORT_ERROR = """
|
397
457
|
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
@@ -499,6 +559,20 @@ INVISIBLE_WATERMARK_IMPORT_ERROR = """
|
|
499
559
|
{0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0`
|
500
560
|
"""
|
501
561
|
|
562
|
+
# docstyle-ignore
|
563
|
+
PEFT_IMPORT_ERROR = """
|
564
|
+
{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install peft`
|
565
|
+
"""
|
566
|
+
|
567
|
+
# docstyle-ignore
|
568
|
+
SAFETENSORS_IMPORT_ERROR = """
|
569
|
+
{0} requires the safetensors library but it was not found in your environment. You can install it with pip: `pip install safetensors`
|
570
|
+
"""
|
571
|
+
|
572
|
+
# docstyle-ignore
|
573
|
+
BITSANDBYTES_IMPORT_ERROR = """
|
574
|
+
{0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes`
|
575
|
+
"""
|
502
576
|
|
503
577
|
BACKENDS_MAPPING = OrderedDict(
|
504
578
|
[
|
@@ -520,6 +594,9 @@ BACKENDS_MAPPING = OrderedDict(
|
|
520
594
|
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
|
521
595
|
("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)),
|
522
596
|
("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)),
|
597
|
+
("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
|
598
|
+
("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)),
|
599
|
+
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
|
523
600
|
]
|
524
601
|
)
|
525
602
|
|
@@ -628,6 +705,20 @@ def is_accelerate_version(operation: str, version: str):
|
|
628
705
|
return compare_versions(parse(_accelerate_version), operation, version)
|
629
706
|
|
630
707
|
|
708
|
+
def is_peft_version(operation: str, version: str):
|
709
|
+
"""
|
710
|
+
Args:
|
711
|
+
Compares the current PEFT version to a given reference with an operation.
|
712
|
+
operation (`str`):
|
713
|
+
A string representation of an operator, such as `">"` or `"<="`
|
714
|
+
version (`str`):
|
715
|
+
A version string
|
716
|
+
"""
|
717
|
+
if not _peft_version:
|
718
|
+
return False
|
719
|
+
return compare_versions(parse(_peft_version), operation, version)
|
720
|
+
|
721
|
+
|
631
722
|
def is_k_diffusion_version(operation: str, version: str):
|
632
723
|
"""
|
633
724
|
Args:
|
diffusers/utils/loading_utils.py
CHANGED
@@ -16,8 +16,8 @@ def load_image(
|
|
16
16
|
image (`str` or `PIL.Image.Image`):
|
17
17
|
The image to convert to the PIL Image format.
|
18
18
|
convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional):
|
19
|
-
A conversion method to apply to the image after loading it.
|
20
|
-
|
19
|
+
A conversion method to apply to the image after loading it. When set to `None` the image will be converted
|
20
|
+
"RGB".
|
21
21
|
|
22
22
|
Returns:
|
23
23
|
`PIL.Image.Image`:
|
diffusers/utils/logging.py
CHANGED
@@ -12,7 +12,7 @@
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
|
-
"""
|
15
|
+
"""Logging utilities."""
|
16
16
|
|
17
17
|
import logging
|
18
18
|
import os
|
diffusers/utils/peft_utils.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
"""
|
15
15
|
PEFT utilities: Utilities related to peft library
|
16
16
|
"""
|
17
|
+
|
17
18
|
import collections
|
18
19
|
import importlib
|
19
20
|
from typing import Optional
|
@@ -63,9 +64,11 @@ def recurse_remove_peft_layers(model):
|
|
63
64
|
module_replaced = False
|
64
65
|
|
65
66
|
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
|
66
|
-
new_module = torch.nn.Linear(
|
67
|
-
module.
|
68
|
-
|
67
|
+
new_module = torch.nn.Linear(
|
68
|
+
module.in_features,
|
69
|
+
module.out_features,
|
70
|
+
bias=module.bias is not None,
|
71
|
+
).to(module.weight.device)
|
69
72
|
new_module.weight = module.weight
|
70
73
|
if module.bias is not None:
|
71
74
|
new_module.bias = module.bias
|
@@ -109,6 +112,9 @@ def scale_lora_layers(model, weight):
|
|
109
112
|
"""
|
110
113
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
111
114
|
|
115
|
+
if weight == 1.0:
|
116
|
+
return
|
117
|
+
|
112
118
|
for module in model.modules():
|
113
119
|
if isinstance(module, BaseTunerLayer):
|
114
120
|
module.scale_layer(weight)
|
@@ -128,6 +134,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
|
|
128
134
|
"""
|
129
135
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
130
136
|
|
137
|
+
if weight == 1.0:
|
138
|
+
return
|
139
|
+
|
131
140
|
for module in model.modules():
|
132
141
|
if isinstance(module, BaseTunerLayer):
|
133
142
|
if weight is not None and weight != 0:
|
@@ -170,6 +179,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
|
|
170
179
|
|
171
180
|
# layer names without the Diffusers specific
|
172
181
|
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
|
182
|
+
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
|
173
183
|
|
174
184
|
lora_config_kwargs = {
|
175
185
|
"r": r,
|
@@ -177,6 +187,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
|
|
177
187
|
"rank_pattern": rank_pattern,
|
178
188
|
"alpha_pattern": alpha_pattern,
|
179
189
|
"target_modules": target_modules,
|
190
|
+
"use_dora": use_dora,
|
180
191
|
}
|
181
192
|
return lora_config_kwargs
|
182
193
|
|
@@ -227,16 +238,32 @@ def delete_adapter_layers(model, adapter_name):
|
|
227
238
|
def set_weights_and_activate_adapters(model, adapter_names, weights):
|
228
239
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
229
240
|
|
241
|
+
def get_module_weight(weight_for_adapter, module_name):
|
242
|
+
if not isinstance(weight_for_adapter, dict):
|
243
|
+
# If weight_for_adapter is a single number, always return it.
|
244
|
+
return weight_for_adapter
|
245
|
+
|
246
|
+
for layer_name, weight_ in weight_for_adapter.items():
|
247
|
+
if layer_name in module_name:
|
248
|
+
return weight_
|
249
|
+
|
250
|
+
parts = module_name.split(".")
|
251
|
+
# e.g. key = "down_blocks.1.attentions.0"
|
252
|
+
key = f"{parts[0]}.{parts[1]}.attentions.{parts[3]}"
|
253
|
+
block_weight = weight_for_adapter.get(key, 1.0)
|
254
|
+
|
255
|
+
return block_weight
|
256
|
+
|
230
257
|
# iterate over each adapter, make it active and set the corresponding scaling weight
|
231
258
|
for adapter_name, weight in zip(adapter_names, weights):
|
232
|
-
for module in model.
|
259
|
+
for module_name, module in model.named_modules():
|
233
260
|
if isinstance(module, BaseTunerLayer):
|
234
261
|
# For backward compatbility with previous PEFT versions
|
235
262
|
if hasattr(module, "set_adapter"):
|
236
263
|
module.set_adapter(adapter_name)
|
237
264
|
else:
|
238
265
|
module.active_adapter = adapter_name
|
239
|
-
module.set_scale(adapter_name, weight)
|
266
|
+
module.set_scale(adapter_name, get_module_weight(weight, module_name))
|
240
267
|
|
241
268
|
# set multiple active adapters
|
242
269
|
for module in model.modules():
|
@@ -14,6 +14,7 @@
|
|
14
14
|
"""
|
15
15
|
State dict utilities: utility methods for converting state dicts easily
|
16
16
|
"""
|
17
|
+
|
17
18
|
import enum
|
18
19
|
|
19
20
|
from .logging import get_logger
|
@@ -46,6 +47,7 @@ UNET_TO_DIFFUSERS = {
|
|
46
47
|
".to_v_lora.up": ".to_v.lora_B",
|
47
48
|
".lora.up": ".lora_B",
|
48
49
|
".lora.down": ".lora_A",
|
50
|
+
".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector",
|
49
51
|
}
|
50
52
|
|
51
53
|
|
@@ -103,6 +105,10 @@ DIFFUSERS_OLD_TO_DIFFUSERS = {
|
|
103
105
|
".to_v_lora.down": ".v_proj.lora_linear_layer.down",
|
104
106
|
".to_out_lora.up": ".out_proj.lora_linear_layer.up",
|
105
107
|
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
|
108
|
+
".to_k.lora_magnitude_vector": ".k_proj.lora_magnitude_vector",
|
109
|
+
".to_v.lora_magnitude_vector": ".v_proj.lora_magnitude_vector",
|
110
|
+
".to_q.lora_magnitude_vector": ".q_proj.lora_magnitude_vector",
|
111
|
+
".to_out.lora_magnitude_vector": ".out_proj.lora_magnitude_vector",
|
106
112
|
}
|
107
113
|
|
108
114
|
PEFT_TO_KOHYA_SS = {
|
@@ -247,8 +253,8 @@ def convert_unet_state_dict_to_peft(state_dict):
|
|
247
253
|
|
248
254
|
def convert_all_state_dict_to_peft(state_dict):
|
249
255
|
r"""
|
250
|
-
Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer`
|
251
|
-
|
256
|
+
Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid
|
257
|
+
`DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft`
|
252
258
|
"""
|
253
259
|
try:
|
254
260
|
peft_dict = convert_state_dict_to_peft(state_dict)
|
@@ -314,6 +320,9 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
|
|
314
320
|
kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
|
315
321
|
elif "unet" in kohya_key:
|
316
322
|
kohya_key = kohya_key.replace("unet", "lora_unet")
|
323
|
+
elif "lora_magnitude_vector" in kohya_key:
|
324
|
+
kohya_key = kohya_key.replace("lora_magnitude_vector", "dora_scale")
|
325
|
+
|
317
326
|
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
|
318
327
|
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
|
319
328
|
kohya_ss_state_dict[kohya_key] = weight
|
diffusers/utils/testing_utils.py
CHANGED
@@ -14,7 +14,6 @@ import time
|
|
14
14
|
import unittest
|
15
15
|
import urllib.parse
|
16
16
|
from contextlib import contextmanager
|
17
|
-
from distutils.util import strtobool
|
18
17
|
from io import BytesIO, StringIO
|
19
18
|
from pathlib import Path
|
20
19
|
from typing import Callable, Dict, List, Optional, Union
|
@@ -34,6 +33,7 @@ from .import_utils import (
|
|
34
33
|
is_onnx_available,
|
35
34
|
is_opencv_available,
|
36
35
|
is_peft_available,
|
36
|
+
is_timm_available,
|
37
37
|
is_torch_available,
|
38
38
|
is_torch_version,
|
39
39
|
is_torchsde_available,
|
@@ -106,10 +106,21 @@ def numpy_cosine_similarity_distance(a, b):
|
|
106
106
|
return distance
|
107
107
|
|
108
108
|
|
109
|
-
def print_tensor_test(
|
109
|
+
def print_tensor_test(
|
110
|
+
tensor,
|
111
|
+
limit_to_slices=None,
|
112
|
+
max_torch_print=None,
|
113
|
+
filename="test_corrections.txt",
|
114
|
+
expected_tensor_name="expected_slice",
|
115
|
+
):
|
116
|
+
if max_torch_print:
|
117
|
+
torch.set_printoptions(threshold=10_000)
|
118
|
+
|
110
119
|
test_name = os.environ.get("PYTEST_CURRENT_TEST")
|
111
120
|
if not torch.is_tensor(tensor):
|
112
121
|
tensor = torch.from_numpy(tensor)
|
122
|
+
if limit_to_slices:
|
123
|
+
tensor = tensor[0, -3:, -3:, -1]
|
113
124
|
|
114
125
|
tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "")
|
115
126
|
# format is usually:
|
@@ -118,7 +129,7 @@ def print_tensor_test(tensor, filename="test_corrections.txt", expected_tensor_n
|
|
118
129
|
test_file, test_class, test_fn = test_name.split("::")
|
119
130
|
test_fn = test_fn.split()[0]
|
120
131
|
with open(filename, "a") as f:
|
121
|
-
print("
|
132
|
+
print("::".join([test_file, test_class, test_fn, output_str]), file=f)
|
122
133
|
|
123
134
|
|
124
135
|
def get_tests_dir(append_path=None):
|
@@ -142,6 +153,22 @@ def get_tests_dir(append_path=None):
|
|
142
153
|
return tests_dir
|
143
154
|
|
144
155
|
|
156
|
+
# Taken from the following PR:
|
157
|
+
# https://github.com/huggingface/accelerate/pull/1964
|
158
|
+
def str_to_bool(value) -> int:
|
159
|
+
"""
|
160
|
+
Converts a string representation of truth to `True` (1) or `False` (0). True values are `y`, `yes`, `t`, `true`,
|
161
|
+
`on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
|
162
|
+
"""
|
163
|
+
value = value.lower()
|
164
|
+
if value in ("y", "yes", "t", "true", "on", "1"):
|
165
|
+
return 1
|
166
|
+
elif value in ("n", "no", "f", "false", "off", "0"):
|
167
|
+
return 0
|
168
|
+
else:
|
169
|
+
raise ValueError(f"invalid truth value {value}")
|
170
|
+
|
171
|
+
|
145
172
|
def parse_flag_from_env(key, default=False):
|
146
173
|
try:
|
147
174
|
value = os.environ[key]
|
@@ -151,7 +178,7 @@ def parse_flag_from_env(key, default=False):
|
|
151
178
|
else:
|
152
179
|
# KEY is set, convert it to True or False.
|
153
180
|
try:
|
154
|
-
_value =
|
181
|
+
_value = str_to_bool(value)
|
155
182
|
except ValueError:
|
156
183
|
# More values are supported, but let's keep the message simple.
|
157
184
|
raise ValueError(f"If set, {key} must be yes or no.")
|
@@ -229,6 +256,20 @@ def require_torch_accelerator(test_case):
|
|
229
256
|
)
|
230
257
|
|
231
258
|
|
259
|
+
def require_torch_multi_gpu(test_case):
|
260
|
+
"""
|
261
|
+
Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
|
262
|
+
multiple GPUs. To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests
|
263
|
+
-k "multi_gpu"
|
264
|
+
"""
|
265
|
+
if not is_torch_available():
|
266
|
+
return unittest.skip("test requires PyTorch")(test_case)
|
267
|
+
|
268
|
+
import torch
|
269
|
+
|
270
|
+
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
|
271
|
+
|
272
|
+
|
232
273
|
def require_torch_accelerator_with_fp16(test_case):
|
233
274
|
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
|
234
275
|
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
|
@@ -300,6 +341,13 @@ def require_peft_backend(test_case):
|
|
300
341
|
return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
|
301
342
|
|
302
343
|
|
344
|
+
def require_timm(test_case):
|
345
|
+
"""
|
346
|
+
Decorator marking a test that requires timm. These tests are skipped when timm isn't installed.
|
347
|
+
"""
|
348
|
+
return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)
|
349
|
+
|
350
|
+
|
303
351
|
def require_peft_version_greater(peft_version):
|
304
352
|
"""
|
305
353
|
Decorator marking a test that requires PEFT backend with a specific version, this would require some specific
|
@@ -317,6 +365,18 @@ def require_peft_version_greater(peft_version):
|
|
317
365
|
return decorator
|
318
366
|
|
319
367
|
|
368
|
+
def require_accelerate_version_greater(accelerate_version):
|
369
|
+
def decorator(test_case):
|
370
|
+
correct_accelerate_version = is_peft_available() and version.parse(
|
371
|
+
version.parse(importlib.metadata.version("accelerate")).base_version
|
372
|
+
) > version.parse(accelerate_version)
|
373
|
+
return unittest.skipUnless(
|
374
|
+
correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}."
|
375
|
+
)(test_case)
|
376
|
+
|
377
|
+
return decorator
|
378
|
+
|
379
|
+
|
320
380
|
def deprecate_after_peft_backend(test_case):
|
321
381
|
"""
|
322
382
|
Decorator marking a test that will be skipped after PEFT backend
|
@@ -324,10 +384,15 @@ def deprecate_after_peft_backend(test_case):
|
|
324
384
|
return unittest.skipUnless(not USE_PEFT_BACKEND, "test skipped in favor of PEFT backend")(test_case)
|
325
385
|
|
326
386
|
|
387
|
+
def get_python_version():
|
388
|
+
sys_info = sys.version_info
|
389
|
+
major, minor = sys_info.major, sys_info.minor
|
390
|
+
return major, minor
|
391
|
+
|
392
|
+
|
327
393
|
def require_python39_or_higher(test_case):
|
328
394
|
def python39_available():
|
329
|
-
|
330
|
-
major, minor = sys_info.major, sys_info.minor
|
395
|
+
major, minor = get_python_version()
|
331
396
|
return major == 3 and minor >= 9
|
332
397
|
|
333
398
|
return unittest.skipUnless(python39_available(), "test requires Python 3.9 or higher")(test_case)
|