diffusers 0.24.0__py3-none-any.whl → 0.25.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +11 -1
- diffusers/commands/fp16_safetensors.py +10 -11
- diffusers/configuration_utils.py +12 -8
- diffusers/dependency_versions_table.py +2 -1
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +286 -46
- diffusers/loaders/ip_adapter.py +11 -9
- diffusers/loaders/lora.py +198 -60
- diffusers/loaders/single_file.py +24 -18
- diffusers/loaders/textual_inversion.py +10 -14
- diffusers/loaders/unet.py +130 -37
- diffusers/models/__init__.py +18 -12
- diffusers/models/activations.py +9 -6
- diffusers/models/attention.py +137 -16
- diffusers/models/attention_processor.py +133 -46
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +4 -4
- diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +45 -6
- diffusers/models/{autoencoder_kl_temporal_decoder.py → autoencoders/autoencoder_kl_temporal_decoder.py} +8 -8
- diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +4 -4
- diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +14 -14
- diffusers/models/{vae.py → autoencoders/vae.py} +9 -5
- diffusers/models/downsampling.py +338 -0
- diffusers/models/embeddings.py +112 -29
- diffusers/models/modeling_flax_utils.py +12 -7
- diffusers/models/modeling_utils.py +10 -10
- diffusers/models/normalization.py +108 -2
- diffusers/models/resnet.py +15 -699
- diffusers/models/transformer_2d.py +2 -2
- diffusers/models/unet_2d_condition.py +37 -0
- diffusers/models/{unet_kandi3.py → unet_kandinsky3.py} +105 -159
- diffusers/models/upsampling.py +454 -0
- diffusers/models/uvit_2d.py +471 -0
- diffusers/models/vq_model.py +9 -2
- diffusers/pipelines/__init__.py +81 -73
- diffusers/pipelines/amused/__init__.py +62 -0
- diffusers/pipelines/amused/pipeline_amused.py +328 -0
- diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +38 -10
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/controlnet/pipeline_controlnet.py +27 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +47 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +25 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +4 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +26 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +4 -3
- diffusers/pipelines/deprecated/__init__.py +153 -0
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +91 -18
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +91 -18
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
- diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
- diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
- diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
- diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +4 -4
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
- diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +16 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +11 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +16 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +10 -10
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +13 -13
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +54 -11
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +6 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +6 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +6 -6
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/kandinsky3/__init__.py +4 -4
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
- diffusers/pipelines/kandinsky3/{kandinsky3_pipeline.py → pipeline_kandinsky3.py} +172 -35
- diffusers/pipelines/kandinsky3/{kandinsky3img2img_pipeline.py → pipeline_kandinsky3_img2img.py} +228 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +46 -5
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +47 -6
- diffusers/pipelines/onnx_utils.py +8 -5
- diffusers/pipelines/pipeline_flax_utils.py +7 -6
- diffusers/pipelines/pipeline_utils.py +30 -29
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +51 -2
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/__init__.py +1 -72
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +67 -75
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +92 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +138 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +57 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +6 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +5 -2
- diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +2 -3
- diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +2 -2
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +3 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +6 -1
- diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +50 -7
- diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +56 -8
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
- diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +67 -10
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +97 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +97 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +7 -5
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +12 -9
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +6 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +5 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +331 -9
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +468 -9
- diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +4 -0
- diffusers/schedulers/__init__.py +2 -0
- diffusers/schedulers/scheduling_amused.py +162 -0
- diffusers/schedulers/scheduling_consistency_models.py +2 -0
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -4
- diffusers/schedulers/scheduling_ddpm.py +46 -0
- diffusers/schedulers/scheduling_ddpm_parallel.py +46 -0
- diffusers/schedulers/scheduling_deis_multistep.py +13 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +13 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +13 -1
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -1
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -0
- diffusers/schedulers/scheduling_euler_discrete.py +62 -3
- diffusers/schedulers/scheduling_heun_discrete.py +2 -0
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -0
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -0
- diffusers/schedulers/scheduling_lms_discrete.py +2 -0
- diffusers/schedulers/scheduling_unipc_multistep.py +13 -1
- diffusers/schedulers/scheduling_utils.py +3 -1
- diffusers/schedulers/scheduling_utils_flax.py +3 -1
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +0 -2
- diffusers/utils/constants.py +2 -5
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +45 -0
- diffusers/utils/dynamic_modules_utils.py +14 -18
- diffusers/utils/hub_utils.py +24 -36
- diffusers/utils/logging.py +1 -1
- diffusers/utils/state_dict_utils.py +8 -0
- diffusers/utils/testing_utils.py +199 -1
- diffusers/utils/torch_utils.py +3 -3
- {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/METADATA +54 -53
- {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/RECORD +174 -155
- {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
- {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
- /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
- {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
- {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
diffusers/loaders/ip_adapter.py
CHANGED
@@ -15,11 +15,10 @@ import os
|
|
15
15
|
from typing import Dict, Union
|
16
16
|
|
17
17
|
import torch
|
18
|
+
from huggingface_hub.utils import validate_hf_hub_args
|
18
19
|
from safetensors import safe_open
|
19
20
|
|
20
21
|
from ..utils import (
|
21
|
-
DIFFUSERS_CACHE,
|
22
|
-
HF_HUB_OFFLINE,
|
23
22
|
_get_model_file,
|
24
23
|
is_transformers_available,
|
25
24
|
logging,
|
@@ -43,6 +42,7 @@ logger = logging.get_logger(__name__)
|
|
43
42
|
class IPAdapterMixin:
|
44
43
|
"""Mixin for handling IP Adapters."""
|
45
44
|
|
45
|
+
@validate_hf_hub_args
|
46
46
|
def load_ip_adapter(
|
47
47
|
self,
|
48
48
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
@@ -77,7 +77,7 @@ class IPAdapterMixin:
|
|
77
77
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
78
78
|
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
79
79
|
won't be downloaded from the Hub.
|
80
|
-
|
80
|
+
token (`str` or *bool*, *optional*):
|
81
81
|
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
82
82
|
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
83
83
|
revision (`str`, *optional*, defaults to `"main"`):
|
@@ -88,12 +88,12 @@ class IPAdapterMixin:
|
|
88
88
|
"""
|
89
89
|
|
90
90
|
# Load the main state dict first.
|
91
|
-
cache_dir = kwargs.pop("cache_dir",
|
91
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
92
92
|
force_download = kwargs.pop("force_download", False)
|
93
93
|
resume_download = kwargs.pop("resume_download", False)
|
94
94
|
proxies = kwargs.pop("proxies", None)
|
95
|
-
local_files_only = kwargs.pop("local_files_only",
|
96
|
-
|
95
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
96
|
+
token = kwargs.pop("token", None)
|
97
97
|
revision = kwargs.pop("revision", None)
|
98
98
|
|
99
99
|
user_agent = {
|
@@ -110,7 +110,7 @@ class IPAdapterMixin:
|
|
110
110
|
resume_download=resume_download,
|
111
111
|
proxies=proxies,
|
112
112
|
local_files_only=local_files_only,
|
113
|
-
|
113
|
+
token=token,
|
114
114
|
revision=revision,
|
115
115
|
subfolder=subfolder,
|
116
116
|
user_agent=user_agent,
|
@@ -149,9 +149,11 @@ class IPAdapterMixin:
|
|
149
149
|
self.feature_extractor = CLIPImageProcessor()
|
150
150
|
|
151
151
|
# load ip-adapter into unet
|
152
|
-
self.unet.
|
152
|
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
153
|
+
unet._load_ip_adapter_weights(state_dict)
|
153
154
|
|
154
155
|
def set_ip_adapter_scale(self, scale):
|
155
|
-
|
156
|
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
157
|
+
for attn_processor in unet.attn_processors.values():
|
156
158
|
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
|
157
159
|
attn_processor.scale = scale
|
diffusers/loaders/lora.py
CHANGED
@@ -11,6 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
+
import inspect
|
14
15
|
import os
|
15
16
|
from contextlib import nullcontext
|
16
17
|
from typing import Callable, Dict, List, Optional, Union
|
@@ -18,14 +19,14 @@ from typing import Callable, Dict, List, Optional, Union
|
|
18
19
|
import safetensors
|
19
20
|
import torch
|
20
21
|
from huggingface_hub import model_info
|
22
|
+
from huggingface_hub.constants import HF_HUB_OFFLINE
|
23
|
+
from huggingface_hub.utils import validate_hf_hub_args
|
21
24
|
from packaging import version
|
22
25
|
from torch import nn
|
23
26
|
|
24
27
|
from .. import __version__
|
25
28
|
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
26
29
|
from ..utils import (
|
27
|
-
DIFFUSERS_CACHE,
|
28
|
-
HF_HUB_OFFLINE,
|
29
30
|
USE_PEFT_BACKEND,
|
30
31
|
_get_model_file,
|
31
32
|
convert_state_dict_to_diffusers,
|
@@ -59,6 +60,7 @@ logger = logging.get_logger(__name__)
|
|
59
60
|
|
60
61
|
TEXT_ENCODER_NAME = "text_encoder"
|
61
62
|
UNET_NAME = "unet"
|
63
|
+
TRANSFORMER_NAME = "transformer"
|
62
64
|
|
63
65
|
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
64
66
|
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
@@ -74,6 +76,7 @@ class LoraLoaderMixin:
|
|
74
76
|
|
75
77
|
text_encoder_name = TEXT_ENCODER_NAME
|
76
78
|
unet_name = UNET_NAME
|
79
|
+
transformer_name = TRANSFORMER_NAME
|
77
80
|
num_fused_loras = 0
|
78
81
|
|
79
82
|
def load_lora_weights(
|
@@ -132,6 +135,7 @@ class LoraLoaderMixin:
|
|
132
135
|
)
|
133
136
|
|
134
137
|
@classmethod
|
138
|
+
@validate_hf_hub_args
|
135
139
|
def lora_state_dict(
|
136
140
|
cls,
|
137
141
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
@@ -174,7 +178,7 @@ class LoraLoaderMixin:
|
|
174
178
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
175
179
|
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
176
180
|
won't be downloaded from the Hub.
|
177
|
-
|
181
|
+
token (`str` or *bool*, *optional*):
|
178
182
|
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
179
183
|
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
180
184
|
revision (`str`, *optional*, defaults to `"main"`):
|
@@ -195,12 +199,12 @@ class LoraLoaderMixin:
|
|
195
199
|
"""
|
196
200
|
# Load the main state dict first which has the LoRA layers for either of
|
197
201
|
# UNet and text encoder or both.
|
198
|
-
cache_dir = kwargs.pop("cache_dir",
|
202
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
199
203
|
force_download = kwargs.pop("force_download", False)
|
200
204
|
resume_download = kwargs.pop("resume_download", False)
|
201
205
|
proxies = kwargs.pop("proxies", None)
|
202
|
-
local_files_only = kwargs.pop("local_files_only",
|
203
|
-
|
206
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
207
|
+
token = kwargs.pop("token", None)
|
204
208
|
revision = kwargs.pop("revision", None)
|
205
209
|
subfolder = kwargs.pop("subfolder", None)
|
206
210
|
weight_name = kwargs.pop("weight_name", None)
|
@@ -229,7 +233,9 @@ class LoraLoaderMixin:
|
|
229
233
|
# determine `weight_name`.
|
230
234
|
if weight_name is None:
|
231
235
|
weight_name = cls._best_guess_weight_name(
|
232
|
-
pretrained_model_name_or_path_or_dict,
|
236
|
+
pretrained_model_name_or_path_or_dict,
|
237
|
+
file_extension=".safetensors",
|
238
|
+
local_files_only=local_files_only,
|
233
239
|
)
|
234
240
|
model_file = _get_model_file(
|
235
241
|
pretrained_model_name_or_path_or_dict,
|
@@ -239,7 +245,7 @@ class LoraLoaderMixin:
|
|
239
245
|
resume_download=resume_download,
|
240
246
|
proxies=proxies,
|
241
247
|
local_files_only=local_files_only,
|
242
|
-
|
248
|
+
token=token,
|
243
249
|
revision=revision,
|
244
250
|
subfolder=subfolder,
|
245
251
|
user_agent=user_agent,
|
@@ -255,7 +261,7 @@ class LoraLoaderMixin:
|
|
255
261
|
if model_file is None:
|
256
262
|
if weight_name is None:
|
257
263
|
weight_name = cls._best_guess_weight_name(
|
258
|
-
pretrained_model_name_or_path_or_dict, file_extension=".bin"
|
264
|
+
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
259
265
|
)
|
260
266
|
model_file = _get_model_file(
|
261
267
|
pretrained_model_name_or_path_or_dict,
|
@@ -265,7 +271,7 @@ class LoraLoaderMixin:
|
|
265
271
|
resume_download=resume_download,
|
266
272
|
proxies=proxies,
|
267
273
|
local_files_only=local_files_only,
|
268
|
-
|
274
|
+
token=token,
|
269
275
|
revision=revision,
|
270
276
|
subfolder=subfolder,
|
271
277
|
user_agent=user_agent,
|
@@ -294,7 +300,12 @@ class LoraLoaderMixin:
|
|
294
300
|
return state_dict, network_alphas
|
295
301
|
|
296
302
|
@classmethod
|
297
|
-
def _best_guess_weight_name(
|
303
|
+
def _best_guess_weight_name(
|
304
|
+
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
|
305
|
+
):
|
306
|
+
if local_files_only or HF_HUB_OFFLINE:
|
307
|
+
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
308
|
+
|
298
309
|
targeted_files = []
|
299
310
|
|
300
311
|
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
@@ -391,6 +402,10 @@ class LoraLoaderMixin:
|
|
391
402
|
# their prefixes.
|
392
403
|
keys = list(state_dict.keys())
|
393
404
|
|
405
|
+
if all(key.startswith("unet.unet") for key in keys):
|
406
|
+
deprecation_message = "Keys starting with 'unet.unet' are deprecated."
|
407
|
+
deprecate("unet.unet keys", "0.27", deprecation_message)
|
408
|
+
|
394
409
|
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
|
395
410
|
# Load the layers corresponding to UNet.
|
396
411
|
logger.info(f"Loading {cls.unet_name}.")
|
@@ -407,8 +422,9 @@ class LoraLoaderMixin:
|
|
407
422
|
else:
|
408
423
|
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
409
424
|
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
410
|
-
|
411
|
-
|
425
|
+
if not USE_PEFT_BACKEND:
|
426
|
+
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
427
|
+
logger.warn(warn_message)
|
412
428
|
|
413
429
|
if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
|
414
430
|
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
@@ -648,6 +664,89 @@ class LoraLoaderMixin:
|
|
648
664
|
_pipeline.enable_sequential_cpu_offload()
|
649
665
|
# Unsafe code />
|
650
666
|
|
667
|
+
@classmethod
|
668
|
+
def load_lora_into_transformer(
|
669
|
+
cls, state_dict, network_alphas, transformer, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
|
670
|
+
):
|
671
|
+
"""
|
672
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
673
|
+
|
674
|
+
Parameters:
|
675
|
+
state_dict (`dict`):
|
676
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
677
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
678
|
+
encoder lora layers.
|
679
|
+
network_alphas (`Dict[str, float]`):
|
680
|
+
See `LoRALinearLayer` for more details.
|
681
|
+
unet (`UNet2DConditionModel`):
|
682
|
+
The UNet model to load the LoRA layers into.
|
683
|
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
684
|
+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
685
|
+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
686
|
+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
687
|
+
argument to `True` will raise an error.
|
688
|
+
adapter_name (`str`, *optional*):
|
689
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
690
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
691
|
+
"""
|
692
|
+
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
693
|
+
|
694
|
+
keys = list(state_dict.keys())
|
695
|
+
|
696
|
+
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
|
697
|
+
state_dict = {
|
698
|
+
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
699
|
+
}
|
700
|
+
|
701
|
+
if network_alphas is not None:
|
702
|
+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)]
|
703
|
+
network_alphas = {
|
704
|
+
k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
705
|
+
}
|
706
|
+
|
707
|
+
if len(state_dict.keys()) > 0:
|
708
|
+
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
709
|
+
|
710
|
+
if adapter_name in getattr(transformer, "peft_config", {}):
|
711
|
+
raise ValueError(
|
712
|
+
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
713
|
+
)
|
714
|
+
|
715
|
+
rank = {}
|
716
|
+
for key, val in state_dict.items():
|
717
|
+
if "lora_B" in key:
|
718
|
+
rank[key] = val.shape[1]
|
719
|
+
|
720
|
+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
|
721
|
+
lora_config = LoraConfig(**lora_config_kwargs)
|
722
|
+
|
723
|
+
# adapter_name
|
724
|
+
if adapter_name is None:
|
725
|
+
adapter_name = get_adapter_name(transformer)
|
726
|
+
|
727
|
+
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
728
|
+
# otherwise loading LoRA weights will lead to an error
|
729
|
+
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
730
|
+
|
731
|
+
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
732
|
+
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
733
|
+
|
734
|
+
if incompatible_keys is not None:
|
735
|
+
# check only for unexpected keys
|
736
|
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
737
|
+
if unexpected_keys:
|
738
|
+
logger.warning(
|
739
|
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
740
|
+
f" {unexpected_keys}. "
|
741
|
+
)
|
742
|
+
|
743
|
+
# Offload back.
|
744
|
+
if is_model_cpu_offload:
|
745
|
+
_pipeline.enable_model_cpu_offload()
|
746
|
+
elif is_sequential_cpu_offload:
|
747
|
+
_pipeline.enable_sequential_cpu_offload()
|
748
|
+
# Unsafe code />
|
749
|
+
|
651
750
|
@property
|
652
751
|
def lora_scale(self) -> float:
|
653
752
|
# property function that returns the lora scale which can be set at run time by the pipeline.
|
@@ -675,8 +774,7 @@ class LoraLoaderMixin:
|
|
675
774
|
|
676
775
|
@classmethod
|
677
776
|
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
|
678
|
-
|
679
|
-
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.25", LORA_DEPRECATION_MESSAGE)
|
777
|
+
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.27", LORA_DEPRECATION_MESSAGE)
|
680
778
|
|
681
779
|
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
682
780
|
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
@@ -704,8 +802,7 @@ class LoraLoaderMixin:
|
|
704
802
|
r"""
|
705
803
|
Monkey-patches the forward passes of attention modules of the text encoder.
|
706
804
|
"""
|
707
|
-
|
708
|
-
deprecate("_modify_text_encoder", "0.25", LORA_DEPRECATION_MESSAGE)
|
805
|
+
deprecate("_modify_text_encoder", "0.27", LORA_DEPRECATION_MESSAGE)
|
709
806
|
|
710
807
|
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
|
711
808
|
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
|
@@ -775,6 +872,7 @@ class LoraLoaderMixin:
|
|
775
872
|
save_directory: Union[str, os.PathLike],
|
776
873
|
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
777
874
|
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
875
|
+
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
778
876
|
is_main_process: bool = True,
|
779
877
|
weight_name: str = None,
|
780
878
|
save_function: Callable = None,
|
@@ -802,29 +900,26 @@ class LoraLoaderMixin:
|
|
802
900
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
803
901
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
804
902
|
"""
|
805
|
-
# Create a flat dictionary.
|
806
903
|
state_dict = {}
|
807
904
|
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
905
|
+
def pack_weights(layers, prefix):
|
906
|
+
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
907
|
+
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
908
|
+
return layers_state_dict
|
909
|
+
|
910
|
+
if not (unet_lora_layers or text_encoder_lora_layers or transformer_lora_layers):
|
911
|
+
raise ValueError(
|
912
|
+
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `transformer_lora_layers`."
|
812
913
|
)
|
813
914
|
|
814
|
-
|
815
|
-
state_dict.update(
|
915
|
+
if unet_lora_layers:
|
916
|
+
state_dict.update(pack_weights(unet_lora_layers, cls.unet_name))
|
816
917
|
|
817
|
-
if text_encoder_lora_layers
|
818
|
-
|
819
|
-
text_encoder_lora_layers.state_dict()
|
820
|
-
if isinstance(text_encoder_lora_layers, torch.nn.Module)
|
821
|
-
else text_encoder_lora_layers
|
822
|
-
)
|
918
|
+
if text_encoder_lora_layers:
|
919
|
+
state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
823
920
|
|
824
|
-
|
825
|
-
|
826
|
-
}
|
827
|
-
state_dict.update(text_encoder_lora_state_dict)
|
921
|
+
if transformer_lora_layers:
|
922
|
+
state_dict.update(pack_weights(transformer_lora_layers, "transformer"))
|
828
923
|
|
829
924
|
# Save the model
|
830
925
|
cls.write_lora_layers(
|
@@ -881,6 +976,8 @@ class LoraLoaderMixin:
|
|
881
976
|
>>> ...
|
882
977
|
```
|
883
978
|
"""
|
979
|
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
980
|
+
|
884
981
|
if not USE_PEFT_BACKEND:
|
885
982
|
if version.parse(__version__) > version.parse("0.23"):
|
886
983
|
logger.warn(
|
@@ -888,13 +985,13 @@ class LoraLoaderMixin:
|
|
888
985
|
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
|
889
986
|
)
|
890
987
|
|
891
|
-
for _, module in
|
988
|
+
for _, module in unet.named_modules():
|
892
989
|
if hasattr(module, "set_lora_layer"):
|
893
990
|
module.set_lora_layer(None)
|
894
991
|
else:
|
895
|
-
recurse_remove_peft_layers(
|
896
|
-
if hasattr(
|
897
|
-
del
|
992
|
+
recurse_remove_peft_layers(unet)
|
993
|
+
if hasattr(unet, "peft_config"):
|
994
|
+
del unet.peft_config
|
898
995
|
|
899
996
|
# Safe to call the following regardless of LoRA.
|
900
997
|
self._remove_text_encoder_monkey_patch()
|
@@ -905,6 +1002,7 @@ class LoraLoaderMixin:
|
|
905
1002
|
fuse_text_encoder: bool = True,
|
906
1003
|
lora_scale: float = 1.0,
|
907
1004
|
safe_fusing: bool = False,
|
1005
|
+
adapter_names: Optional[List[str]] = None,
|
908
1006
|
):
|
909
1007
|
r"""
|
910
1008
|
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
@@ -924,6 +1022,21 @@ class LoraLoaderMixin:
|
|
924
1022
|
Controls how much to influence the outputs with the LoRA parameters.
|
925
1023
|
safe_fusing (`bool`, defaults to `False`):
|
926
1024
|
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
1025
|
+
adapter_names (`List[str]`, *optional*):
|
1026
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
1027
|
+
|
1028
|
+
Example:
|
1029
|
+
|
1030
|
+
```py
|
1031
|
+
from diffusers import DiffusionPipeline
|
1032
|
+
import torch
|
1033
|
+
|
1034
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
1035
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
1036
|
+
).to("cuda")
|
1037
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
1038
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
1039
|
+
```
|
927
1040
|
"""
|
928
1041
|
if fuse_unet or fuse_text_encoder:
|
929
1042
|
self.num_fused_loras += 1
|
@@ -933,25 +1046,44 @@ class LoraLoaderMixin:
|
|
933
1046
|
)
|
934
1047
|
|
935
1048
|
if fuse_unet:
|
936
|
-
self.
|
1049
|
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
1050
|
+
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
937
1051
|
|
938
1052
|
if USE_PEFT_BACKEND:
|
939
1053
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
940
1054
|
|
941
|
-
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
|
942
|
-
|
1055
|
+
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
1056
|
+
merge_kwargs = {"safe_merge": safe_fusing}
|
1057
|
+
|
943
1058
|
for module in text_encoder.modules():
|
944
1059
|
if isinstance(module, BaseTunerLayer):
|
945
1060
|
if lora_scale != 1.0:
|
946
1061
|
module.scale_layer(lora_scale)
|
947
1062
|
|
948
|
-
|
1063
|
+
# For BC with previous PEFT versions, we need to check the signature
|
1064
|
+
# of the `merge` method to see if it supports the `adapter_names` argument.
|
1065
|
+
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
1066
|
+
if "adapter_names" in supported_merge_kwargs:
|
1067
|
+
merge_kwargs["adapter_names"] = adapter_names
|
1068
|
+
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
1069
|
+
raise ValueError(
|
1070
|
+
"The `adapter_names` argument is not supported with your PEFT version. "
|
1071
|
+
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
1072
|
+
)
|
1073
|
+
|
1074
|
+
module.merge(**merge_kwargs)
|
949
1075
|
|
950
1076
|
else:
|
951
|
-
|
952
|
-
|
1077
|
+
deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
|
1078
|
+
|
1079
|
+
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, **kwargs):
|
1080
|
+
if "adapter_names" in kwargs and kwargs["adapter_names"] is not None:
|
1081
|
+
raise ValueError(
|
1082
|
+
"The `adapter_names` argument is not supported in your environment. Please switch to PEFT "
|
1083
|
+
"backend to use this argument by installing latest PEFT and transformers."
|
1084
|
+
" `pip install -U peft transformers`"
|
1085
|
+
)
|
953
1086
|
|
954
|
-
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
|
955
1087
|
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
956
1088
|
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
957
1089
|
attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
|
@@ -966,9 +1098,9 @@ class LoraLoaderMixin:
|
|
966
1098
|
|
967
1099
|
if fuse_text_encoder:
|
968
1100
|
if hasattr(self, "text_encoder"):
|
969
|
-
fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing)
|
1101
|
+
fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing, adapter_names=adapter_names)
|
970
1102
|
if hasattr(self, "text_encoder_2"):
|
971
|
-
fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing)
|
1103
|
+
fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing, adapter_names=adapter_names)
|
972
1104
|
|
973
1105
|
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
|
974
1106
|
r"""
|
@@ -987,13 +1119,14 @@ class LoraLoaderMixin:
|
|
987
1119
|
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
988
1120
|
LoRA parameters then it won't have any effect.
|
989
1121
|
"""
|
1122
|
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
990
1123
|
if unfuse_unet:
|
991
1124
|
if not USE_PEFT_BACKEND:
|
992
|
-
|
1125
|
+
unet.unfuse_lora()
|
993
1126
|
else:
|
994
1127
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
995
1128
|
|
996
|
-
for module in
|
1129
|
+
for module in unet.modules():
|
997
1130
|
if isinstance(module, BaseTunerLayer):
|
998
1131
|
module.unmerge()
|
999
1132
|
|
@@ -1006,8 +1139,7 @@ class LoraLoaderMixin:
|
|
1006
1139
|
module.unmerge()
|
1007
1140
|
|
1008
1141
|
else:
|
1009
|
-
|
1010
|
-
deprecate("unfuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
|
1142
|
+
deprecate("unfuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
|
1011
1143
|
|
1012
1144
|
def unfuse_text_encoder_lora(text_encoder):
|
1013
1145
|
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
@@ -1110,8 +1242,9 @@ class LoraLoaderMixin:
|
|
1110
1242
|
adapter_names: Union[List[str], str],
|
1111
1243
|
adapter_weights: Optional[List[float]] = None,
|
1112
1244
|
):
|
1245
|
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
1113
1246
|
# Handle the UNET
|
1114
|
-
|
1247
|
+
unet.set_adapters(adapter_names, adapter_weights)
|
1115
1248
|
|
1116
1249
|
# Handle the Text Encoder
|
1117
1250
|
if hasattr(self, "text_encoder"):
|
@@ -1124,7 +1257,8 @@ class LoraLoaderMixin:
|
|
1124
1257
|
raise ValueError("PEFT backend is required for this method.")
|
1125
1258
|
|
1126
1259
|
# Disable unet adapters
|
1127
|
-
self.unet.
|
1260
|
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
1261
|
+
unet.disable_lora()
|
1128
1262
|
|
1129
1263
|
# Disable text encoder adapters
|
1130
1264
|
if hasattr(self, "text_encoder"):
|
@@ -1137,7 +1271,8 @@ class LoraLoaderMixin:
|
|
1137
1271
|
raise ValueError("PEFT backend is required for this method.")
|
1138
1272
|
|
1139
1273
|
# Enable unet adapters
|
1140
|
-
self.unet.
|
1274
|
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
1275
|
+
unet.enable_lora()
|
1141
1276
|
|
1142
1277
|
# Enable text encoder adapters
|
1143
1278
|
if hasattr(self, "text_encoder"):
|
@@ -1159,7 +1294,8 @@ class LoraLoaderMixin:
|
|
1159
1294
|
adapter_names = [adapter_names]
|
1160
1295
|
|
1161
1296
|
# Delete unet adapters
|
1162
|
-
self.unet.
|
1297
|
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
1298
|
+
unet.delete_adapters(adapter_names)
|
1163
1299
|
|
1164
1300
|
for adapter_name in adapter_names:
|
1165
1301
|
# Delete text encoder adapters
|
@@ -1192,8 +1328,8 @@ class LoraLoaderMixin:
|
|
1192
1328
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
1193
1329
|
|
1194
1330
|
active_adapters = []
|
1195
|
-
|
1196
|
-
for module in
|
1331
|
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
1332
|
+
for module in unet.modules():
|
1197
1333
|
if isinstance(module, BaseTunerLayer):
|
1198
1334
|
active_adapters = module.active_adapters
|
1199
1335
|
break
|
@@ -1217,8 +1353,9 @@ class LoraLoaderMixin:
|
|
1217
1353
|
if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"):
|
1218
1354
|
set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys())
|
1219
1355
|
|
1220
|
-
if hasattr(self, "unet")
|
1221
|
-
|
1356
|
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
1357
|
+
if hasattr(self, self.unet_name) and hasattr(unet, "peft_config"):
|
1358
|
+
set_adapters[self.unet_name] = list(self.unet.peft_config.keys())
|
1222
1359
|
|
1223
1360
|
return set_adapters
|
1224
1361
|
|
@@ -1239,7 +1376,8 @@ class LoraLoaderMixin:
|
|
1239
1376
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
1240
1377
|
|
1241
1378
|
# Handle the UNET
|
1242
|
-
|
1379
|
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
1380
|
+
for unet_module in unet.modules():
|
1243
1381
|
if isinstance(unet_module, BaseTunerLayer):
|
1244
1382
|
for adapter_name in adapter_names:
|
1245
1383
|
unet_module.lora_A[adapter_name].to(device)
|