diffusers 0.28.2__py3-none-any.whl → 0.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +9 -1
- diffusers/commands/env.py +1 -5
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +2 -1
- diffusers/loaders/__init__.py +2 -2
- diffusers/loaders/lora.py +406 -140
- diffusers/loaders/lora_conversion_utils.py +7 -1
- diffusers/loaders/single_file.py +1 -1
- diffusers/loaders/single_file_model.py +5 -0
- diffusers/loaders/single_file_utils.py +242 -2
- diffusers/loaders/unet.py +307 -272
- diffusers/models/__init__.py +5 -3
- diffusers/models/attention.py +125 -1
- diffusers/models/attention_processor.py +169 -1
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl.py +17 -6
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -2
- diffusers/models/autoencoders/consistency_decoder_vae.py +9 -9
- diffusers/models/autoencoders/vq_model.py +182 -0
- diffusers/models/controlnet_xs.py +6 -6
- diffusers/models/embeddings.py +112 -84
- diffusers/models/model_loading_utils.py +55 -0
- diffusers/models/modeling_utils.py +128 -17
- diffusers/models/normalization.py +11 -6
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/dual_transformer_2d.py +5 -4
- diffusers/models/transformers/hunyuan_transformer_2d.py +149 -2
- diffusers/models/transformers/prior_transformer.py +5 -5
- diffusers/models/transformers/transformer_2d.py +2 -2
- diffusers/models/transformers/transformer_sd3.py +344 -0
- diffusers/models/transformers/transformer_temporal.py +12 -10
- diffusers/models/unets/unet_1d.py +3 -3
- diffusers/models/unets/unet_2d.py +3 -3
- diffusers/models/unets/unet_2d_condition.py +4 -15
- diffusers/models/unets/unet_3d_condition.py +5 -17
- diffusers/models/unets/unet_i2vgen_xl.py +4 -4
- diffusers/models/unets/unet_motion_model.py +4 -4
- diffusers/models/unets/unet_spatio_temporal_condition.py +3 -3
- diffusers/models/vq_model.py +8 -165
- diffusers/pipelines/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +4 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +4 -3
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +4 -3
- diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +4 -3
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +24 -5
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +4 -3
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +4 -3
- diffusers/pipelines/marigold/marigold_image_processing.py +35 -20
- diffusers/pipelines/pia/pipeline_pia.py +4 -3
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +17 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +52 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +886 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +923 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +4 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +10 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +4 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +4 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +4 -3
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +4 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +4 -3
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +4 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +4 -3
- diffusers/schedulers/__init__.py +2 -0
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -3
- diffusers/schedulers/scheduling_edm_euler.py +2 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +287 -0
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/training_utils.py +4 -4
- diffusers/utils/__init__.py +3 -0
- diffusers/utils/constants.py +2 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +30 -0
- diffusers/utils/dynamic_modules_utils.py +15 -13
- diffusers/utils/hub_utils.py +106 -0
- diffusers/utils/import_utils.py +0 -1
- diffusers/utils/logging.py +3 -1
- diffusers/utils/state_dict_utils.py +2 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/METADATA +45 -45
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/RECORD +108 -111
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/WHEEL +1 -1
- diffusers/models/dual_transformer_2d.py +0 -20
- diffusers/models/prior_transformer.py +0 -12
- diffusers/models/t5_film_transformer.py +0 -70
- diffusers/models/transformer_2d.py +0 -25
- diffusers/models/transformer_temporal.py +0 -34
- diffusers/models/unet_1d.py +0 -26
- diffusers/models/unet_1d_blocks.py +0 -203
- diffusers/models/unet_2d.py +0 -27
- diffusers/models/unet_2d_blocks.py +0 -375
- diffusers/models/unet_2d_condition.py +0 -25
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/LICENSE +0 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/top_level.txt +0 -0
diffusers/loaders/lora.py
CHANGED
@@ -22,17 +22,14 @@ import torch
|
|
22
22
|
from huggingface_hub import model_info
|
23
23
|
from huggingface_hub.constants import HF_HUB_OFFLINE
|
24
24
|
from huggingface_hub.utils import validate_hf_hub_args
|
25
|
-
from packaging import version
|
26
25
|
from torch import nn
|
27
26
|
|
28
|
-
from .. import
|
29
|
-
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
27
|
+
from ..models.modeling_utils import load_state_dict
|
30
28
|
from ..utils import (
|
31
29
|
USE_PEFT_BACKEND,
|
32
30
|
_get_model_file,
|
33
31
|
convert_state_dict_to_diffusers,
|
34
32
|
convert_state_dict_to_peft,
|
35
|
-
convert_unet_state_dict_to_peft,
|
36
33
|
delete_adapter_layers,
|
37
34
|
get_adapter_name,
|
38
35
|
get_peft_kwargs,
|
@@ -119,13 +116,10 @@ class LoraLoaderMixin:
|
|
119
116
|
if not is_correct_format:
|
120
117
|
raise ValueError("Invalid LoRA checkpoint.")
|
121
118
|
|
122
|
-
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
123
|
-
|
124
119
|
self.load_lora_into_unet(
|
125
120
|
state_dict,
|
126
121
|
network_alphas=network_alphas,
|
127
122
|
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
128
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
129
123
|
adapter_name=adapter_name,
|
130
124
|
_pipeline=self,
|
131
125
|
)
|
@@ -136,7 +130,6 @@ class LoraLoaderMixin:
|
|
136
130
|
if not hasattr(self, "text_encoder")
|
137
131
|
else self.text_encoder,
|
138
132
|
lora_scale=self.lora_scale,
|
139
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
140
133
|
adapter_name=adapter_name,
|
141
134
|
_pipeline=self,
|
142
135
|
)
|
@@ -193,16 +186,8 @@ class LoraLoaderMixin:
|
|
193
186
|
allowed by Git.
|
194
187
|
subfolder (`str`, *optional*, defaults to `""`):
|
195
188
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
196
|
-
|
197
|
-
|
198
|
-
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
199
|
-
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
200
|
-
argument to `True` will raise an error.
|
201
|
-
mirror (`str`, *optional*):
|
202
|
-
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
203
|
-
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
204
|
-
information.
|
205
|
-
|
189
|
+
weight_name (`str`, *optional*, defaults to None):
|
190
|
+
Name of the serialized state dict file.
|
206
191
|
"""
|
207
192
|
# Load the main state dict first which has the LoRA layers for either of
|
208
193
|
# UNet and text encoder or both.
|
@@ -383,9 +368,7 @@ class LoraLoaderMixin:
|
|
383
368
|
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
384
369
|
|
385
370
|
@classmethod
|
386
|
-
def load_lora_into_unet(
|
387
|
-
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
|
388
|
-
):
|
371
|
+
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
|
389
372
|
"""
|
390
373
|
This will load the LoRA layers specified in `state_dict` into `unet`.
|
391
374
|
|
@@ -395,14 +378,11 @@ class LoraLoaderMixin:
|
|
395
378
|
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
396
379
|
encoder lora layers.
|
397
380
|
network_alphas (`Dict[str, float]`):
|
398
|
-
|
381
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
382
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
383
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
399
384
|
unet (`UNet2DConditionModel`):
|
400
385
|
The UNet model to load the LoRA layers into.
|
401
|
-
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
402
|
-
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
403
|
-
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
404
|
-
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
405
|
-
argument to `True` will raise an error.
|
406
386
|
adapter_name (`str`, *optional*):
|
407
387
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
408
388
|
`default_{i}` where i is the total number of adapters being loaded.
|
@@ -410,94 +390,18 @@ class LoraLoaderMixin:
|
|
410
390
|
if not USE_PEFT_BACKEND:
|
411
391
|
raise ValueError("PEFT backend is required for this method.")
|
412
392
|
|
413
|
-
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
414
|
-
|
415
|
-
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
416
393
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
417
394
|
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
418
395
|
# their prefixes.
|
419
396
|
keys = list(state_dict.keys())
|
397
|
+
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
|
420
398
|
|
421
|
-
if
|
399
|
+
if any(key.startswith(cls.unet_name) for key in keys) and not only_text_encoder:
|
422
400
|
# Load the layers corresponding to UNet.
|
423
401
|
logger.info(f"Loading {cls.unet_name}.")
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
if network_alphas is not None:
|
429
|
-
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)]
|
430
|
-
network_alphas = {
|
431
|
-
k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
432
|
-
}
|
433
|
-
|
434
|
-
else:
|
435
|
-
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
436
|
-
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
437
|
-
if not USE_PEFT_BACKEND:
|
438
|
-
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
439
|
-
logger.warning(warn_message)
|
440
|
-
|
441
|
-
if len(state_dict.keys()) > 0:
|
442
|
-
if adapter_name in getattr(unet, "peft_config", {}):
|
443
|
-
raise ValueError(
|
444
|
-
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
|
445
|
-
)
|
446
|
-
|
447
|
-
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
448
|
-
|
449
|
-
if network_alphas is not None:
|
450
|
-
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
|
451
|
-
# `convert_unet_state_dict_to_peft` method.
|
452
|
-
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
|
453
|
-
|
454
|
-
rank = {}
|
455
|
-
for key, val in state_dict.items():
|
456
|
-
if "lora_B" in key:
|
457
|
-
rank[key] = val.shape[1]
|
458
|
-
|
459
|
-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
|
460
|
-
if "use_dora" in lora_config_kwargs:
|
461
|
-
if lora_config_kwargs["use_dora"]:
|
462
|
-
if is_peft_version("<", "0.9.0"):
|
463
|
-
raise ValueError(
|
464
|
-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
465
|
-
)
|
466
|
-
else:
|
467
|
-
if is_peft_version("<", "0.9.0"):
|
468
|
-
lora_config_kwargs.pop("use_dora")
|
469
|
-
lora_config = LoraConfig(**lora_config_kwargs)
|
470
|
-
|
471
|
-
# adapter_name
|
472
|
-
if adapter_name is None:
|
473
|
-
adapter_name = get_adapter_name(unet)
|
474
|
-
|
475
|
-
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
476
|
-
# otherwise loading LoRA weights will lead to an error
|
477
|
-
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
478
|
-
|
479
|
-
inject_adapter_in_model(lora_config, unet, adapter_name=adapter_name)
|
480
|
-
incompatible_keys = set_peft_model_state_dict(unet, state_dict, adapter_name)
|
481
|
-
|
482
|
-
if incompatible_keys is not None:
|
483
|
-
# check only for unexpected keys
|
484
|
-
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
485
|
-
if unexpected_keys:
|
486
|
-
logger.warning(
|
487
|
-
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
488
|
-
f" {unexpected_keys}. "
|
489
|
-
)
|
490
|
-
|
491
|
-
# Offload back.
|
492
|
-
if is_model_cpu_offload:
|
493
|
-
_pipeline.enable_model_cpu_offload()
|
494
|
-
elif is_sequential_cpu_offload:
|
495
|
-
_pipeline.enable_sequential_cpu_offload()
|
496
|
-
# Unsafe code />
|
497
|
-
|
498
|
-
unet.load_attn_procs(
|
499
|
-
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
|
500
|
-
)
|
402
|
+
unet.load_attn_procs(
|
403
|
+
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
|
404
|
+
)
|
501
405
|
|
502
406
|
@classmethod
|
503
407
|
def load_lora_into_text_encoder(
|
@@ -507,7 +411,6 @@ class LoraLoaderMixin:
|
|
507
411
|
text_encoder,
|
508
412
|
prefix=None,
|
509
413
|
lora_scale=1.0,
|
510
|
-
low_cpu_mem_usage=None,
|
511
414
|
adapter_name=None,
|
512
415
|
_pipeline=None,
|
513
416
|
):
|
@@ -527,11 +430,6 @@ class LoraLoaderMixin:
|
|
527
430
|
lora_scale (`float`):
|
528
431
|
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
529
432
|
lora layer.
|
530
|
-
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
531
|
-
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
532
|
-
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
533
|
-
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
534
|
-
argument to `True` will raise an error.
|
535
433
|
adapter_name (`str`, *optional*):
|
536
434
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
537
435
|
`default_{i}` where i is the total number of adapters being loaded.
|
@@ -541,8 +439,6 @@ class LoraLoaderMixin:
|
|
541
439
|
|
542
440
|
from peft import LoraConfig
|
543
441
|
|
544
|
-
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
545
|
-
|
546
442
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
547
443
|
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
548
444
|
# their prefixes.
|
@@ -625,9 +521,7 @@ class LoraLoaderMixin:
|
|
625
521
|
# Unsafe code />
|
626
522
|
|
627
523
|
@classmethod
|
628
|
-
def load_lora_into_transformer(
|
629
|
-
cls, state_dict, network_alphas, transformer, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
|
630
|
-
):
|
524
|
+
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
|
631
525
|
"""
|
632
526
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
633
527
|
|
@@ -640,19 +534,12 @@ class LoraLoaderMixin:
|
|
640
534
|
See `LoRALinearLayer` for more details.
|
641
535
|
unet (`UNet2DConditionModel`):
|
642
536
|
The UNet model to load the LoRA layers into.
|
643
|
-
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
644
|
-
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
645
|
-
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
646
|
-
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
647
|
-
argument to `True` will raise an error.
|
648
537
|
adapter_name (`str`, *optional*):
|
649
538
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
650
539
|
`default_{i}` where i is the total number of adapters being loaded.
|
651
540
|
"""
|
652
541
|
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
653
542
|
|
654
|
-
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
655
|
-
|
656
543
|
keys = list(state_dict.keys())
|
657
544
|
|
658
545
|
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
|
@@ -846,22 +733,11 @@ class LoraLoaderMixin:
|
|
846
733
|
>>> ...
|
847
734
|
```
|
848
735
|
"""
|
849
|
-
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
850
|
-
|
851
736
|
if not USE_PEFT_BACKEND:
|
852
|
-
|
853
|
-
logger.warning(
|
854
|
-
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
|
855
|
-
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
|
856
|
-
)
|
737
|
+
raise ValueError("PEFT backend is required for this method.")
|
857
738
|
|
858
|
-
|
859
|
-
|
860
|
-
module.set_lora_layer(None)
|
861
|
-
else:
|
862
|
-
recurse_remove_peft_layers(unet)
|
863
|
-
if hasattr(unet, "peft_config"):
|
864
|
-
del unet.peft_config
|
739
|
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
740
|
+
unet.unload_lora()
|
865
741
|
|
866
742
|
# Safe to call the following regardless of LoRA.
|
867
743
|
self._remove_text_encoder_monkey_patch()
|
@@ -1461,3 +1337,393 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
|
1461
1337
|
if getattr(self.text_encoder_2, "peft_config", None) is not None:
|
1462
1338
|
del self.text_encoder_2.peft_config
|
1463
1339
|
self.text_encoder_2._hf_peft_config_loaded = None
|
1340
|
+
|
1341
|
+
|
1342
|
+
class SD3LoraLoaderMixin:
|
1343
|
+
r"""
|
1344
|
+
Load LoRA layers into [`SD3Transformer2DModel`].
|
1345
|
+
"""
|
1346
|
+
|
1347
|
+
transformer_name = TRANSFORMER_NAME
|
1348
|
+
num_fused_loras = 0
|
1349
|
+
|
1350
|
+
def load_lora_weights(
|
1351
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
1352
|
+
):
|
1353
|
+
"""
|
1354
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
1355
|
+
`self.text_encoder`.
|
1356
|
+
|
1357
|
+
All kwargs are forwarded to `self.lora_state_dict`.
|
1358
|
+
|
1359
|
+
See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
1360
|
+
|
1361
|
+
See [`~loaders.LoraLoaderMixin.load_lora_into_transformer`] for more details on how the state dict is loaded
|
1362
|
+
into `self.transformer`.
|
1363
|
+
|
1364
|
+
Parameters:
|
1365
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
1366
|
+
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
1367
|
+
kwargs (`dict`, *optional*):
|
1368
|
+
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
1369
|
+
adapter_name (`str`, *optional*):
|
1370
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1371
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
1372
|
+
"""
|
1373
|
+
if not USE_PEFT_BACKEND:
|
1374
|
+
raise ValueError("PEFT backend is required for this method.")
|
1375
|
+
|
1376
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
1377
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
1378
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
1379
|
+
|
1380
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
1381
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
1382
|
+
|
1383
|
+
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
1384
|
+
if not is_correct_format:
|
1385
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
1386
|
+
|
1387
|
+
self.load_lora_into_transformer(
|
1388
|
+
state_dict,
|
1389
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
1390
|
+
adapter_name=adapter_name,
|
1391
|
+
_pipeline=self,
|
1392
|
+
)
|
1393
|
+
|
1394
|
+
@classmethod
|
1395
|
+
@validate_hf_hub_args
|
1396
|
+
def lora_state_dict(
|
1397
|
+
cls,
|
1398
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
1399
|
+
**kwargs,
|
1400
|
+
):
|
1401
|
+
r"""
|
1402
|
+
Return state dict for lora weights and the network alphas.
|
1403
|
+
|
1404
|
+
<Tip warning={true}>
|
1405
|
+
|
1406
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
1407
|
+
|
1408
|
+
This function is experimental and might change in the future.
|
1409
|
+
|
1410
|
+
</Tip>
|
1411
|
+
|
1412
|
+
Parameters:
|
1413
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
1414
|
+
Can be either:
|
1415
|
+
|
1416
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
1417
|
+
the Hub.
|
1418
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
1419
|
+
with [`ModelMixin.save_pretrained`].
|
1420
|
+
- A [torch state
|
1421
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
1422
|
+
|
1423
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
1424
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
1425
|
+
is not used.
|
1426
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
1427
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
1428
|
+
cached versions if they exist.
|
1429
|
+
resume_download (`bool`, *optional*, defaults to `False`):
|
1430
|
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
1431
|
+
incompletely downloaded files are deleted.
|
1432
|
+
proxies (`Dict[str, str]`, *optional*):
|
1433
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
1434
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
1435
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
1436
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
1437
|
+
won't be downloaded from the Hub.
|
1438
|
+
token (`str` or *bool*, *optional*):
|
1439
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
1440
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
1441
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
1442
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
1443
|
+
allowed by Git.
|
1444
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
1445
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
1446
|
+
|
1447
|
+
"""
|
1448
|
+
# Load the main state dict first which has the LoRA layers for either of
|
1449
|
+
# UNet and text encoder or both.
|
1450
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
1451
|
+
force_download = kwargs.pop("force_download", False)
|
1452
|
+
resume_download = kwargs.pop("resume_download", False)
|
1453
|
+
proxies = kwargs.pop("proxies", None)
|
1454
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
1455
|
+
token = kwargs.pop("token", None)
|
1456
|
+
revision = kwargs.pop("revision", None)
|
1457
|
+
subfolder = kwargs.pop("subfolder", None)
|
1458
|
+
weight_name = kwargs.pop("weight_name", None)
|
1459
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
1460
|
+
|
1461
|
+
allow_pickle = False
|
1462
|
+
if use_safetensors is None:
|
1463
|
+
use_safetensors = True
|
1464
|
+
allow_pickle = True
|
1465
|
+
|
1466
|
+
user_agent = {
|
1467
|
+
"file_type": "attn_procs_weights",
|
1468
|
+
"framework": "pytorch",
|
1469
|
+
}
|
1470
|
+
|
1471
|
+
model_file = None
|
1472
|
+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
1473
|
+
# Let's first try to load .safetensors weights
|
1474
|
+
if (use_safetensors and weight_name is None) or (
|
1475
|
+
weight_name is not None and weight_name.endswith(".safetensors")
|
1476
|
+
):
|
1477
|
+
try:
|
1478
|
+
model_file = _get_model_file(
|
1479
|
+
pretrained_model_name_or_path_or_dict,
|
1480
|
+
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
1481
|
+
cache_dir=cache_dir,
|
1482
|
+
force_download=force_download,
|
1483
|
+
resume_download=resume_download,
|
1484
|
+
proxies=proxies,
|
1485
|
+
local_files_only=local_files_only,
|
1486
|
+
token=token,
|
1487
|
+
revision=revision,
|
1488
|
+
subfolder=subfolder,
|
1489
|
+
user_agent=user_agent,
|
1490
|
+
)
|
1491
|
+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
1492
|
+
except (IOError, safetensors.SafetensorError) as e:
|
1493
|
+
if not allow_pickle:
|
1494
|
+
raise e
|
1495
|
+
# try loading non-safetensors weights
|
1496
|
+
model_file = None
|
1497
|
+
pass
|
1498
|
+
|
1499
|
+
if model_file is None:
|
1500
|
+
model_file = _get_model_file(
|
1501
|
+
pretrained_model_name_or_path_or_dict,
|
1502
|
+
weights_name=weight_name or LORA_WEIGHT_NAME,
|
1503
|
+
cache_dir=cache_dir,
|
1504
|
+
force_download=force_download,
|
1505
|
+
resume_download=resume_download,
|
1506
|
+
proxies=proxies,
|
1507
|
+
local_files_only=local_files_only,
|
1508
|
+
token=token,
|
1509
|
+
revision=revision,
|
1510
|
+
subfolder=subfolder,
|
1511
|
+
user_agent=user_agent,
|
1512
|
+
)
|
1513
|
+
state_dict = load_state_dict(model_file)
|
1514
|
+
else:
|
1515
|
+
state_dict = pretrained_model_name_or_path_or_dict
|
1516
|
+
|
1517
|
+
return state_dict
|
1518
|
+
|
1519
|
+
@classmethod
|
1520
|
+
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
|
1521
|
+
"""
|
1522
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
1523
|
+
|
1524
|
+
Parameters:
|
1525
|
+
state_dict (`dict`):
|
1526
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
1527
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
1528
|
+
encoder lora layers.
|
1529
|
+
transformer (`SD3Transformer2DModel`):
|
1530
|
+
The Transformer model to load the LoRA layers into.
|
1531
|
+
adapter_name (`str`, *optional*):
|
1532
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1533
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
1534
|
+
"""
|
1535
|
+
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
1536
|
+
|
1537
|
+
keys = list(state_dict.keys())
|
1538
|
+
|
1539
|
+
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
|
1540
|
+
state_dict = {
|
1541
|
+
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
1542
|
+
}
|
1543
|
+
|
1544
|
+
if len(state_dict.keys()) > 0:
|
1545
|
+
if adapter_name in getattr(transformer, "peft_config", {}):
|
1546
|
+
raise ValueError(
|
1547
|
+
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
1548
|
+
)
|
1549
|
+
|
1550
|
+
rank = {}
|
1551
|
+
for key, val in state_dict.items():
|
1552
|
+
if "lora_B" in key:
|
1553
|
+
rank[key] = val.shape[1]
|
1554
|
+
|
1555
|
+
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
|
1556
|
+
if "use_dora" in lora_config_kwargs:
|
1557
|
+
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
1558
|
+
raise ValueError(
|
1559
|
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
1560
|
+
)
|
1561
|
+
else:
|
1562
|
+
lora_config_kwargs.pop("use_dora")
|
1563
|
+
lora_config = LoraConfig(**lora_config_kwargs)
|
1564
|
+
|
1565
|
+
# adapter_name
|
1566
|
+
if adapter_name is None:
|
1567
|
+
adapter_name = get_adapter_name(transformer)
|
1568
|
+
|
1569
|
+
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
1570
|
+
# otherwise loading LoRA weights will lead to an error
|
1571
|
+
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
1572
|
+
|
1573
|
+
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
1574
|
+
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
1575
|
+
|
1576
|
+
if incompatible_keys is not None:
|
1577
|
+
# check only for unexpected keys
|
1578
|
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
1579
|
+
if unexpected_keys:
|
1580
|
+
logger.warning(
|
1581
|
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
1582
|
+
f" {unexpected_keys}. "
|
1583
|
+
)
|
1584
|
+
|
1585
|
+
# Offload back.
|
1586
|
+
if is_model_cpu_offload:
|
1587
|
+
_pipeline.enable_model_cpu_offload()
|
1588
|
+
elif is_sequential_cpu_offload:
|
1589
|
+
_pipeline.enable_sequential_cpu_offload()
|
1590
|
+
# Unsafe code />
|
1591
|
+
|
1592
|
+
@classmethod
|
1593
|
+
def save_lora_weights(
|
1594
|
+
cls,
|
1595
|
+
save_directory: Union[str, os.PathLike],
|
1596
|
+
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
1597
|
+
is_main_process: bool = True,
|
1598
|
+
weight_name: str = None,
|
1599
|
+
save_function: Callable = None,
|
1600
|
+
safe_serialization: bool = True,
|
1601
|
+
):
|
1602
|
+
r"""
|
1603
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
1604
|
+
|
1605
|
+
Arguments:
|
1606
|
+
save_directory (`str` or `os.PathLike`):
|
1607
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
1608
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
1609
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
1610
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
1611
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
1612
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
1613
|
+
process to avoid race conditions.
|
1614
|
+
save_function (`Callable`):
|
1615
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
1616
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
1617
|
+
`DIFFUSERS_SAVE_MODE`.
|
1618
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
1619
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
1620
|
+
"""
|
1621
|
+
state_dict = {}
|
1622
|
+
|
1623
|
+
def pack_weights(layers, prefix):
|
1624
|
+
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
1625
|
+
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
1626
|
+
return layers_state_dict
|
1627
|
+
|
1628
|
+
if not transformer_lora_layers:
|
1629
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
1630
|
+
|
1631
|
+
if transformer_lora_layers:
|
1632
|
+
state_dict.update(pack_weights(transformer_lora_layers, cls.transformer_name))
|
1633
|
+
|
1634
|
+
# Save the model
|
1635
|
+
cls.write_lora_layers(
|
1636
|
+
state_dict=state_dict,
|
1637
|
+
save_directory=save_directory,
|
1638
|
+
is_main_process=is_main_process,
|
1639
|
+
weight_name=weight_name,
|
1640
|
+
save_function=save_function,
|
1641
|
+
safe_serialization=safe_serialization,
|
1642
|
+
)
|
1643
|
+
|
1644
|
+
@staticmethod
|
1645
|
+
def write_lora_layers(
|
1646
|
+
state_dict: Dict[str, torch.Tensor],
|
1647
|
+
save_directory: str,
|
1648
|
+
is_main_process: bool,
|
1649
|
+
weight_name: str,
|
1650
|
+
save_function: Callable,
|
1651
|
+
safe_serialization: bool,
|
1652
|
+
):
|
1653
|
+
if os.path.isfile(save_directory):
|
1654
|
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
1655
|
+
return
|
1656
|
+
|
1657
|
+
if save_function is None:
|
1658
|
+
if safe_serialization:
|
1659
|
+
|
1660
|
+
def save_function(weights, filename):
|
1661
|
+
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
1662
|
+
|
1663
|
+
else:
|
1664
|
+
save_function = torch.save
|
1665
|
+
|
1666
|
+
os.makedirs(save_directory, exist_ok=True)
|
1667
|
+
|
1668
|
+
if weight_name is None:
|
1669
|
+
if safe_serialization:
|
1670
|
+
weight_name = LORA_WEIGHT_NAME_SAFE
|
1671
|
+
else:
|
1672
|
+
weight_name = LORA_WEIGHT_NAME
|
1673
|
+
|
1674
|
+
save_path = Path(save_directory, weight_name).as_posix()
|
1675
|
+
save_function(state_dict, save_path)
|
1676
|
+
logger.info(f"Model weights saved in {save_path}")
|
1677
|
+
|
1678
|
+
def unload_lora_weights(self):
|
1679
|
+
"""
|
1680
|
+
Unloads the LoRA parameters.
|
1681
|
+
|
1682
|
+
Examples:
|
1683
|
+
|
1684
|
+
```python
|
1685
|
+
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
|
1686
|
+
>>> pipeline.unload_lora_weights()
|
1687
|
+
>>> ...
|
1688
|
+
```
|
1689
|
+
"""
|
1690
|
+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
1691
|
+
recurse_remove_peft_layers(transformer)
|
1692
|
+
if hasattr(transformer, "peft_config"):
|
1693
|
+
del transformer.peft_config
|
1694
|
+
|
1695
|
+
@classmethod
|
1696
|
+
# Copied from diffusers.loaders.lora.LoraLoaderMixin._optionally_disable_offloading
|
1697
|
+
def _optionally_disable_offloading(cls, _pipeline):
|
1698
|
+
"""
|
1699
|
+
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
1700
|
+
|
1701
|
+
Args:
|
1702
|
+
_pipeline (`DiffusionPipeline`):
|
1703
|
+
The pipeline to disable offloading for.
|
1704
|
+
|
1705
|
+
Returns:
|
1706
|
+
tuple:
|
1707
|
+
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
1708
|
+
"""
|
1709
|
+
is_model_cpu_offload = False
|
1710
|
+
is_sequential_cpu_offload = False
|
1711
|
+
|
1712
|
+
if _pipeline is not None and _pipeline.hf_device_map is None:
|
1713
|
+
for _, component in _pipeline.components.items():
|
1714
|
+
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
1715
|
+
if not is_model_cpu_offload:
|
1716
|
+
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
1717
|
+
if not is_sequential_cpu_offload:
|
1718
|
+
is_sequential_cpu_offload = (
|
1719
|
+
isinstance(component._hf_hook, AlignDevicesHook)
|
1720
|
+
or hasattr(component._hf_hook, "hooks")
|
1721
|
+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
1722
|
+
)
|
1723
|
+
|
1724
|
+
logger.info(
|
1725
|
+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
1726
|
+
)
|
1727
|
+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
1728
|
+
|
1729
|
+
return (is_model_cpu_offload, is_sequential_cpu_offload)
|