diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +72 -26
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
diffusers/loaders/peft.py
CHANGED
@@ -13,30 +13,103 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
import inspect
|
16
|
+
import os
|
16
17
|
from functools import partial
|
18
|
+
from pathlib import Path
|
17
19
|
from typing import Dict, List, Optional, Union
|
18
20
|
|
21
|
+
import safetensors
|
22
|
+
import torch
|
23
|
+
import torch.nn as nn
|
24
|
+
|
19
25
|
from ..utils import (
|
20
26
|
MIN_PEFT_VERSION,
|
21
27
|
USE_PEFT_BACKEND,
|
22
28
|
check_peft_version,
|
29
|
+
convert_unet_state_dict_to_peft,
|
23
30
|
delete_adapter_layers,
|
31
|
+
get_adapter_name,
|
32
|
+
get_peft_kwargs,
|
33
|
+
is_accelerate_available,
|
24
34
|
is_peft_available,
|
35
|
+
is_peft_version,
|
36
|
+
logging,
|
25
37
|
set_adapter_layers,
|
26
38
|
set_weights_and_activate_adapters,
|
27
39
|
)
|
40
|
+
from .lora_base import _fetch_state_dict
|
28
41
|
from .unet_loader_utils import _maybe_expand_lora_scales
|
29
42
|
|
30
43
|
|
44
|
+
if is_accelerate_available():
|
45
|
+
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
46
|
+
|
47
|
+
logger = logging.get_logger(__name__)
|
48
|
+
|
31
49
|
_SET_ADAPTER_SCALE_FN_MAPPING = {
|
32
50
|
"UNet2DConditionModel": _maybe_expand_lora_scales,
|
33
51
|
"UNetMotionModel": _maybe_expand_lora_scales,
|
34
52
|
"SD3Transformer2DModel": lambda model_cls, weights: weights,
|
35
53
|
"FluxTransformer2DModel": lambda model_cls, weights: weights,
|
36
54
|
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
|
55
|
+
"MochiTransformer3DModel": lambda model_cls, weights: weights,
|
56
|
+
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
|
57
|
+
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
|
58
|
+
"SanaTransformer2DModel": lambda model_cls, weights: weights,
|
37
59
|
}
|
38
60
|
|
39
61
|
|
62
|
+
def _maybe_adjust_config(config):
|
63
|
+
"""
|
64
|
+
We may run into some ambiguous configuration values when a model has module names, sharing a common prefix
|
65
|
+
(`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This
|
66
|
+
method removes the ambiguity by following what is described here:
|
67
|
+
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
|
68
|
+
"""
|
69
|
+
rank_pattern = config["rank_pattern"].copy()
|
70
|
+
target_modules = config["target_modules"]
|
71
|
+
original_r = config["r"]
|
72
|
+
|
73
|
+
for key in list(rank_pattern.keys()):
|
74
|
+
key_rank = rank_pattern[key]
|
75
|
+
|
76
|
+
# try to detect ambiguity
|
77
|
+
# `target_modules` can also be a str, in which case this loop would loop
|
78
|
+
# over the chars of the str. The technically correct way to match LoRA keys
|
79
|
+
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
|
80
|
+
# But this cuts it for now.
|
81
|
+
exact_matches = [mod for mod in target_modules if mod == key]
|
82
|
+
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
|
83
|
+
ambiguous_key = key
|
84
|
+
|
85
|
+
if exact_matches and substring_matches:
|
86
|
+
# if ambiguous we update the rank associated with the ambiguous key (`proj_out`, for example)
|
87
|
+
config["r"] = key_rank
|
88
|
+
# remove the ambiguous key from `rank_pattern` and update its rank to `r`, instead
|
89
|
+
del config["rank_pattern"][key]
|
90
|
+
for mod in substring_matches:
|
91
|
+
# avoid overwriting if the module already has a specific rank
|
92
|
+
if mod not in config["rank_pattern"]:
|
93
|
+
config["rank_pattern"][mod] = original_r
|
94
|
+
|
95
|
+
# update the rest of the keys with the `original_r`
|
96
|
+
for mod in target_modules:
|
97
|
+
if mod != ambiguous_key and mod not in config["rank_pattern"]:
|
98
|
+
config["rank_pattern"][mod] = original_r
|
99
|
+
|
100
|
+
# handle alphas to deal with cases like
|
101
|
+
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
|
102
|
+
has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
|
103
|
+
if has_different_ranks:
|
104
|
+
config["lora_alpha"] = config["r"]
|
105
|
+
alpha_pattern = {}
|
106
|
+
for module_name, rank in config["rank_pattern"].items():
|
107
|
+
alpha_pattern[module_name] = rank
|
108
|
+
config["alpha_pattern"] = alpha_pattern
|
109
|
+
|
110
|
+
return config
|
111
|
+
|
112
|
+
|
40
113
|
class PeftAdapterMixin:
|
41
114
|
"""
|
42
115
|
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
|
@@ -53,6 +126,312 @@ class PeftAdapterMixin:
|
|
53
126
|
|
54
127
|
_hf_peft_config_loaded = False
|
55
128
|
|
129
|
+
@classmethod
|
130
|
+
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
131
|
+
def _optionally_disable_offloading(cls, _pipeline):
|
132
|
+
"""
|
133
|
+
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
134
|
+
|
135
|
+
Args:
|
136
|
+
_pipeline (`DiffusionPipeline`):
|
137
|
+
The pipeline to disable offloading for.
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
tuple:
|
141
|
+
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
142
|
+
"""
|
143
|
+
is_model_cpu_offload = False
|
144
|
+
is_sequential_cpu_offload = False
|
145
|
+
|
146
|
+
if _pipeline is not None and _pipeline.hf_device_map is None:
|
147
|
+
for _, component in _pipeline.components.items():
|
148
|
+
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
149
|
+
if not is_model_cpu_offload:
|
150
|
+
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
151
|
+
if not is_sequential_cpu_offload:
|
152
|
+
is_sequential_cpu_offload = (
|
153
|
+
isinstance(component._hf_hook, AlignDevicesHook)
|
154
|
+
or hasattr(component._hf_hook, "hooks")
|
155
|
+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
156
|
+
)
|
157
|
+
|
158
|
+
logger.info(
|
159
|
+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
160
|
+
)
|
161
|
+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
162
|
+
|
163
|
+
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
164
|
+
|
165
|
+
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
|
166
|
+
r"""
|
167
|
+
Loads a LoRA adapter into the underlying model.
|
168
|
+
|
169
|
+
Parameters:
|
170
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
171
|
+
Can be either:
|
172
|
+
|
173
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
174
|
+
the Hub.
|
175
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
176
|
+
with [`ModelMixin.save_pretrained`].
|
177
|
+
- A [torch state
|
178
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
179
|
+
|
180
|
+
prefix (`str`, *optional*): Prefix to filter the state dict.
|
181
|
+
|
182
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
183
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
184
|
+
is not used.
|
185
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
186
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
187
|
+
cached versions if they exist.
|
188
|
+
proxies (`Dict[str, str]`, *optional*):
|
189
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
190
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
191
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
192
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
193
|
+
won't be downloaded from the Hub.
|
194
|
+
token (`str` or *bool*, *optional*):
|
195
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
196
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
197
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
198
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
199
|
+
allowed by Git.
|
200
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
201
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
202
|
+
network_alphas (`Dict[str, float]`):
|
203
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
204
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
205
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
206
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
207
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
208
|
+
weights.
|
209
|
+
"""
|
210
|
+
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
211
|
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
212
|
+
|
213
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
214
|
+
force_download = kwargs.pop("force_download", False)
|
215
|
+
proxies = kwargs.pop("proxies", None)
|
216
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
217
|
+
token = kwargs.pop("token", None)
|
218
|
+
revision = kwargs.pop("revision", None)
|
219
|
+
subfolder = kwargs.pop("subfolder", None)
|
220
|
+
weight_name = kwargs.pop("weight_name", None)
|
221
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
222
|
+
adapter_name = kwargs.pop("adapter_name", None)
|
223
|
+
network_alphas = kwargs.pop("network_alphas", None)
|
224
|
+
_pipeline = kwargs.pop("_pipeline", None)
|
225
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
|
226
|
+
allow_pickle = False
|
227
|
+
|
228
|
+
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
|
229
|
+
raise ValueError(
|
230
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
231
|
+
)
|
232
|
+
|
233
|
+
user_agent = {
|
234
|
+
"file_type": "attn_procs_weights",
|
235
|
+
"framework": "pytorch",
|
236
|
+
}
|
237
|
+
|
238
|
+
state_dict = _fetch_state_dict(
|
239
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
240
|
+
weight_name=weight_name,
|
241
|
+
use_safetensors=use_safetensors,
|
242
|
+
local_files_only=local_files_only,
|
243
|
+
cache_dir=cache_dir,
|
244
|
+
force_download=force_download,
|
245
|
+
proxies=proxies,
|
246
|
+
token=token,
|
247
|
+
revision=revision,
|
248
|
+
subfolder=subfolder,
|
249
|
+
user_agent=user_agent,
|
250
|
+
allow_pickle=allow_pickle,
|
251
|
+
)
|
252
|
+
if network_alphas is not None and prefix is None:
|
253
|
+
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
|
254
|
+
|
255
|
+
if prefix is not None:
|
256
|
+
keys = list(state_dict.keys())
|
257
|
+
model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
|
258
|
+
if len(model_keys) > 0:
|
259
|
+
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}
|
260
|
+
|
261
|
+
if len(state_dict) > 0:
|
262
|
+
if adapter_name in getattr(self, "peft_config", {}):
|
263
|
+
raise ValueError(
|
264
|
+
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
|
265
|
+
)
|
266
|
+
|
267
|
+
# check with first key if is not in peft format
|
268
|
+
first_key = next(iter(state_dict.keys()))
|
269
|
+
if "lora_A" not in first_key:
|
270
|
+
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
271
|
+
|
272
|
+
rank = {}
|
273
|
+
for key, val in state_dict.items():
|
274
|
+
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
|
275
|
+
# Bias layers in LoRA only have a single dimension
|
276
|
+
if "lora_B" in key and val.ndim > 1:
|
277
|
+
rank[key] = val.shape[1]
|
278
|
+
|
279
|
+
if network_alphas is not None and len(network_alphas) >= 1:
|
280
|
+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
|
281
|
+
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
282
|
+
|
283
|
+
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
284
|
+
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
|
285
|
+
|
286
|
+
if "use_dora" in lora_config_kwargs:
|
287
|
+
if lora_config_kwargs["use_dora"]:
|
288
|
+
if is_peft_version("<", "0.9.0"):
|
289
|
+
raise ValueError(
|
290
|
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
291
|
+
)
|
292
|
+
else:
|
293
|
+
if is_peft_version("<", "0.9.0"):
|
294
|
+
lora_config_kwargs.pop("use_dora")
|
295
|
+
|
296
|
+
if "lora_bias" in lora_config_kwargs:
|
297
|
+
if lora_config_kwargs["lora_bias"]:
|
298
|
+
if is_peft_version("<=", "0.13.2"):
|
299
|
+
raise ValueError(
|
300
|
+
"You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
|
301
|
+
)
|
302
|
+
else:
|
303
|
+
if is_peft_version("<=", "0.13.2"):
|
304
|
+
lora_config_kwargs.pop("lora_bias")
|
305
|
+
|
306
|
+
lora_config = LoraConfig(**lora_config_kwargs)
|
307
|
+
# adapter_name
|
308
|
+
if adapter_name is None:
|
309
|
+
adapter_name = get_adapter_name(self)
|
310
|
+
|
311
|
+
# <Unsafe code
|
312
|
+
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
|
313
|
+
# Now we remove any existing hooks to `_pipeline`.
|
314
|
+
|
315
|
+
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
316
|
+
# otherwise loading LoRA weights will lead to an error
|
317
|
+
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
318
|
+
|
319
|
+
peft_kwargs = {}
|
320
|
+
if is_peft_version(">=", "0.13.1"):
|
321
|
+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
322
|
+
|
323
|
+
# To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
|
324
|
+
# we should also delete the `peft_config` associated to the `adapter_name`.
|
325
|
+
try:
|
326
|
+
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
327
|
+
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
328
|
+
except RuntimeError as e:
|
329
|
+
for module in self.modules():
|
330
|
+
if isinstance(module, BaseTunerLayer):
|
331
|
+
active_adapters = module.active_adapters
|
332
|
+
for active_adapter in active_adapters:
|
333
|
+
if adapter_name in active_adapter:
|
334
|
+
module.delete_adapter(adapter_name)
|
335
|
+
|
336
|
+
self.peft_config.pop(adapter_name)
|
337
|
+
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
|
338
|
+
raise
|
339
|
+
|
340
|
+
warn_msg = ""
|
341
|
+
if incompatible_keys is not None:
|
342
|
+
# Check only for unexpected keys.
|
343
|
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
344
|
+
if unexpected_keys:
|
345
|
+
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
346
|
+
if lora_unexpected_keys:
|
347
|
+
warn_msg = (
|
348
|
+
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
349
|
+
f" {', '.join(lora_unexpected_keys)}. "
|
350
|
+
)
|
351
|
+
|
352
|
+
# Filter missing keys specific to the current adapter.
|
353
|
+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
354
|
+
if missing_keys:
|
355
|
+
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
356
|
+
if lora_missing_keys:
|
357
|
+
warn_msg += (
|
358
|
+
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
359
|
+
f" {', '.join(lora_missing_keys)}."
|
360
|
+
)
|
361
|
+
|
362
|
+
if warn_msg:
|
363
|
+
logger.warning(warn_msg)
|
364
|
+
|
365
|
+
# Offload back.
|
366
|
+
if is_model_cpu_offload:
|
367
|
+
_pipeline.enable_model_cpu_offload()
|
368
|
+
elif is_sequential_cpu_offload:
|
369
|
+
_pipeline.enable_sequential_cpu_offload()
|
370
|
+
# Unsafe code />
|
371
|
+
|
372
|
+
def save_lora_adapter(
|
373
|
+
self,
|
374
|
+
save_directory,
|
375
|
+
adapter_name: str = "default",
|
376
|
+
upcast_before_saving: bool = False,
|
377
|
+
safe_serialization: bool = True,
|
378
|
+
weight_name: Optional[str] = None,
|
379
|
+
):
|
380
|
+
"""
|
381
|
+
Save the LoRA parameters corresponding to the underlying model.
|
382
|
+
|
383
|
+
Arguments:
|
384
|
+
save_directory (`str` or `os.PathLike`):
|
385
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
386
|
+
adapter_name: (`str`, defaults to "default"): The name of the adapter to serialize. Useful when the
|
387
|
+
underlying model has multiple adapters loaded.
|
388
|
+
upcast_before_saving (`bool`, defaults to `False`):
|
389
|
+
Whether to cast the underlying model to `torch.float32` before serialization.
|
390
|
+
save_function (`Callable`):
|
391
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
392
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
393
|
+
`DIFFUSERS_SAVE_MODE`.
|
394
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
395
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
396
|
+
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
|
397
|
+
"""
|
398
|
+
from peft.utils import get_peft_model_state_dict
|
399
|
+
|
400
|
+
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
401
|
+
|
402
|
+
if adapter_name is None:
|
403
|
+
adapter_name = get_adapter_name(self)
|
404
|
+
|
405
|
+
if adapter_name not in getattr(self, "peft_config", {}):
|
406
|
+
raise ValueError(f"Adapter name {adapter_name} not found in the model.")
|
407
|
+
|
408
|
+
lora_layers_to_save = get_peft_model_state_dict(
|
409
|
+
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
|
410
|
+
)
|
411
|
+
if os.path.isfile(save_directory):
|
412
|
+
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
|
413
|
+
|
414
|
+
if safe_serialization:
|
415
|
+
|
416
|
+
def save_function(weights, filename):
|
417
|
+
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
418
|
+
|
419
|
+
else:
|
420
|
+
save_function = torch.save
|
421
|
+
|
422
|
+
os.makedirs(save_directory, exist_ok=True)
|
423
|
+
|
424
|
+
if weight_name is None:
|
425
|
+
if safe_serialization:
|
426
|
+
weight_name = LORA_WEIGHT_NAME_SAFE
|
427
|
+
else:
|
428
|
+
weight_name = LORA_WEIGHT_NAME
|
429
|
+
|
430
|
+
# TODO: we could consider saving the `peft_config` as well.
|
431
|
+
save_path = Path(save_directory, weight_name).as_posix()
|
432
|
+
save_function(lora_layers_to_save, save_path)
|
433
|
+
logger.info(f"Model weights saved in {save_path}")
|
434
|
+
|
56
435
|
def set_adapters(
|
57
436
|
self,
|
58
437
|
adapter_names: Union[List[str], str],
|
@@ -17,16 +17,23 @@ import re
|
|
17
17
|
from contextlib import nullcontext
|
18
18
|
from typing import Optional
|
19
19
|
|
20
|
+
import torch
|
20
21
|
from huggingface_hub.utils import validate_hf_hub_args
|
21
22
|
|
23
|
+
from ..quantizers import DiffusersAutoQuantizer
|
22
24
|
from ..utils import deprecate, is_accelerate_available, logging
|
23
25
|
from .single_file_utils import (
|
24
26
|
SingleFileComponentError,
|
25
27
|
convert_animatediff_checkpoint_to_diffusers,
|
28
|
+
convert_autoencoder_dc_checkpoint_to_diffusers,
|
26
29
|
convert_controlnet_checkpoint,
|
27
30
|
convert_flux_transformer_checkpoint_to_diffusers,
|
31
|
+
convert_hunyuan_video_transformer_to_diffusers,
|
28
32
|
convert_ldm_unet_checkpoint,
|
29
33
|
convert_ldm_vae_checkpoint,
|
34
|
+
convert_ltx_transformer_checkpoint_to_diffusers,
|
35
|
+
convert_ltx_vae_checkpoint_to_diffusers,
|
36
|
+
convert_mochi_transformer_checkpoint_to_diffusers,
|
30
37
|
convert_sd3_transformer_checkpoint_to_diffusers,
|
31
38
|
convert_stable_cascade_unet_single_file_to_diffusers,
|
32
39
|
create_controlnet_diffusers_config_from_ldm,
|
@@ -82,6 +89,23 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
|
82
89
|
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
|
83
90
|
"default_subfolder": "transformer",
|
84
91
|
},
|
92
|
+
"LTXVideoTransformer3DModel": {
|
93
|
+
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
|
94
|
+
"default_subfolder": "transformer",
|
95
|
+
},
|
96
|
+
"AutoencoderKLLTXVideo": {
|
97
|
+
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
|
98
|
+
"default_subfolder": "vae",
|
99
|
+
},
|
100
|
+
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
|
101
|
+
"MochiTransformer3DModel": {
|
102
|
+
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
|
103
|
+
"default_subfolder": "transformer",
|
104
|
+
},
|
105
|
+
"HunyuanVideoTransformer3DModel": {
|
106
|
+
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
|
107
|
+
"default_subfolder": "transformer",
|
108
|
+
},
|
85
109
|
}
|
86
110
|
|
87
111
|
|
@@ -201,7 +225,10 @@ class FromOriginalModelMixin:
|
|
201
225
|
local_files_only = kwargs.pop("local_files_only", None)
|
202
226
|
subfolder = kwargs.pop("subfolder", None)
|
203
227
|
revision = kwargs.pop("revision", None)
|
228
|
+
config_revision = kwargs.pop("config_revision", None)
|
204
229
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
230
|
+
quantization_config = kwargs.pop("quantization_config", None)
|
231
|
+
device = kwargs.pop("device", None)
|
205
232
|
|
206
233
|
if isinstance(pretrained_model_link_or_path_or_dict, dict):
|
207
234
|
checkpoint = pretrained_model_link_or_path_or_dict
|
@@ -215,11 +242,17 @@ class FromOriginalModelMixin:
|
|
215
242
|
local_files_only=local_files_only,
|
216
243
|
revision=revision,
|
217
244
|
)
|
245
|
+
if quantization_config is not None:
|
246
|
+
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
247
|
+
hf_quantizer.validate_environment()
|
248
|
+
|
249
|
+
else:
|
250
|
+
hf_quantizer = None
|
218
251
|
|
219
252
|
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
|
220
253
|
|
221
254
|
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
|
222
|
-
if original_config:
|
255
|
+
if original_config is not None:
|
223
256
|
if "config_mapping_fn" in mapping_functions:
|
224
257
|
config_mapping_fn = mapping_functions["config_mapping_fn"]
|
225
258
|
else:
|
@@ -243,7 +276,7 @@ class FromOriginalModelMixin:
|
|
243
276
|
original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
|
244
277
|
)
|
245
278
|
else:
|
246
|
-
if config:
|
279
|
+
if config is not None:
|
247
280
|
if isinstance(config, str):
|
248
281
|
default_pretrained_model_config_name = config
|
249
282
|
else:
|
@@ -269,6 +302,8 @@ class FromOriginalModelMixin:
|
|
269
302
|
pretrained_model_name_or_path=default_pretrained_model_config_name,
|
270
303
|
subfolder=subfolder,
|
271
304
|
local_files_only=local_files_only,
|
305
|
+
token=token,
|
306
|
+
revision=config_revision,
|
272
307
|
)
|
273
308
|
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
|
274
309
|
|
@@ -295,8 +330,36 @@ class FromOriginalModelMixin:
|
|
295
330
|
with ctx():
|
296
331
|
model = cls.from_config(diffusers_model_config)
|
297
332
|
|
333
|
+
# Check if `_keep_in_fp32_modules` is not None
|
334
|
+
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
335
|
+
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
336
|
+
)
|
337
|
+
if use_keep_in_fp32_modules:
|
338
|
+
keep_in_fp32_modules = cls._keep_in_fp32_modules
|
339
|
+
if not isinstance(keep_in_fp32_modules, list):
|
340
|
+
keep_in_fp32_modules = [keep_in_fp32_modules]
|
341
|
+
|
342
|
+
else:
|
343
|
+
keep_in_fp32_modules = []
|
344
|
+
|
345
|
+
if hf_quantizer is not None:
|
346
|
+
hf_quantizer.preprocess_model(
|
347
|
+
model=model,
|
348
|
+
device_map=None,
|
349
|
+
state_dict=diffusers_format_checkpoint,
|
350
|
+
keep_in_fp32_modules=keep_in_fp32_modules,
|
351
|
+
)
|
352
|
+
|
298
353
|
if is_accelerate_available():
|
299
|
-
|
354
|
+
param_device = torch.device(device) if device else torch.device("cpu")
|
355
|
+
unexpected_keys = load_model_dict_into_meta(
|
356
|
+
model,
|
357
|
+
diffusers_format_checkpoint,
|
358
|
+
dtype=torch_dtype,
|
359
|
+
device=param_device,
|
360
|
+
hf_quantizer=hf_quantizer,
|
361
|
+
keep_in_fp32_modules=keep_in_fp32_modules,
|
362
|
+
)
|
300
363
|
|
301
364
|
else:
|
302
365
|
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
@@ -310,7 +373,11 @@ class FromOriginalModelMixin:
|
|
310
373
|
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
311
374
|
)
|
312
375
|
|
313
|
-
if
|
376
|
+
if hf_quantizer is not None:
|
377
|
+
hf_quantizer.postprocess_model(model)
|
378
|
+
model.hf_quantizer = hf_quantizer
|
379
|
+
|
380
|
+
if torch_dtype is not None and hf_quantizer is None:
|
314
381
|
model.to(torch_dtype)
|
315
382
|
|
316
383
|
model.eval()
|