diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +26 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +33 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +8 -0
- diffusers/models/activations.py +23 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +475 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +363 -32
- diffusers/models/model_loading_utils.py +177 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +175 -99
- diffusers/models/normalization.py +2 -1
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/__init__.py +3 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
- diffusers/models/transformers/pixart_transformer_2d.py +336 -0
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +292 -184
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +27 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +7 -4
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +269 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +75 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
- diffusers-0.28.1.dist-info/RECORD +419 -0
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
@@ -92,6 +92,51 @@ class ControlNetModel(metaclass=DummyObject):
|
|
92
92
|
requires_backends(cls, ["torch"])
|
93
93
|
|
94
94
|
|
95
|
+
class ControlNetXSAdapter(metaclass=DummyObject):
|
96
|
+
_backends = ["torch"]
|
97
|
+
|
98
|
+
def __init__(self, *args, **kwargs):
|
99
|
+
requires_backends(self, ["torch"])
|
100
|
+
|
101
|
+
@classmethod
|
102
|
+
def from_config(cls, *args, **kwargs):
|
103
|
+
requires_backends(cls, ["torch"])
|
104
|
+
|
105
|
+
@classmethod
|
106
|
+
def from_pretrained(cls, *args, **kwargs):
|
107
|
+
requires_backends(cls, ["torch"])
|
108
|
+
|
109
|
+
|
110
|
+
class DiTTransformer2DModel(metaclass=DummyObject):
|
111
|
+
_backends = ["torch"]
|
112
|
+
|
113
|
+
def __init__(self, *args, **kwargs):
|
114
|
+
requires_backends(self, ["torch"])
|
115
|
+
|
116
|
+
@classmethod
|
117
|
+
def from_config(cls, *args, **kwargs):
|
118
|
+
requires_backends(cls, ["torch"])
|
119
|
+
|
120
|
+
@classmethod
|
121
|
+
def from_pretrained(cls, *args, **kwargs):
|
122
|
+
requires_backends(cls, ["torch"])
|
123
|
+
|
124
|
+
|
125
|
+
class HunyuanDiT2DModel(metaclass=DummyObject):
|
126
|
+
_backends = ["torch"]
|
127
|
+
|
128
|
+
def __init__(self, *args, **kwargs):
|
129
|
+
requires_backends(self, ["torch"])
|
130
|
+
|
131
|
+
@classmethod
|
132
|
+
def from_config(cls, *args, **kwargs):
|
133
|
+
requires_backends(cls, ["torch"])
|
134
|
+
|
135
|
+
@classmethod
|
136
|
+
def from_pretrained(cls, *args, **kwargs):
|
137
|
+
requires_backends(cls, ["torch"])
|
138
|
+
|
139
|
+
|
95
140
|
class I2VGenXLUNet(metaclass=DummyObject):
|
96
141
|
_backends = ["torch"]
|
97
142
|
|
@@ -167,6 +212,21 @@ class MultiAdapter(metaclass=DummyObject):
|
|
167
212
|
requires_backends(cls, ["torch"])
|
168
213
|
|
169
214
|
|
215
|
+
class PixArtTransformer2DModel(metaclass=DummyObject):
|
216
|
+
_backends = ["torch"]
|
217
|
+
|
218
|
+
def __init__(self, *args, **kwargs):
|
219
|
+
requires_backends(self, ["torch"])
|
220
|
+
|
221
|
+
@classmethod
|
222
|
+
def from_config(cls, *args, **kwargs):
|
223
|
+
requires_backends(cls, ["torch"])
|
224
|
+
|
225
|
+
@classmethod
|
226
|
+
def from_pretrained(cls, *args, **kwargs):
|
227
|
+
requires_backends(cls, ["torch"])
|
228
|
+
|
229
|
+
|
170
230
|
class PriorTransformer(metaclass=DummyObject):
|
171
231
|
_backends = ["torch"]
|
172
232
|
|
@@ -287,6 +347,21 @@ class UNet3DConditionModel(metaclass=DummyObject):
|
|
287
347
|
requires_backends(cls, ["torch"])
|
288
348
|
|
289
349
|
|
350
|
+
class UNetControlNetXSModel(metaclass=DummyObject):
|
351
|
+
_backends = ["torch"]
|
352
|
+
|
353
|
+
def __init__(self, *args, **kwargs):
|
354
|
+
requires_backends(self, ["torch"])
|
355
|
+
|
356
|
+
@classmethod
|
357
|
+
def from_config(cls, *args, **kwargs):
|
358
|
+
requires_backends(cls, ["torch"])
|
359
|
+
|
360
|
+
@classmethod
|
361
|
+
def from_pretrained(cls, *args, **kwargs):
|
362
|
+
requires_backends(cls, ["torch"])
|
363
|
+
|
364
|
+
|
290
365
|
class UNetMotionModel(metaclass=DummyObject):
|
291
366
|
_backends = ["torch"]
|
292
367
|
|
@@ -92,6 +92,21 @@ class AnimateDiffPipeline(metaclass=DummyObject):
|
|
92
92
|
requires_backends(cls, ["torch", "transformers"])
|
93
93
|
|
94
94
|
|
95
|
+
class AnimateDiffSDXLPipeline(metaclass=DummyObject):
|
96
|
+
_backends = ["torch", "transformers"]
|
97
|
+
|
98
|
+
def __init__(self, *args, **kwargs):
|
99
|
+
requires_backends(self, ["torch", "transformers"])
|
100
|
+
|
101
|
+
@classmethod
|
102
|
+
def from_config(cls, *args, **kwargs):
|
103
|
+
requires_backends(cls, ["torch", "transformers"])
|
104
|
+
|
105
|
+
@classmethod
|
106
|
+
def from_pretrained(cls, *args, **kwargs):
|
107
|
+
requires_backends(cls, ["torch", "transformers"])
|
108
|
+
|
109
|
+
|
95
110
|
class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject):
|
96
111
|
_backends = ["torch", "transformers"]
|
97
112
|
|
@@ -197,6 +212,21 @@ class CycleDiffusionPipeline(metaclass=DummyObject):
|
|
197
212
|
requires_backends(cls, ["torch", "transformers"])
|
198
213
|
|
199
214
|
|
215
|
+
class HunyuanDiTPipeline(metaclass=DummyObject):
|
216
|
+
_backends = ["torch", "transformers"]
|
217
|
+
|
218
|
+
def __init__(self, *args, **kwargs):
|
219
|
+
requires_backends(self, ["torch", "transformers"])
|
220
|
+
|
221
|
+
@classmethod
|
222
|
+
def from_config(cls, *args, **kwargs):
|
223
|
+
requires_backends(cls, ["torch", "transformers"])
|
224
|
+
|
225
|
+
@classmethod
|
226
|
+
def from_pretrained(cls, *args, **kwargs):
|
227
|
+
requires_backends(cls, ["torch", "transformers"])
|
228
|
+
|
229
|
+
|
200
230
|
class I2VGenXLPipeline(metaclass=DummyObject):
|
201
231
|
_backends = ["torch", "transformers"]
|
202
232
|
|
@@ -677,6 +707,36 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
|
|
677
707
|
requires_backends(cls, ["torch", "transformers"])
|
678
708
|
|
679
709
|
|
710
|
+
class MarigoldDepthPipeline(metaclass=DummyObject):
|
711
|
+
_backends = ["torch", "transformers"]
|
712
|
+
|
713
|
+
def __init__(self, *args, **kwargs):
|
714
|
+
requires_backends(self, ["torch", "transformers"])
|
715
|
+
|
716
|
+
@classmethod
|
717
|
+
def from_config(cls, *args, **kwargs):
|
718
|
+
requires_backends(cls, ["torch", "transformers"])
|
719
|
+
|
720
|
+
@classmethod
|
721
|
+
def from_pretrained(cls, *args, **kwargs):
|
722
|
+
requires_backends(cls, ["torch", "transformers"])
|
723
|
+
|
724
|
+
|
725
|
+
class MarigoldNormalsPipeline(metaclass=DummyObject):
|
726
|
+
_backends = ["torch", "transformers"]
|
727
|
+
|
728
|
+
def __init__(self, *args, **kwargs):
|
729
|
+
requires_backends(self, ["torch", "transformers"])
|
730
|
+
|
731
|
+
@classmethod
|
732
|
+
def from_config(cls, *args, **kwargs):
|
733
|
+
requires_backends(cls, ["torch", "transformers"])
|
734
|
+
|
735
|
+
@classmethod
|
736
|
+
def from_pretrained(cls, *args, **kwargs):
|
737
|
+
requires_backends(cls, ["torch", "transformers"])
|
738
|
+
|
739
|
+
|
680
740
|
class MusicLDMPipeline(metaclass=DummyObject):
|
681
741
|
_backends = ["torch", "transformers"]
|
682
742
|
|
@@ -737,6 +797,21 @@ class PixArtAlphaPipeline(metaclass=DummyObject):
|
|
737
797
|
requires_backends(cls, ["torch", "transformers"])
|
738
798
|
|
739
799
|
|
800
|
+
class PixArtSigmaPipeline(metaclass=DummyObject):
|
801
|
+
_backends = ["torch", "transformers"]
|
802
|
+
|
803
|
+
def __init__(self, *args, **kwargs):
|
804
|
+
requires_backends(self, ["torch", "transformers"])
|
805
|
+
|
806
|
+
@classmethod
|
807
|
+
def from_config(cls, *args, **kwargs):
|
808
|
+
requires_backends(cls, ["torch", "transformers"])
|
809
|
+
|
810
|
+
@classmethod
|
811
|
+
def from_pretrained(cls, *args, **kwargs):
|
812
|
+
requires_backends(cls, ["torch", "transformers"])
|
813
|
+
|
814
|
+
|
740
815
|
class SemanticStableDiffusionPipeline(metaclass=DummyObject):
|
741
816
|
_backends = ["torch", "transformers"]
|
742
817
|
|
@@ -902,6 +977,21 @@ class StableDiffusionControlNetPipeline(metaclass=DummyObject):
|
|
902
977
|
requires_backends(cls, ["torch", "transformers"])
|
903
978
|
|
904
979
|
|
980
|
+
class StableDiffusionControlNetXSPipeline(metaclass=DummyObject):
|
981
|
+
_backends = ["torch", "transformers"]
|
982
|
+
|
983
|
+
def __init__(self, *args, **kwargs):
|
984
|
+
requires_backends(self, ["torch", "transformers"])
|
985
|
+
|
986
|
+
@classmethod
|
987
|
+
def from_config(cls, *args, **kwargs):
|
988
|
+
requires_backends(cls, ["torch", "transformers"])
|
989
|
+
|
990
|
+
@classmethod
|
991
|
+
def from_pretrained(cls, *args, **kwargs):
|
992
|
+
requires_backends(cls, ["torch", "transformers"])
|
993
|
+
|
994
|
+
|
905
995
|
class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject):
|
906
996
|
_backends = ["torch", "transformers"]
|
907
997
|
|
@@ -1247,6 +1337,21 @@ class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
|
|
1247
1337
|
requires_backends(cls, ["torch", "transformers"])
|
1248
1338
|
|
1249
1339
|
|
1340
|
+
class StableDiffusionXLControlNetXSPipeline(metaclass=DummyObject):
|
1341
|
+
_backends = ["torch", "transformers"]
|
1342
|
+
|
1343
|
+
def __init__(self, *args, **kwargs):
|
1344
|
+
requires_backends(self, ["torch", "transformers"])
|
1345
|
+
|
1346
|
+
@classmethod
|
1347
|
+
def from_config(cls, *args, **kwargs):
|
1348
|
+
requires_backends(cls, ["torch", "transformers"])
|
1349
|
+
|
1350
|
+
@classmethod
|
1351
|
+
def from_pretrained(cls, *args, **kwargs):
|
1352
|
+
requires_backends(cls, ["torch", "transformers"])
|
1353
|
+
|
1354
|
+
|
1250
1355
|
class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
|
1251
1356
|
_backends = ["torch", "transformers"]
|
1252
1357
|
|
@@ -201,7 +201,7 @@ def get_cached_module_file(
|
|
201
201
|
module_file: str,
|
202
202
|
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
203
203
|
force_download: bool = False,
|
204
|
-
resume_download: bool =
|
204
|
+
resume_download: Optional[bool] = None,
|
205
205
|
proxies: Optional[Dict[str, str]] = None,
|
206
206
|
token: Optional[Union[bool, str]] = None,
|
207
207
|
revision: Optional[str] = None,
|
@@ -228,9 +228,9 @@ def get_cached_module_file(
|
|
228
228
|
cache should not be used.
|
229
229
|
force_download (`bool`, *optional*, defaults to `False`):
|
230
230
|
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
231
|
-
exist.
|
232
|
-
|
233
|
-
|
231
|
+
exist. resume_download:
|
232
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
233
|
+
of Diffusers.
|
234
234
|
proxies (`Dict[str, str]`, *optional*):
|
235
235
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
236
236
|
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
@@ -246,8 +246,8 @@ def get_cached_module_file(
|
|
246
246
|
|
247
247
|
<Tip>
|
248
248
|
|
249
|
-
You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private
|
250
|
-
|
249
|
+
You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or
|
250
|
+
[gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
251
251
|
|
252
252
|
</Tip>
|
253
253
|
|
@@ -329,6 +329,11 @@ def get_cached_module_file(
|
|
329
329
|
# The only reason we do the copy is to avoid putting too many folders in sys.path.
|
330
330
|
shutil.copy(resolved_module_file, submodule_path / module_file)
|
331
331
|
for module_needed in modules_needed:
|
332
|
+
if len(module_needed.split(".")) == 2:
|
333
|
+
module_needed = "/".join(module_needed.split("."))
|
334
|
+
module_folder = module_needed.split("/")[0]
|
335
|
+
if not os.path.exists(submodule_path / module_folder):
|
336
|
+
os.makedirs(submodule_path / module_folder)
|
332
337
|
module_needed = f"{module_needed}.py"
|
333
338
|
shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
|
334
339
|
else:
|
@@ -343,9 +348,16 @@ def get_cached_module_file(
|
|
343
348
|
create_dynamic_module(full_submodule)
|
344
349
|
|
345
350
|
if not (submodule_path / module_file).exists():
|
351
|
+
if len(module_file.split("/")) == 2:
|
352
|
+
module_folder = module_file.split("/")[0]
|
353
|
+
if not os.path.exists(submodule_path / module_folder):
|
354
|
+
os.makedirs(submodule_path / module_folder)
|
346
355
|
shutil.copy(resolved_module_file, submodule_path / module_file)
|
356
|
+
|
347
357
|
# Make sure we also have every file with relative
|
348
358
|
for module_needed in modules_needed:
|
359
|
+
if len(module_needed.split(".")) == 2:
|
360
|
+
module_needed = "/".join(module_needed.split("."))
|
349
361
|
if not (submodule_path / module_needed).exists():
|
350
362
|
get_cached_module_file(
|
351
363
|
pretrained_model_name_or_path,
|
@@ -368,7 +380,7 @@ def get_class_from_dynamic_module(
|
|
368
380
|
class_name: Optional[str] = None,
|
369
381
|
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
370
382
|
force_download: bool = False,
|
371
|
-
resume_download: bool =
|
383
|
+
resume_download: Optional[bool] = None,
|
372
384
|
proxies: Optional[Dict[str, str]] = None,
|
373
385
|
token: Optional[Union[bool, str]] = None,
|
374
386
|
revision: Optional[str] = None,
|
@@ -405,8 +417,9 @@ def get_class_from_dynamic_module(
|
|
405
417
|
force_download (`bool`, *optional*, defaults to `False`):
|
406
418
|
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
407
419
|
exist.
|
408
|
-
resume_download
|
409
|
-
|
420
|
+
resume_download:
|
421
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 of
|
422
|
+
Diffusers.
|
410
423
|
proxies (`Dict[str, str]`, *optional*):
|
411
424
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
412
425
|
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
@@ -422,8 +435,8 @@ def get_class_from_dynamic_module(
|
|
422
435
|
|
423
436
|
<Tip>
|
424
437
|
|
425
|
-
You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private
|
426
|
-
|
438
|
+
You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or
|
439
|
+
[gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
427
440
|
|
428
441
|
</Tip>
|
429
442
|
|
diffusers/utils/hub_utils.py
CHANGED
@@ -112,7 +112,8 @@ def load_or_create_model_card(
|
|
112
112
|
repo_id_or_path (`str`):
|
113
113
|
The repo id (e.g., "runwayml/stable-diffusion-v1-5") or local path where to look for the model card.
|
114
114
|
token (`str`, *optional*):
|
115
|
-
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more
|
115
|
+
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more
|
116
|
+
details.
|
116
117
|
is_pipeline (`bool`):
|
117
118
|
Boolean to indicate if we're adding tag to a [`DiffusionPipeline`].
|
118
119
|
from_training: (`bool`): Boolean flag to denote if the model card is being created from a training script.
|
@@ -282,7 +283,7 @@ def _get_model_file(
|
|
282
283
|
cache_dir: Optional[str] = None,
|
283
284
|
force_download: bool = False,
|
284
285
|
proxies: Optional[Dict] = None,
|
285
|
-
resume_download: bool =
|
286
|
+
resume_download: Optional[bool] = None,
|
286
287
|
local_files_only: bool = False,
|
287
288
|
token: Optional[str] = None,
|
288
289
|
user_agent: Optional[Union[Dict, str]] = None,
|
diffusers/utils/import_utils.py
CHANGED
@@ -295,6 +295,46 @@ try:
|
|
295
295
|
except importlib_metadata.PackageNotFoundError:
|
296
296
|
_torchvision_available = False
|
297
297
|
|
298
|
+
_matplotlib_available = importlib.util.find_spec("matplotlib") is not None
|
299
|
+
try:
|
300
|
+
_matplotlib_version = importlib_metadata.version("matplotlib")
|
301
|
+
logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}")
|
302
|
+
except importlib_metadata.PackageNotFoundError:
|
303
|
+
_matplotlib_available = False
|
304
|
+
|
305
|
+
_timm_available = importlib.util.find_spec("timm") is not None
|
306
|
+
if _timm_available:
|
307
|
+
try:
|
308
|
+
_timm_version = importlib_metadata.version("timm")
|
309
|
+
logger.info(f"Timm version {_timm_version} available.")
|
310
|
+
except importlib_metadata.PackageNotFoundError:
|
311
|
+
_timm_available = False
|
312
|
+
|
313
|
+
|
314
|
+
def is_timm_available():
|
315
|
+
return _timm_available
|
316
|
+
|
317
|
+
|
318
|
+
_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None
|
319
|
+
try:
|
320
|
+
_bitsandbytes_version = importlib_metadata.version("bitsandbytes")
|
321
|
+
logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}")
|
322
|
+
except importlib_metadata.PackageNotFoundError:
|
323
|
+
_bitsandbytes_available = False
|
324
|
+
|
325
|
+
# Taken from `huggingface_hub`.
|
326
|
+
_is_notebook = False
|
327
|
+
try:
|
328
|
+
shell_class = get_ipython().__class__ # type: ignore # noqa: F821
|
329
|
+
for parent_class in shell_class.__mro__: # e.g. "is subclass of"
|
330
|
+
if parent_class.__name__ == "ZMQInteractiveShell":
|
331
|
+
_is_notebook = True # Jupyter notebook, Google colab or qtconsole
|
332
|
+
break
|
333
|
+
except NameError:
|
334
|
+
pass # Probably standard Python interpreter
|
335
|
+
|
336
|
+
_is_google_colab = "google.colab" in sys.modules
|
337
|
+
|
298
338
|
|
299
339
|
def is_torch_available():
|
300
340
|
return _torch_available
|
@@ -392,6 +432,26 @@ def is_torchvision_available():
|
|
392
432
|
return _torchvision_available
|
393
433
|
|
394
434
|
|
435
|
+
def is_matplotlib_available():
|
436
|
+
return _matplotlib_available
|
437
|
+
|
438
|
+
|
439
|
+
def is_safetensors_available():
|
440
|
+
return _safetensors_available
|
441
|
+
|
442
|
+
|
443
|
+
def is_bitsandbytes_available():
|
444
|
+
return _bitsandbytes_available
|
445
|
+
|
446
|
+
|
447
|
+
def is_notebook():
|
448
|
+
return _is_notebook
|
449
|
+
|
450
|
+
|
451
|
+
def is_google_colab():
|
452
|
+
return _is_google_colab
|
453
|
+
|
454
|
+
|
395
455
|
# docstyle-ignore
|
396
456
|
FLAX_IMPORT_ERROR = """
|
397
457
|
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
@@ -499,6 +559,20 @@ INVISIBLE_WATERMARK_IMPORT_ERROR = """
|
|
499
559
|
{0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0`
|
500
560
|
"""
|
501
561
|
|
562
|
+
# docstyle-ignore
|
563
|
+
PEFT_IMPORT_ERROR = """
|
564
|
+
{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install peft`
|
565
|
+
"""
|
566
|
+
|
567
|
+
# docstyle-ignore
|
568
|
+
SAFETENSORS_IMPORT_ERROR = """
|
569
|
+
{0} requires the safetensors library but it was not found in your environment. You can install it with pip: `pip install safetensors`
|
570
|
+
"""
|
571
|
+
|
572
|
+
# docstyle-ignore
|
573
|
+
BITSANDBYTES_IMPORT_ERROR = """
|
574
|
+
{0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes`
|
575
|
+
"""
|
502
576
|
|
503
577
|
BACKENDS_MAPPING = OrderedDict(
|
504
578
|
[
|
@@ -520,6 +594,9 @@ BACKENDS_MAPPING = OrderedDict(
|
|
520
594
|
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
|
521
595
|
("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)),
|
522
596
|
("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)),
|
597
|
+
("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
|
598
|
+
("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)),
|
599
|
+
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
|
523
600
|
]
|
524
601
|
)
|
525
602
|
|
@@ -628,6 +705,20 @@ def is_accelerate_version(operation: str, version: str):
|
|
628
705
|
return compare_versions(parse(_accelerate_version), operation, version)
|
629
706
|
|
630
707
|
|
708
|
+
def is_peft_version(operation: str, version: str):
|
709
|
+
"""
|
710
|
+
Args:
|
711
|
+
Compares the current PEFT version to a given reference with an operation.
|
712
|
+
operation (`str`):
|
713
|
+
A string representation of an operator, such as `">"` or `"<="`
|
714
|
+
version (`str`):
|
715
|
+
A version string
|
716
|
+
"""
|
717
|
+
if not _peft_version:
|
718
|
+
return False
|
719
|
+
return compare_versions(parse(_peft_version), operation, version)
|
720
|
+
|
721
|
+
|
631
722
|
def is_k_diffusion_version(operation: str, version: str):
|
632
723
|
"""
|
633
724
|
Args:
|
diffusers/utils/loading_utils.py
CHANGED
@@ -16,8 +16,8 @@ def load_image(
|
|
16
16
|
image (`str` or `PIL.Image.Image`):
|
17
17
|
The image to convert to the PIL Image format.
|
18
18
|
convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional):
|
19
|
-
A conversion method to apply to the image after loading it.
|
20
|
-
|
19
|
+
A conversion method to apply to the image after loading it. When set to `None` the image will be converted
|
20
|
+
"RGB".
|
21
21
|
|
22
22
|
Returns:
|
23
23
|
`PIL.Image.Image`:
|
diffusers/utils/logging.py
CHANGED
@@ -12,7 +12,7 @@
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
|
-
"""
|
15
|
+
"""Logging utilities."""
|
16
16
|
|
17
17
|
import logging
|
18
18
|
import os
|
diffusers/utils/peft_utils.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
"""
|
15
15
|
PEFT utilities: Utilities related to peft library
|
16
16
|
"""
|
17
|
+
|
17
18
|
import collections
|
18
19
|
import importlib
|
19
20
|
from typing import Optional
|
@@ -63,9 +64,11 @@ def recurse_remove_peft_layers(model):
|
|
63
64
|
module_replaced = False
|
64
65
|
|
65
66
|
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
|
66
|
-
new_module = torch.nn.Linear(
|
67
|
-
module.
|
68
|
-
|
67
|
+
new_module = torch.nn.Linear(
|
68
|
+
module.in_features,
|
69
|
+
module.out_features,
|
70
|
+
bias=module.bias is not None,
|
71
|
+
).to(module.weight.device)
|
69
72
|
new_module.weight = module.weight
|
70
73
|
if module.bias is not None:
|
71
74
|
new_module.bias = module.bias
|
@@ -109,6 +112,9 @@ def scale_lora_layers(model, weight):
|
|
109
112
|
"""
|
110
113
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
111
114
|
|
115
|
+
if weight == 1.0:
|
116
|
+
return
|
117
|
+
|
112
118
|
for module in model.modules():
|
113
119
|
if isinstance(module, BaseTunerLayer):
|
114
120
|
module.scale_layer(weight)
|
@@ -128,6 +134,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
|
|
128
134
|
"""
|
129
135
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
130
136
|
|
137
|
+
if weight == 1.0:
|
138
|
+
return
|
139
|
+
|
131
140
|
for module in model.modules():
|
132
141
|
if isinstance(module, BaseTunerLayer):
|
133
142
|
if weight is not None and weight != 0:
|
@@ -170,6 +179,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
|
|
170
179
|
|
171
180
|
# layer names without the Diffusers specific
|
172
181
|
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
|
182
|
+
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
|
173
183
|
|
174
184
|
lora_config_kwargs = {
|
175
185
|
"r": r,
|
@@ -177,6 +187,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
|
|
177
187
|
"rank_pattern": rank_pattern,
|
178
188
|
"alpha_pattern": alpha_pattern,
|
179
189
|
"target_modules": target_modules,
|
190
|
+
"use_dora": use_dora,
|
180
191
|
}
|
181
192
|
return lora_config_kwargs
|
182
193
|
|
@@ -227,16 +238,32 @@ def delete_adapter_layers(model, adapter_name):
|
|
227
238
|
def set_weights_and_activate_adapters(model, adapter_names, weights):
|
228
239
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
229
240
|
|
241
|
+
def get_module_weight(weight_for_adapter, module_name):
|
242
|
+
if not isinstance(weight_for_adapter, dict):
|
243
|
+
# If weight_for_adapter is a single number, always return it.
|
244
|
+
return weight_for_adapter
|
245
|
+
|
246
|
+
for layer_name, weight_ in weight_for_adapter.items():
|
247
|
+
if layer_name in module_name:
|
248
|
+
return weight_
|
249
|
+
|
250
|
+
parts = module_name.split(".")
|
251
|
+
# e.g. key = "down_blocks.1.attentions.0"
|
252
|
+
key = f"{parts[0]}.{parts[1]}.attentions.{parts[3]}"
|
253
|
+
block_weight = weight_for_adapter.get(key, 1.0)
|
254
|
+
|
255
|
+
return block_weight
|
256
|
+
|
230
257
|
# iterate over each adapter, make it active and set the corresponding scaling weight
|
231
258
|
for adapter_name, weight in zip(adapter_names, weights):
|
232
|
-
for module in model.
|
259
|
+
for module_name, module in model.named_modules():
|
233
260
|
if isinstance(module, BaseTunerLayer):
|
234
261
|
# For backward compatbility with previous PEFT versions
|
235
262
|
if hasattr(module, "set_adapter"):
|
236
263
|
module.set_adapter(adapter_name)
|
237
264
|
else:
|
238
265
|
module.active_adapter = adapter_name
|
239
|
-
module.set_scale(adapter_name, weight)
|
266
|
+
module.set_scale(adapter_name, get_module_weight(weight, module_name))
|
240
267
|
|
241
268
|
# set multiple active adapters
|
242
269
|
for module in model.modules():
|
@@ -14,6 +14,7 @@
|
|
14
14
|
"""
|
15
15
|
State dict utilities: utility methods for converting state dicts easily
|
16
16
|
"""
|
17
|
+
|
17
18
|
import enum
|
18
19
|
|
19
20
|
from .logging import get_logger
|
@@ -46,6 +47,7 @@ UNET_TO_DIFFUSERS = {
|
|
46
47
|
".to_v_lora.up": ".to_v.lora_B",
|
47
48
|
".lora.up": ".lora_B",
|
48
49
|
".lora.down": ".lora_A",
|
50
|
+
".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector",
|
49
51
|
}
|
50
52
|
|
51
53
|
|
@@ -103,6 +105,10 @@ DIFFUSERS_OLD_TO_DIFFUSERS = {
|
|
103
105
|
".to_v_lora.down": ".v_proj.lora_linear_layer.down",
|
104
106
|
".to_out_lora.up": ".out_proj.lora_linear_layer.up",
|
105
107
|
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
|
108
|
+
".to_k.lora_magnitude_vector": ".k_proj.lora_magnitude_vector",
|
109
|
+
".to_v.lora_magnitude_vector": ".v_proj.lora_magnitude_vector",
|
110
|
+
".to_q.lora_magnitude_vector": ".q_proj.lora_magnitude_vector",
|
111
|
+
".to_out.lora_magnitude_vector": ".out_proj.lora_magnitude_vector",
|
106
112
|
}
|
107
113
|
|
108
114
|
PEFT_TO_KOHYA_SS = {
|
@@ -247,8 +253,8 @@ def convert_unet_state_dict_to_peft(state_dict):
|
|
247
253
|
|
248
254
|
def convert_all_state_dict_to_peft(state_dict):
|
249
255
|
r"""
|
250
|
-
Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer`
|
251
|
-
|
256
|
+
Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid
|
257
|
+
`DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft`
|
252
258
|
"""
|
253
259
|
try:
|
254
260
|
peft_dict = convert_state_dict_to_peft(state_dict)
|
@@ -314,6 +320,9 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
|
|
314
320
|
kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
|
315
321
|
elif "unet" in kohya_key:
|
316
322
|
kohya_key = kohya_key.replace("unet", "lora_unet")
|
323
|
+
elif "lora_magnitude_vector" in kohya_key:
|
324
|
+
kohya_key = kohya_key.replace("lora_magnitude_vector", "dora_scale")
|
325
|
+
|
317
326
|
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
|
318
327
|
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
|
319
328
|
kohya_ss_state_dict[kohya_key] = weight
|