diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +34 -2
- diffusers/configuration_utils.py +12 -0
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +257 -54
- diffusers/loaders/__init__.py +2 -0
- diffusers/loaders/ip_adapter.py +5 -1
- diffusers/loaders/lora_base.py +14 -7
- diffusers/loaders/lora_conversion_utils.py +332 -0
- diffusers/loaders/lora_pipeline.py +707 -41
- diffusers/loaders/peft.py +1 -0
- diffusers/loaders/single_file_utils.py +81 -4
- diffusers/loaders/textual_inversion.py +2 -0
- diffusers/loaders/unet.py +39 -8
- diffusers/models/__init__.py +4 -0
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +86 -10
- diffusers/models/attention_processor.py +169 -133
- diffusers/models/autoencoders/autoencoder_kl.py +71 -11
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
- diffusers/models/controlnet_flux.py +536 -0
- diffusers/models/controlnet_sd3.py +7 -3
- diffusers/models/controlnet_sparsectrl.py +0 -1
- diffusers/models/embeddings.py +170 -61
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +182 -14
- diffusers/models/modeling_utils.py +283 -46
- diffusers/models/normalization.py +79 -0
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
- diffusers/models/transformers/pixart_transformer_2d.py +9 -1
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +161 -44
- diffusers/models/transformers/transformer_sd3.py +7 -1
- diffusers/models/unets/unet_2d_condition.py +8 -8
- diffusers/models/unets/unet_motion_model.py +41 -63
- diffusers/models/upsampling.py +6 -6
- diffusers/pipelines/__init__.py +35 -6
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
- diffusers/pipelines/auto_pipeline.py +39 -8
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +10 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -20
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
- diffusers/pipelines/pag/__init__.py +6 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_loading_utils.py +225 -27
- diffusers/pipelines/pipeline_utils.py +123 -180
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +126 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/quantization_config.py +391 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +4 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
- diffusers/schedulers/scheduling_deis_multistep.py +78 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_sasolver.py +78 -1
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
- diffusers/training_utils.py +48 -18
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
- diffusers/utils/hub_utils.py +16 -4
- diffusers/utils/import_utils.py +31 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +3 -3
- diffusers/utils/testing_utils.py +59 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -25,13 +25,32 @@ from ..utils import (
|
|
25
25
|
deprecate,
|
26
26
|
get_adapter_name,
|
27
27
|
get_peft_kwargs,
|
28
|
+
is_peft_available,
|
28
29
|
is_peft_version,
|
30
|
+
is_torch_version,
|
29
31
|
is_transformers_available,
|
32
|
+
is_transformers_version,
|
30
33
|
logging,
|
31
34
|
scale_lora_layers,
|
32
35
|
)
|
33
36
|
from .lora_base import LoraBaseMixin
|
34
|
-
from .lora_conversion_utils import
|
37
|
+
from .lora_conversion_utils import (
|
38
|
+
_convert_kohya_flux_lora_to_diffusers,
|
39
|
+
_convert_non_diffusers_lora_to_diffusers,
|
40
|
+
_convert_xlabs_flux_lora_to_diffusers,
|
41
|
+
_maybe_map_sgm_blocks_to_diffusers,
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
|
46
|
+
if is_torch_version(">=", "1.9.0"):
|
47
|
+
if (
|
48
|
+
is_peft_available()
|
49
|
+
and is_peft_version(">=", "0.13.1")
|
50
|
+
and is_transformers_available()
|
51
|
+
and is_transformers_version(">", "4.45.2")
|
52
|
+
):
|
53
|
+
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
|
35
54
|
|
36
55
|
|
37
56
|
if is_transformers_available():
|
@@ -78,15 +97,24 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
78
97
|
Parameters:
|
79
98
|
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
80
99
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
81
|
-
kwargs (`dict`, *optional*):
|
82
|
-
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
83
100
|
adapter_name (`str`, *optional*):
|
84
101
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
85
102
|
`default_{i}` where i is the total number of adapters being loaded.
|
103
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
104
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
105
|
+
weights.
|
106
|
+
kwargs (`dict`, *optional*):
|
107
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
86
108
|
"""
|
87
109
|
if not USE_PEFT_BACKEND:
|
88
110
|
raise ValueError("PEFT backend is required for this method.")
|
89
111
|
|
112
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
113
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
114
|
+
raise ValueError(
|
115
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
116
|
+
)
|
117
|
+
|
90
118
|
# if a dict is passed, copy it instead of modifying it inplace
|
91
119
|
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
92
120
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
@@ -94,7 +122,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
94
122
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
95
123
|
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
96
124
|
|
97
|
-
is_correct_format = all("lora" in key
|
125
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
98
126
|
if not is_correct_format:
|
99
127
|
raise ValueError("Invalid LoRA checkpoint.")
|
100
128
|
|
@@ -104,6 +132,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
104
132
|
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
105
133
|
adapter_name=adapter_name,
|
106
134
|
_pipeline=self,
|
135
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
107
136
|
)
|
108
137
|
self.load_lora_into_text_encoder(
|
109
138
|
state_dict,
|
@@ -114,6 +143,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
114
143
|
lora_scale=self.lora_scale,
|
115
144
|
adapter_name=adapter_name,
|
116
145
|
_pipeline=self,
|
146
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
117
147
|
)
|
118
148
|
|
119
149
|
@classmethod
|
@@ -206,6 +236,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
206
236
|
user_agent=user_agent,
|
207
237
|
allow_pickle=allow_pickle,
|
208
238
|
)
|
239
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
240
|
+
if is_dora_scale_present:
|
241
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
242
|
+
logger.warning(warn_msg)
|
243
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
209
244
|
|
210
245
|
network_alphas = None
|
211
246
|
# TODO: replace it with a method from `state_dict_utils`
|
@@ -227,7 +262,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
227
262
|
return state_dict, network_alphas
|
228
263
|
|
229
264
|
@classmethod
|
230
|
-
def load_lora_into_unet(
|
265
|
+
def load_lora_into_unet(
|
266
|
+
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
267
|
+
):
|
231
268
|
"""
|
232
269
|
This will load the LoRA layers specified in `state_dict` into `unet`.
|
233
270
|
|
@@ -245,10 +282,16 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
245
282
|
adapter_name (`str`, *optional*):
|
246
283
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
247
284
|
`default_{i}` where i is the total number of adapters being loaded.
|
285
|
+
Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
|
248
286
|
"""
|
249
287
|
if not USE_PEFT_BACKEND:
|
250
288
|
raise ValueError("PEFT backend is required for this method.")
|
251
289
|
|
290
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
291
|
+
raise ValueError(
|
292
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
293
|
+
)
|
294
|
+
|
252
295
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
253
296
|
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
254
297
|
# their prefixes.
|
@@ -258,7 +301,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
258
301
|
# Load the layers corresponding to UNet.
|
259
302
|
logger.info(f"Loading {cls.unet_name}.")
|
260
303
|
unet.load_attn_procs(
|
261
|
-
state_dict,
|
304
|
+
state_dict,
|
305
|
+
network_alphas=network_alphas,
|
306
|
+
adapter_name=adapter_name,
|
307
|
+
_pipeline=_pipeline,
|
308
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
262
309
|
)
|
263
310
|
|
264
311
|
@classmethod
|
@@ -271,6 +318,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
271
318
|
lora_scale=1.0,
|
272
319
|
adapter_name=None,
|
273
320
|
_pipeline=None,
|
321
|
+
low_cpu_mem_usage=False,
|
274
322
|
):
|
275
323
|
"""
|
276
324
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -280,7 +328,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
280
328
|
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
281
329
|
additional `text_encoder` to distinguish between unet lora layers.
|
282
330
|
network_alphas (`Dict[str, float]`):
|
283
|
-
|
331
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
332
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
333
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
284
334
|
text_encoder (`CLIPTextModel`):
|
285
335
|
The text encoder model to load the LoRA layers into.
|
286
336
|
prefix (`str`):
|
@@ -291,10 +341,25 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
291
341
|
adapter_name (`str`, *optional*):
|
292
342
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
293
343
|
`default_{i}` where i is the total number of adapters being loaded.
|
344
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
294
345
|
"""
|
295
346
|
if not USE_PEFT_BACKEND:
|
296
347
|
raise ValueError("PEFT backend is required for this method.")
|
297
348
|
|
349
|
+
peft_kwargs = {}
|
350
|
+
if low_cpu_mem_usage:
|
351
|
+
if not is_peft_version(">=", "0.13.1"):
|
352
|
+
raise ValueError(
|
353
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
354
|
+
)
|
355
|
+
if not is_transformers_version(">", "4.45.2"):
|
356
|
+
# Note from sayakpaul: It's not in `transformers` stable yet.
|
357
|
+
# https://github.com/huggingface/transformers/pull/33725/
|
358
|
+
raise ValueError(
|
359
|
+
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
360
|
+
)
|
361
|
+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
362
|
+
|
298
363
|
from peft import LoraConfig
|
299
364
|
|
300
365
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
@@ -365,6 +430,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
365
430
|
adapter_name=adapter_name,
|
366
431
|
adapter_state_dict=text_encoder_lora_state_dict,
|
367
432
|
peft_config=lora_config,
|
433
|
+
**peft_kwargs,
|
368
434
|
)
|
369
435
|
|
370
436
|
# scale LoRA layers with `lora_scale`
|
@@ -535,12 +601,19 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
535
601
|
adapter_name (`str`, *optional*):
|
536
602
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
537
603
|
`default_{i}` where i is the total number of adapters being loaded.
|
604
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
538
605
|
kwargs (`dict`, *optional*):
|
539
606
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
540
607
|
"""
|
541
608
|
if not USE_PEFT_BACKEND:
|
542
609
|
raise ValueError("PEFT backend is required for this method.")
|
543
610
|
|
611
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
612
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
613
|
+
raise ValueError(
|
614
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
615
|
+
)
|
616
|
+
|
544
617
|
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
545
618
|
# it here explicitly to be able to tell that it's coming from an SDXL
|
546
619
|
# pipeline.
|
@@ -555,12 +628,18 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
555
628
|
unet_config=self.unet.config,
|
556
629
|
**kwargs,
|
557
630
|
)
|
558
|
-
|
631
|
+
|
632
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
559
633
|
if not is_correct_format:
|
560
634
|
raise ValueError("Invalid LoRA checkpoint.")
|
561
635
|
|
562
636
|
self.load_lora_into_unet(
|
563
|
-
state_dict,
|
637
|
+
state_dict,
|
638
|
+
network_alphas=network_alphas,
|
639
|
+
unet=self.unet,
|
640
|
+
adapter_name=adapter_name,
|
641
|
+
_pipeline=self,
|
642
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
564
643
|
)
|
565
644
|
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
566
645
|
if len(text_encoder_state_dict) > 0:
|
@@ -572,6 +651,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
572
651
|
lora_scale=self.lora_scale,
|
573
652
|
adapter_name=adapter_name,
|
574
653
|
_pipeline=self,
|
654
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
575
655
|
)
|
576
656
|
|
577
657
|
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
@@ -584,6 +664,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
584
664
|
lora_scale=self.lora_scale,
|
585
665
|
adapter_name=adapter_name,
|
586
666
|
_pipeline=self,
|
667
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
587
668
|
)
|
588
669
|
|
589
670
|
@classmethod
|
@@ -677,6 +758,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
677
758
|
user_agent=user_agent,
|
678
759
|
allow_pickle=allow_pickle,
|
679
760
|
)
|
761
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
762
|
+
if is_dora_scale_present:
|
763
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
764
|
+
logger.warning(warn_msg)
|
765
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
680
766
|
|
681
767
|
network_alphas = None
|
682
768
|
# TODO: replace it with a method from `state_dict_utils`
|
@@ -699,7 +785,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
699
785
|
|
700
786
|
@classmethod
|
701
787
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
|
702
|
-
def load_lora_into_unet(
|
788
|
+
def load_lora_into_unet(
|
789
|
+
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
790
|
+
):
|
703
791
|
"""
|
704
792
|
This will load the LoRA layers specified in `state_dict` into `unet`.
|
705
793
|
|
@@ -717,10 +805,16 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
717
805
|
adapter_name (`str`, *optional*):
|
718
806
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
719
807
|
`default_{i}` where i is the total number of adapters being loaded.
|
808
|
+
Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
|
720
809
|
"""
|
721
810
|
if not USE_PEFT_BACKEND:
|
722
811
|
raise ValueError("PEFT backend is required for this method.")
|
723
812
|
|
813
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
814
|
+
raise ValueError(
|
815
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
816
|
+
)
|
817
|
+
|
724
818
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
725
819
|
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
726
820
|
# their prefixes.
|
@@ -730,7 +824,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
730
824
|
# Load the layers corresponding to UNet.
|
731
825
|
logger.info(f"Loading {cls.unet_name}.")
|
732
826
|
unet.load_attn_procs(
|
733
|
-
state_dict,
|
827
|
+
state_dict,
|
828
|
+
network_alphas=network_alphas,
|
829
|
+
adapter_name=adapter_name,
|
830
|
+
_pipeline=_pipeline,
|
831
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
734
832
|
)
|
735
833
|
|
736
834
|
@classmethod
|
@@ -744,6 +842,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
744
842
|
lora_scale=1.0,
|
745
843
|
adapter_name=None,
|
746
844
|
_pipeline=None,
|
845
|
+
low_cpu_mem_usage=False,
|
747
846
|
):
|
748
847
|
"""
|
749
848
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -753,7 +852,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
753
852
|
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
754
853
|
additional `text_encoder` to distinguish between unet lora layers.
|
755
854
|
network_alphas (`Dict[str, float]`):
|
756
|
-
|
855
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
856
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
857
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
757
858
|
text_encoder (`CLIPTextModel`):
|
758
859
|
The text encoder model to load the LoRA layers into.
|
759
860
|
prefix (`str`):
|
@@ -764,10 +865,25 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
764
865
|
adapter_name (`str`, *optional*):
|
765
866
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
766
867
|
`default_{i}` where i is the total number of adapters being loaded.
|
868
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
767
869
|
"""
|
768
870
|
if not USE_PEFT_BACKEND:
|
769
871
|
raise ValueError("PEFT backend is required for this method.")
|
770
872
|
|
873
|
+
peft_kwargs = {}
|
874
|
+
if low_cpu_mem_usage:
|
875
|
+
if not is_peft_version(">=", "0.13.1"):
|
876
|
+
raise ValueError(
|
877
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
878
|
+
)
|
879
|
+
if not is_transformers_version(">", "4.45.2"):
|
880
|
+
# Note from sayakpaul: It's not in `transformers` stable yet.
|
881
|
+
# https://github.com/huggingface/transformers/pull/33725/
|
882
|
+
raise ValueError(
|
883
|
+
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
884
|
+
)
|
885
|
+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
886
|
+
|
771
887
|
from peft import LoraConfig
|
772
888
|
|
773
889
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
@@ -838,6 +954,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
838
954
|
adapter_name=adapter_name,
|
839
955
|
adapter_state_dict=text_encoder_lora_state_dict,
|
840
956
|
peft_config=lora_config,
|
957
|
+
**peft_kwargs,
|
841
958
|
)
|
842
959
|
|
843
960
|
# scale LoRA layers with `lora_scale`
|
@@ -1080,6 +1197,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1080
1197
|
allow_pickle=allow_pickle,
|
1081
1198
|
)
|
1082
1199
|
|
1200
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
1201
|
+
if is_dora_scale_present:
|
1202
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
1203
|
+
logger.warning(warn_msg)
|
1204
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
1205
|
+
|
1083
1206
|
return state_dict
|
1084
1207
|
|
1085
1208
|
def load_lora_weights(
|
@@ -1100,15 +1223,22 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1100
1223
|
Parameters:
|
1101
1224
|
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
1102
1225
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1103
|
-
kwargs (`dict`, *optional*):
|
1104
|
-
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1105
1226
|
adapter_name (`str`, *optional*):
|
1106
1227
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1107
1228
|
`default_{i}` where i is the total number of adapters being loaded.
|
1229
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
1230
|
+
kwargs (`dict`, *optional*):
|
1231
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1108
1232
|
"""
|
1109
1233
|
if not USE_PEFT_BACKEND:
|
1110
1234
|
raise ValueError("PEFT backend is required for this method.")
|
1111
1235
|
|
1236
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
1237
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
1238
|
+
raise ValueError(
|
1239
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1240
|
+
)
|
1241
|
+
|
1112
1242
|
# if a dict is passed, copy it instead of modifying it inplace
|
1113
1243
|
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
1114
1244
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
@@ -1116,7 +1246,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1116
1246
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
1117
1247
|
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
1118
1248
|
|
1119
|
-
is_correct_format = all("lora" in key
|
1249
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
1120
1250
|
if not is_correct_format:
|
1121
1251
|
raise ValueError("Invalid LoRA checkpoint.")
|
1122
1252
|
|
@@ -1125,6 +1255,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1125
1255
|
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
1126
1256
|
adapter_name=adapter_name,
|
1127
1257
|
_pipeline=self,
|
1258
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1128
1259
|
)
|
1129
1260
|
|
1130
1261
|
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
@@ -1137,6 +1268,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1137
1268
|
lora_scale=self.lora_scale,
|
1138
1269
|
adapter_name=adapter_name,
|
1139
1270
|
_pipeline=self,
|
1271
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1140
1272
|
)
|
1141
1273
|
|
1142
1274
|
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
@@ -1149,10 +1281,13 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1149
1281
|
lora_scale=self.lora_scale,
|
1150
1282
|
adapter_name=adapter_name,
|
1151
1283
|
_pipeline=self,
|
1284
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1152
1285
|
)
|
1153
1286
|
|
1154
1287
|
@classmethod
|
1155
|
-
def load_lora_into_transformer(
|
1288
|
+
def load_lora_into_transformer(
|
1289
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
1290
|
+
):
|
1156
1291
|
"""
|
1157
1292
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
1158
1293
|
|
@@ -1166,7 +1301,13 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1166
1301
|
adapter_name (`str`, *optional*):
|
1167
1302
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1168
1303
|
`default_{i}` where i is the total number of adapters being loaded.
|
1304
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
1169
1305
|
"""
|
1306
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
1307
|
+
raise ValueError(
|
1308
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1309
|
+
)
|
1310
|
+
|
1170
1311
|
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
1171
1312
|
|
1172
1313
|
keys = list(state_dict.keys())
|
@@ -1210,17 +1351,37 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1210
1351
|
# otherwise loading LoRA weights will lead to an error
|
1211
1352
|
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
1212
1353
|
|
1213
|
-
|
1214
|
-
|
1354
|
+
peft_kwargs = {}
|
1355
|
+
if is_peft_version(">=", "0.13.1"):
|
1356
|
+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
1357
|
+
|
1358
|
+
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
|
1359
|
+
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
1215
1360
|
|
1361
|
+
warn_msg = ""
|
1216
1362
|
if incompatible_keys is not None:
|
1217
|
-
#
|
1363
|
+
# Check only for unexpected keys.
|
1218
1364
|
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
1219
1365
|
if unexpected_keys:
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1366
|
+
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
1367
|
+
if lora_unexpected_keys:
|
1368
|
+
warn_msg = (
|
1369
|
+
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
1370
|
+
f" {', '.join(lora_unexpected_keys)}. "
|
1371
|
+
)
|
1372
|
+
|
1373
|
+
# Filter missing keys specific to the current adapter.
|
1374
|
+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
1375
|
+
if missing_keys:
|
1376
|
+
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
1377
|
+
if lora_missing_keys:
|
1378
|
+
warn_msg += (
|
1379
|
+
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
1380
|
+
f" {', '.join(lora_missing_keys)}."
|
1381
|
+
)
|
1382
|
+
|
1383
|
+
if warn_msg:
|
1384
|
+
logger.warning(warn_msg)
|
1224
1385
|
|
1225
1386
|
# Offload back.
|
1226
1387
|
if is_model_cpu_offload:
|
@@ -1240,6 +1401,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1240
1401
|
lora_scale=1.0,
|
1241
1402
|
adapter_name=None,
|
1242
1403
|
_pipeline=None,
|
1404
|
+
low_cpu_mem_usage=False,
|
1243
1405
|
):
|
1244
1406
|
"""
|
1245
1407
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -1249,7 +1411,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1249
1411
|
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
1250
1412
|
additional `text_encoder` to distinguish between unet lora layers.
|
1251
1413
|
network_alphas (`Dict[str, float]`):
|
1252
|
-
|
1414
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
1415
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
1416
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
1253
1417
|
text_encoder (`CLIPTextModel`):
|
1254
1418
|
The text encoder model to load the LoRA layers into.
|
1255
1419
|
prefix (`str`):
|
@@ -1260,10 +1424,25 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1260
1424
|
adapter_name (`str`, *optional*):
|
1261
1425
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1262
1426
|
`default_{i}` where i is the total number of adapters being loaded.
|
1427
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
1263
1428
|
"""
|
1264
1429
|
if not USE_PEFT_BACKEND:
|
1265
1430
|
raise ValueError("PEFT backend is required for this method.")
|
1266
1431
|
|
1432
|
+
peft_kwargs = {}
|
1433
|
+
if low_cpu_mem_usage:
|
1434
|
+
if not is_peft_version(">=", "0.13.1"):
|
1435
|
+
raise ValueError(
|
1436
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1437
|
+
)
|
1438
|
+
if not is_transformers_version(">", "4.45.2"):
|
1439
|
+
# Note from sayakpaul: It's not in `transformers` stable yet.
|
1440
|
+
# https://github.com/huggingface/transformers/pull/33725/
|
1441
|
+
raise ValueError(
|
1442
|
+
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
1443
|
+
)
|
1444
|
+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
1445
|
+
|
1267
1446
|
from peft import LoraConfig
|
1268
1447
|
|
1269
1448
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
@@ -1334,6 +1513,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1334
1513
|
adapter_name=adapter_name,
|
1335
1514
|
adapter_state_dict=text_encoder_lora_state_dict,
|
1336
1515
|
peft_config=lora_config,
|
1516
|
+
**peft_kwargs,
|
1337
1517
|
)
|
1338
1518
|
|
1339
1519
|
# scale LoRA layers with `lora_scale`
|
@@ -1576,6 +1756,24 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1576
1756
|
user_agent=user_agent,
|
1577
1757
|
allow_pickle=allow_pickle,
|
1578
1758
|
)
|
1759
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
1760
|
+
if is_dora_scale_present:
|
1761
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
1762
|
+
logger.warning(warn_msg)
|
1763
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
1764
|
+
|
1765
|
+
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
|
1766
|
+
is_kohya = any(".lora_down.weight" in k for k in state_dict)
|
1767
|
+
if is_kohya:
|
1768
|
+
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
|
1769
|
+
# Kohya already takes care of scaling the LoRA parameters with alpha.
|
1770
|
+
return (state_dict, None) if return_alphas else state_dict
|
1771
|
+
|
1772
|
+
is_xlabs = any("processor" in k for k in state_dict)
|
1773
|
+
if is_xlabs:
|
1774
|
+
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
|
1775
|
+
# xlabs doesn't use `alpha`.
|
1776
|
+
return (state_dict, None) if return_alphas else state_dict
|
1579
1777
|
|
1580
1778
|
# For state dicts like
|
1581
1779
|
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
|
@@ -1621,10 +1819,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1621
1819
|
adapter_name (`str`, *optional*):
|
1622
1820
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1623
1821
|
`default_{i}` where i is the total number of adapters being loaded.
|
1822
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
1624
1823
|
"""
|
1625
1824
|
if not USE_PEFT_BACKEND:
|
1626
1825
|
raise ValueError("PEFT backend is required for this method.")
|
1627
1826
|
|
1827
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
1828
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
1829
|
+
raise ValueError(
|
1830
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1831
|
+
)
|
1832
|
+
|
1628
1833
|
# if a dict is passed, copy it instead of modifying it inplace
|
1629
1834
|
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
1630
1835
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
@@ -1634,7 +1839,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1634
1839
|
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
|
1635
1840
|
)
|
1636
1841
|
|
1637
|
-
is_correct_format = all("lora" in key
|
1842
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
1638
1843
|
if not is_correct_format:
|
1639
1844
|
raise ValueError("Invalid LoRA checkpoint.")
|
1640
1845
|
|
@@ -1644,6 +1849,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1644
1849
|
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
1645
1850
|
adapter_name=adapter_name,
|
1646
1851
|
_pipeline=self,
|
1852
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1647
1853
|
)
|
1648
1854
|
|
1649
1855
|
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
@@ -1656,10 +1862,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1656
1862
|
lora_scale=self.lora_scale,
|
1657
1863
|
adapter_name=adapter_name,
|
1658
1864
|
_pipeline=self,
|
1865
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1659
1866
|
)
|
1660
1867
|
|
1661
1868
|
@classmethod
|
1662
|
-
def load_lora_into_transformer(
|
1869
|
+
def load_lora_into_transformer(
|
1870
|
+
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
1871
|
+
):
|
1663
1872
|
"""
|
1664
1873
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
1665
1874
|
|
@@ -1677,7 +1886,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1677
1886
|
adapter_name (`str`, *optional*):
|
1678
1887
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1679
1888
|
`default_{i}` where i is the total number of adapters being loaded.
|
1889
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
1680
1890
|
"""
|
1891
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
1892
|
+
raise ValueError(
|
1893
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1894
|
+
)
|
1895
|
+
|
1681
1896
|
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
1682
1897
|
|
1683
1898
|
keys = list(state_dict.keys())
|
@@ -1726,17 +1941,37 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1726
1941
|
# otherwise loading LoRA weights will lead to an error
|
1727
1942
|
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
1728
1943
|
|
1729
|
-
|
1730
|
-
|
1944
|
+
peft_kwargs = {}
|
1945
|
+
if is_peft_version(">=", "0.13.1"):
|
1946
|
+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
1947
|
+
|
1948
|
+
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
|
1949
|
+
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
1731
1950
|
|
1951
|
+
warn_msg = ""
|
1732
1952
|
if incompatible_keys is not None:
|
1733
|
-
#
|
1953
|
+
# Check only for unexpected keys.
|
1734
1954
|
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
1735
1955
|
if unexpected_keys:
|
1736
|
-
|
1737
|
-
|
1738
|
-
|
1739
|
-
|
1956
|
+
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
1957
|
+
if lora_unexpected_keys:
|
1958
|
+
warn_msg = (
|
1959
|
+
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
1960
|
+
f" {', '.join(lora_unexpected_keys)}. "
|
1961
|
+
)
|
1962
|
+
|
1963
|
+
# Filter missing keys specific to the current adapter.
|
1964
|
+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
1965
|
+
if missing_keys:
|
1966
|
+
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
1967
|
+
if lora_missing_keys:
|
1968
|
+
warn_msg += (
|
1969
|
+
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
1970
|
+
f" {', '.join(lora_missing_keys)}."
|
1971
|
+
)
|
1972
|
+
|
1973
|
+
if warn_msg:
|
1974
|
+
logger.warning(warn_msg)
|
1740
1975
|
|
1741
1976
|
# Offload back.
|
1742
1977
|
if is_model_cpu_offload:
|
@@ -1756,6 +1991,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1756
1991
|
lora_scale=1.0,
|
1757
1992
|
adapter_name=None,
|
1758
1993
|
_pipeline=None,
|
1994
|
+
low_cpu_mem_usage=False,
|
1759
1995
|
):
|
1760
1996
|
"""
|
1761
1997
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -1765,7 +2001,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1765
2001
|
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
1766
2002
|
additional `text_encoder` to distinguish between unet lora layers.
|
1767
2003
|
network_alphas (`Dict[str, float]`):
|
1768
|
-
|
2004
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
2005
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
2006
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
1769
2007
|
text_encoder (`CLIPTextModel`):
|
1770
2008
|
The text encoder model to load the LoRA layers into.
|
1771
2009
|
prefix (`str`):
|
@@ -1776,10 +2014,25 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1776
2014
|
adapter_name (`str`, *optional*):
|
1777
2015
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1778
2016
|
`default_{i}` where i is the total number of adapters being loaded.
|
2017
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
1779
2018
|
"""
|
1780
2019
|
if not USE_PEFT_BACKEND:
|
1781
2020
|
raise ValueError("PEFT backend is required for this method.")
|
1782
2021
|
|
2022
|
+
peft_kwargs = {}
|
2023
|
+
if low_cpu_mem_usage:
|
2024
|
+
if not is_peft_version(">=", "0.13.1"):
|
2025
|
+
raise ValueError(
|
2026
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2027
|
+
)
|
2028
|
+
if not is_transformers_version(">", "4.45.2"):
|
2029
|
+
# Note from sayakpaul: It's not in `transformers` stable yet.
|
2030
|
+
# https://github.com/huggingface/transformers/pull/33725/
|
2031
|
+
raise ValueError(
|
2032
|
+
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
2033
|
+
)
|
2034
|
+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
2035
|
+
|
1783
2036
|
from peft import LoraConfig
|
1784
2037
|
|
1785
2038
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
@@ -1850,6 +2103,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1850
2103
|
adapter_name=adapter_name,
|
1851
2104
|
adapter_state_dict=text_encoder_lora_state_dict,
|
1852
2105
|
peft_config=lora_config,
|
2106
|
+
**peft_kwargs,
|
1853
2107
|
)
|
1854
2108
|
|
1855
2109
|
# scale LoRA layers with `lora_scale`
|
@@ -1998,7 +2252,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
1998
2252
|
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
1999
2253
|
encoder lora layers.
|
2000
2254
|
network_alphas (`Dict[str, float]`):
|
2001
|
-
|
2255
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
2256
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
2257
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
2002
2258
|
unet (`UNet2DConditionModel`):
|
2003
2259
|
The UNet model to load the LoRA layers into.
|
2004
2260
|
adapter_name (`str`, *optional*):
|
@@ -2055,14 +2311,30 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2055
2311
|
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
2056
2312
|
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
2057
2313
|
|
2314
|
+
warn_msg = ""
|
2058
2315
|
if incompatible_keys is not None:
|
2059
|
-
#
|
2316
|
+
# Check only for unexpected keys.
|
2060
2317
|
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
2061
2318
|
if unexpected_keys:
|
2062
|
-
|
2063
|
-
|
2064
|
-
|
2065
|
-
|
2319
|
+
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
2320
|
+
if lora_unexpected_keys:
|
2321
|
+
warn_msg = (
|
2322
|
+
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
2323
|
+
f" {', '.join(lora_unexpected_keys)}. "
|
2324
|
+
)
|
2325
|
+
|
2326
|
+
# Filter missing keys specific to the current adapter.
|
2327
|
+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
2328
|
+
if missing_keys:
|
2329
|
+
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
2330
|
+
if lora_missing_keys:
|
2331
|
+
warn_msg += (
|
2332
|
+
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
2333
|
+
f" {', '.join(lora_missing_keys)}."
|
2334
|
+
)
|
2335
|
+
|
2336
|
+
if warn_msg:
|
2337
|
+
logger.warning(warn_msg)
|
2066
2338
|
|
2067
2339
|
# Offload back.
|
2068
2340
|
if is_model_cpu_offload:
|
@@ -2082,6 +2354,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2082
2354
|
lora_scale=1.0,
|
2083
2355
|
adapter_name=None,
|
2084
2356
|
_pipeline=None,
|
2357
|
+
low_cpu_mem_usage=False,
|
2085
2358
|
):
|
2086
2359
|
"""
|
2087
2360
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -2091,7 +2364,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2091
2364
|
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
2092
2365
|
additional `text_encoder` to distinguish between unet lora layers.
|
2093
2366
|
network_alphas (`Dict[str, float]`):
|
2094
|
-
|
2367
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
2368
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
2369
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
2095
2370
|
text_encoder (`CLIPTextModel`):
|
2096
2371
|
The text encoder model to load the LoRA layers into.
|
2097
2372
|
prefix (`str`):
|
@@ -2102,10 +2377,25 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2102
2377
|
adapter_name (`str`, *optional*):
|
2103
2378
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2104
2379
|
`default_{i}` where i is the total number of adapters being loaded.
|
2380
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
2105
2381
|
"""
|
2106
2382
|
if not USE_PEFT_BACKEND:
|
2107
2383
|
raise ValueError("PEFT backend is required for this method.")
|
2108
2384
|
|
2385
|
+
peft_kwargs = {}
|
2386
|
+
if low_cpu_mem_usage:
|
2387
|
+
if not is_peft_version(">=", "0.13.1"):
|
2388
|
+
raise ValueError(
|
2389
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2390
|
+
)
|
2391
|
+
if not is_transformers_version(">", "4.45.2"):
|
2392
|
+
# Note from sayakpaul: It's not in `transformers` stable yet.
|
2393
|
+
# https://github.com/huggingface/transformers/pull/33725/
|
2394
|
+
raise ValueError(
|
2395
|
+
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
2396
|
+
)
|
2397
|
+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
2398
|
+
|
2109
2399
|
from peft import LoraConfig
|
2110
2400
|
|
2111
2401
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
@@ -2176,6 +2466,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2176
2466
|
adapter_name=adapter_name,
|
2177
2467
|
adapter_state_dict=text_encoder_lora_state_dict,
|
2178
2468
|
peft_config=lora_config,
|
2469
|
+
**peft_kwargs,
|
2179
2470
|
)
|
2180
2471
|
|
2181
2472
|
# scale LoRA layers with `lora_scale`
|
@@ -2245,6 +2536,381 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2245
2536
|
)
|
2246
2537
|
|
2247
2538
|
|
2539
|
+
class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
2540
|
+
r"""
|
2541
|
+
Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoX`].
|
2542
|
+
"""
|
2543
|
+
|
2544
|
+
_lora_loadable_modules = ["transformer"]
|
2545
|
+
transformer_name = TRANSFORMER_NAME
|
2546
|
+
|
2547
|
+
@classmethod
|
2548
|
+
@validate_hf_hub_args
|
2549
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
2550
|
+
def lora_state_dict(
|
2551
|
+
cls,
|
2552
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
2553
|
+
**kwargs,
|
2554
|
+
):
|
2555
|
+
r"""
|
2556
|
+
Return state dict for lora weights and the network alphas.
|
2557
|
+
|
2558
|
+
<Tip warning={true}>
|
2559
|
+
|
2560
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
2561
|
+
|
2562
|
+
This function is experimental and might change in the future.
|
2563
|
+
|
2564
|
+
</Tip>
|
2565
|
+
|
2566
|
+
Parameters:
|
2567
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
2568
|
+
Can be either:
|
2569
|
+
|
2570
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
2571
|
+
the Hub.
|
2572
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
2573
|
+
with [`ModelMixin.save_pretrained`].
|
2574
|
+
- A [torch state
|
2575
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
2576
|
+
|
2577
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
2578
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
2579
|
+
is not used.
|
2580
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
2581
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
2582
|
+
cached versions if they exist.
|
2583
|
+
|
2584
|
+
proxies (`Dict[str, str]`, *optional*):
|
2585
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
2586
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
2587
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
2588
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
2589
|
+
won't be downloaded from the Hub.
|
2590
|
+
token (`str` or *bool*, *optional*):
|
2591
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
2592
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
2593
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
2594
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
2595
|
+
allowed by Git.
|
2596
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
2597
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
2598
|
+
|
2599
|
+
"""
|
2600
|
+
# Load the main state dict first which has the LoRA layers for either of
|
2601
|
+
# transformer and text encoder or both.
|
2602
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
2603
|
+
force_download = kwargs.pop("force_download", False)
|
2604
|
+
proxies = kwargs.pop("proxies", None)
|
2605
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
2606
|
+
token = kwargs.pop("token", None)
|
2607
|
+
revision = kwargs.pop("revision", None)
|
2608
|
+
subfolder = kwargs.pop("subfolder", None)
|
2609
|
+
weight_name = kwargs.pop("weight_name", None)
|
2610
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
2611
|
+
|
2612
|
+
allow_pickle = False
|
2613
|
+
if use_safetensors is None:
|
2614
|
+
use_safetensors = True
|
2615
|
+
allow_pickle = True
|
2616
|
+
|
2617
|
+
user_agent = {
|
2618
|
+
"file_type": "attn_procs_weights",
|
2619
|
+
"framework": "pytorch",
|
2620
|
+
}
|
2621
|
+
|
2622
|
+
state_dict = cls._fetch_state_dict(
|
2623
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
2624
|
+
weight_name=weight_name,
|
2625
|
+
use_safetensors=use_safetensors,
|
2626
|
+
local_files_only=local_files_only,
|
2627
|
+
cache_dir=cache_dir,
|
2628
|
+
force_download=force_download,
|
2629
|
+
proxies=proxies,
|
2630
|
+
token=token,
|
2631
|
+
revision=revision,
|
2632
|
+
subfolder=subfolder,
|
2633
|
+
user_agent=user_agent,
|
2634
|
+
allow_pickle=allow_pickle,
|
2635
|
+
)
|
2636
|
+
|
2637
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
2638
|
+
if is_dora_scale_present:
|
2639
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
2640
|
+
logger.warning(warn_msg)
|
2641
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
2642
|
+
|
2643
|
+
return state_dict
|
2644
|
+
|
2645
|
+
def load_lora_weights(
|
2646
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
2647
|
+
):
|
2648
|
+
"""
|
2649
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
2650
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
2651
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
2652
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
2653
|
+
dict is loaded into `self.transformer`.
|
2654
|
+
|
2655
|
+
Parameters:
|
2656
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
2657
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
2658
|
+
adapter_name (`str`, *optional*):
|
2659
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2660
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
2661
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
2662
|
+
kwargs (`dict`, *optional*):
|
2663
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
2664
|
+
"""
|
2665
|
+
if not USE_PEFT_BACKEND:
|
2666
|
+
raise ValueError("PEFT backend is required for this method.")
|
2667
|
+
|
2668
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
2669
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
2670
|
+
raise ValueError(
|
2671
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2672
|
+
)
|
2673
|
+
|
2674
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
2675
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
2676
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
2677
|
+
|
2678
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
2679
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
2680
|
+
|
2681
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
2682
|
+
if not is_correct_format:
|
2683
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
2684
|
+
|
2685
|
+
self.load_lora_into_transformer(
|
2686
|
+
state_dict,
|
2687
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
2688
|
+
adapter_name=adapter_name,
|
2689
|
+
_pipeline=self,
|
2690
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
2691
|
+
)
|
2692
|
+
|
2693
|
+
@classmethod
|
2694
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
|
2695
|
+
def load_lora_into_transformer(
|
2696
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
2697
|
+
):
|
2698
|
+
"""
|
2699
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
2700
|
+
|
2701
|
+
Parameters:
|
2702
|
+
state_dict (`dict`):
|
2703
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
2704
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
2705
|
+
encoder lora layers.
|
2706
|
+
transformer (`SD3Transformer2DModel`):
|
2707
|
+
The Transformer model to load the LoRA layers into.
|
2708
|
+
adapter_name (`str`, *optional*):
|
2709
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2710
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
2711
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
2712
|
+
"""
|
2713
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
2714
|
+
raise ValueError(
|
2715
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2716
|
+
)
|
2717
|
+
|
2718
|
+
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
2719
|
+
|
2720
|
+
keys = list(state_dict.keys())
|
2721
|
+
|
2722
|
+
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
|
2723
|
+
state_dict = {
|
2724
|
+
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
2725
|
+
}
|
2726
|
+
|
2727
|
+
if len(state_dict.keys()) > 0:
|
2728
|
+
# check with first key if is not in peft format
|
2729
|
+
first_key = next(iter(state_dict.keys()))
|
2730
|
+
if "lora_A" not in first_key:
|
2731
|
+
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
2732
|
+
|
2733
|
+
if adapter_name in getattr(transformer, "peft_config", {}):
|
2734
|
+
raise ValueError(
|
2735
|
+
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
2736
|
+
)
|
2737
|
+
|
2738
|
+
rank = {}
|
2739
|
+
for key, val in state_dict.items():
|
2740
|
+
if "lora_B" in key:
|
2741
|
+
rank[key] = val.shape[1]
|
2742
|
+
|
2743
|
+
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
|
2744
|
+
if "use_dora" in lora_config_kwargs:
|
2745
|
+
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
2746
|
+
raise ValueError(
|
2747
|
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
2748
|
+
)
|
2749
|
+
else:
|
2750
|
+
lora_config_kwargs.pop("use_dora")
|
2751
|
+
lora_config = LoraConfig(**lora_config_kwargs)
|
2752
|
+
|
2753
|
+
# adapter_name
|
2754
|
+
if adapter_name is None:
|
2755
|
+
adapter_name = get_adapter_name(transformer)
|
2756
|
+
|
2757
|
+
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
2758
|
+
# otherwise loading LoRA weights will lead to an error
|
2759
|
+
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
2760
|
+
|
2761
|
+
peft_kwargs = {}
|
2762
|
+
if is_peft_version(">=", "0.13.1"):
|
2763
|
+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
2764
|
+
|
2765
|
+
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
|
2766
|
+
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
2767
|
+
|
2768
|
+
warn_msg = ""
|
2769
|
+
if incompatible_keys is not None:
|
2770
|
+
# Check only for unexpected keys.
|
2771
|
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
2772
|
+
if unexpected_keys:
|
2773
|
+
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
2774
|
+
if lora_unexpected_keys:
|
2775
|
+
warn_msg = (
|
2776
|
+
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
2777
|
+
f" {', '.join(lora_unexpected_keys)}. "
|
2778
|
+
)
|
2779
|
+
|
2780
|
+
# Filter missing keys specific to the current adapter.
|
2781
|
+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
2782
|
+
if missing_keys:
|
2783
|
+
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
2784
|
+
if lora_missing_keys:
|
2785
|
+
warn_msg += (
|
2786
|
+
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
2787
|
+
f" {', '.join(lora_missing_keys)}."
|
2788
|
+
)
|
2789
|
+
|
2790
|
+
if warn_msg:
|
2791
|
+
logger.warning(warn_msg)
|
2792
|
+
|
2793
|
+
# Offload back.
|
2794
|
+
if is_model_cpu_offload:
|
2795
|
+
_pipeline.enable_model_cpu_offload()
|
2796
|
+
elif is_sequential_cpu_offload:
|
2797
|
+
_pipeline.enable_sequential_cpu_offload()
|
2798
|
+
# Unsafe code />
|
2799
|
+
|
2800
|
+
@classmethod
|
2801
|
+
# Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
|
2802
|
+
def save_lora_weights(
|
2803
|
+
cls,
|
2804
|
+
save_directory: Union[str, os.PathLike],
|
2805
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
2806
|
+
is_main_process: bool = True,
|
2807
|
+
weight_name: str = None,
|
2808
|
+
save_function: Callable = None,
|
2809
|
+
safe_serialization: bool = True,
|
2810
|
+
):
|
2811
|
+
r"""
|
2812
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
2813
|
+
|
2814
|
+
Arguments:
|
2815
|
+
save_directory (`str` or `os.PathLike`):
|
2816
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
2817
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
2818
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
2819
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
2820
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
2821
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
2822
|
+
process to avoid race conditions.
|
2823
|
+
save_function (`Callable`):
|
2824
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
2825
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
2826
|
+
`DIFFUSERS_SAVE_MODE`.
|
2827
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
2828
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
2829
|
+
"""
|
2830
|
+
state_dict = {}
|
2831
|
+
|
2832
|
+
if not transformer_lora_layers:
|
2833
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
2834
|
+
|
2835
|
+
if transformer_lora_layers:
|
2836
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
2837
|
+
|
2838
|
+
# Save the model
|
2839
|
+
cls.write_lora_layers(
|
2840
|
+
state_dict=state_dict,
|
2841
|
+
save_directory=save_directory,
|
2842
|
+
is_main_process=is_main_process,
|
2843
|
+
weight_name=weight_name,
|
2844
|
+
save_function=save_function,
|
2845
|
+
safe_serialization=safe_serialization,
|
2846
|
+
)
|
2847
|
+
|
2848
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
2849
|
+
def fuse_lora(
|
2850
|
+
self,
|
2851
|
+
components: List[str] = ["transformer", "text_encoder"],
|
2852
|
+
lora_scale: float = 1.0,
|
2853
|
+
safe_fusing: bool = False,
|
2854
|
+
adapter_names: Optional[List[str]] = None,
|
2855
|
+
**kwargs,
|
2856
|
+
):
|
2857
|
+
r"""
|
2858
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
2859
|
+
|
2860
|
+
<Tip warning={true}>
|
2861
|
+
|
2862
|
+
This is an experimental API.
|
2863
|
+
|
2864
|
+
</Tip>
|
2865
|
+
|
2866
|
+
Args:
|
2867
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
2868
|
+
lora_scale (`float`, defaults to 1.0):
|
2869
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
2870
|
+
safe_fusing (`bool`, defaults to `False`):
|
2871
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
2872
|
+
adapter_names (`List[str]`, *optional*):
|
2873
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
2874
|
+
|
2875
|
+
Example:
|
2876
|
+
|
2877
|
+
```py
|
2878
|
+
from diffusers import DiffusionPipeline
|
2879
|
+
import torch
|
2880
|
+
|
2881
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
2882
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
2883
|
+
).to("cuda")
|
2884
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
2885
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
2886
|
+
```
|
2887
|
+
"""
|
2888
|
+
super().fuse_lora(
|
2889
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
2890
|
+
)
|
2891
|
+
|
2892
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
2893
|
+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
2894
|
+
r"""
|
2895
|
+
Reverses the effect of
|
2896
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
2897
|
+
|
2898
|
+
<Tip warning={true}>
|
2899
|
+
|
2900
|
+
This is an experimental API.
|
2901
|
+
|
2902
|
+
</Tip>
|
2903
|
+
|
2904
|
+
Args:
|
2905
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
2906
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
2907
|
+
unfuse_text_encoder (`bool`, defaults to `True`):
|
2908
|
+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
2909
|
+
LoRA parameters then it won't have any effect.
|
2910
|
+
"""
|
2911
|
+
super().unfuse_lora(components=components)
|
2912
|
+
|
2913
|
+
|
2248
2914
|
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
2249
2915
|
def __init__(self, *args, **kwargs):
|
2250
2916
|
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|