diffusers 0.24.0__py3-none-any.whl → 0.25.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (174) hide show
  1. diffusers/__init__.py +11 -1
  2. diffusers/commands/fp16_safetensors.py +10 -11
  3. diffusers/configuration_utils.py +12 -8
  4. diffusers/dependency_versions_table.py +3 -2
  5. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  6. diffusers/image_processor.py +286 -46
  7. diffusers/loaders/ip_adapter.py +11 -9
  8. diffusers/loaders/lora.py +198 -60
  9. diffusers/loaders/single_file.py +24 -18
  10. diffusers/loaders/textual_inversion.py +10 -14
  11. diffusers/loaders/unet.py +130 -37
  12. diffusers/models/__init__.py +18 -12
  13. diffusers/models/activations.py +9 -6
  14. diffusers/models/attention.py +137 -16
  15. diffusers/models/attention_processor.py +133 -46
  16. diffusers/models/autoencoders/__init__.py +5 -0
  17. diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +4 -4
  18. diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +45 -6
  19. diffusers/models/{autoencoder_kl_temporal_decoder.py → autoencoders/autoencoder_kl_temporal_decoder.py} +8 -8
  20. diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +4 -4
  21. diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +14 -14
  22. diffusers/models/{vae.py → autoencoders/vae.py} +9 -5
  23. diffusers/models/downsampling.py +338 -0
  24. diffusers/models/embeddings.py +112 -29
  25. diffusers/models/modeling_flax_utils.py +12 -7
  26. diffusers/models/modeling_utils.py +10 -10
  27. diffusers/models/normalization.py +108 -2
  28. diffusers/models/resnet.py +15 -699
  29. diffusers/models/transformer_2d.py +2 -2
  30. diffusers/models/unet_2d_condition.py +37 -0
  31. diffusers/models/{unet_kandi3.py → unet_kandinsky3.py} +105 -159
  32. diffusers/models/upsampling.py +454 -0
  33. diffusers/models/uvit_2d.py +471 -0
  34. diffusers/models/vq_model.py +9 -2
  35. diffusers/pipelines/__init__.py +81 -73
  36. diffusers/pipelines/amused/__init__.py +62 -0
  37. diffusers/pipelines/amused/pipeline_amused.py +328 -0
  38. diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
  39. diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +38 -10
  41. diffusers/pipelines/auto_pipeline.py +17 -13
  42. diffusers/pipelines/controlnet/pipeline_controlnet.py +27 -10
  43. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +47 -5
  44. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +25 -8
  45. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +4 -6
  46. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +26 -10
  47. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +4 -3
  48. diffusers/pipelines/deprecated/__init__.py +153 -0
  49. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
  50. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +91 -18
  51. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +91 -18
  52. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
  53. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
  54. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
  55. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
  56. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
  57. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
  58. diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
  59. diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
  60. diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
  61. diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
  62. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
  63. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +4 -4
  64. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
  65. diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
  66. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
  67. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
  68. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +7 -7
  69. diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
  70. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +16 -11
  71. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +6 -6
  72. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +11 -11
  73. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +16 -11
  74. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +10 -10
  75. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +13 -13
  76. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
  77. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
  78. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
  79. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +54 -11
  80. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
  81. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +6 -6
  82. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +6 -6
  83. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +6 -6
  84. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
  85. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
  86. diffusers/pipelines/kandinsky3/__init__.py +4 -4
  87. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
  88. diffusers/pipelines/kandinsky3/{kandinsky3_pipeline.py → pipeline_kandinsky3.py} +172 -35
  89. diffusers/pipelines/kandinsky3/{kandinsky3img2img_pipeline.py → pipeline_kandinsky3_img2img.py} +228 -34
  90. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +46 -5
  91. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +47 -6
  92. diffusers/pipelines/onnx_utils.py +8 -5
  93. diffusers/pipelines/pipeline_flax_utils.py +7 -6
  94. diffusers/pipelines/pipeline_utils.py +32 -31
  95. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +51 -2
  96. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
  97. diffusers/pipelines/stable_diffusion/__init__.py +1 -72
  98. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +67 -75
  99. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +92 -8
  100. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -8
  101. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +138 -10
  102. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +57 -7
  103. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -0
  104. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +6 -0
  105. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
  106. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
  107. diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
  108. diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +5 -2
  109. diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
  110. diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +2 -3
  111. diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
  112. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +2 -2
  113. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +3 -3
  114. diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
  115. diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +6 -1
  116. diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
  117. diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +50 -7
  118. diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
  119. diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +56 -8
  120. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
  121. diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
  122. diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +67 -10
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +97 -15
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +97 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +7 -5
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +12 -9
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +6 -0
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -0
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +5 -0
  131. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +331 -9
  132. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +468 -9
  133. diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
  134. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
  135. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  136. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +4 -0
  137. diffusers/schedulers/__init__.py +2 -0
  138. diffusers/schedulers/scheduling_amused.py +162 -0
  139. diffusers/schedulers/scheduling_consistency_models.py +2 -0
  140. diffusers/schedulers/scheduling_ddim_inverse.py +1 -4
  141. diffusers/schedulers/scheduling_ddpm.py +46 -0
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +46 -0
  143. diffusers/schedulers/scheduling_deis_multistep.py +13 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +13 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +13 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -0
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -1
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -0
  149. diffusers/schedulers/scheduling_euler_discrete.py +62 -3
  150. diffusers/schedulers/scheduling_heun_discrete.py +2 -0
  151. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -0
  152. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -0
  153. diffusers/schedulers/scheduling_lms_discrete.py +2 -0
  154. diffusers/schedulers/scheduling_unipc_multistep.py +13 -1
  155. diffusers/schedulers/scheduling_utils.py +3 -1
  156. diffusers/schedulers/scheduling_utils_flax.py +3 -1
  157. diffusers/training_utils.py +1 -1
  158. diffusers/utils/__init__.py +0 -2
  159. diffusers/utils/constants.py +2 -5
  160. diffusers/utils/dummy_pt_objects.py +30 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +45 -0
  162. diffusers/utils/dynamic_modules_utils.py +14 -18
  163. diffusers/utils/hub_utils.py +24 -36
  164. diffusers/utils/logging.py +1 -1
  165. diffusers/utils/state_dict_utils.py +8 -0
  166. diffusers/utils/testing_utils.py +199 -1
  167. diffusers/utils/torch_utils.py +3 -3
  168. {diffusers-0.24.0.dist-info → diffusers-0.25.1.dist-info}/METADATA +55 -53
  169. {diffusers-0.24.0.dist-info → diffusers-0.25.1.dist-info}/RECORD +174 -155
  170. {diffusers-0.24.0.dist-info → diffusers-0.25.1.dist-info}/WHEEL +1 -1
  171. {diffusers-0.24.0.dist-info → diffusers-0.25.1.dist-info}/entry_points.txt +0 -1
  172. /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
  173. {diffusers-0.24.0.dist-info → diffusers-0.25.1.dist-info}/LICENSE +0 -0
  174. {diffusers-0.24.0.dist-info → diffusers-0.25.1.dist-info}/top_level.txt +0 -0
@@ -15,11 +15,10 @@ import os
15
15
  from typing import Dict, Union
16
16
 
17
17
  import torch
18
+ from huggingface_hub.utils import validate_hf_hub_args
18
19
  from safetensors import safe_open
19
20
 
20
21
  from ..utils import (
21
- DIFFUSERS_CACHE,
22
- HF_HUB_OFFLINE,
23
22
  _get_model_file,
24
23
  is_transformers_available,
25
24
  logging,
@@ -43,6 +42,7 @@ logger = logging.get_logger(__name__)
43
42
  class IPAdapterMixin:
44
43
  """Mixin for handling IP Adapters."""
45
44
 
45
+ @validate_hf_hub_args
46
46
  def load_ip_adapter(
47
47
  self,
48
48
  pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -77,7 +77,7 @@ class IPAdapterMixin:
77
77
  local_files_only (`bool`, *optional*, defaults to `False`):
78
78
  Whether to only load local model weights and configuration files or not. If set to `True`, the model
79
79
  won't be downloaded from the Hub.
80
- use_auth_token (`str` or *bool*, *optional*):
80
+ token (`str` or *bool*, *optional*):
81
81
  The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
82
82
  `diffusers-cli login` (stored in `~/.huggingface`) is used.
83
83
  revision (`str`, *optional*, defaults to `"main"`):
@@ -88,12 +88,12 @@ class IPAdapterMixin:
88
88
  """
89
89
 
90
90
  # Load the main state dict first.
91
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
91
+ cache_dir = kwargs.pop("cache_dir", None)
92
92
  force_download = kwargs.pop("force_download", False)
93
93
  resume_download = kwargs.pop("resume_download", False)
94
94
  proxies = kwargs.pop("proxies", None)
95
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
96
- use_auth_token = kwargs.pop("use_auth_token", None)
95
+ local_files_only = kwargs.pop("local_files_only", None)
96
+ token = kwargs.pop("token", None)
97
97
  revision = kwargs.pop("revision", None)
98
98
 
99
99
  user_agent = {
@@ -110,7 +110,7 @@ class IPAdapterMixin:
110
110
  resume_download=resume_download,
111
111
  proxies=proxies,
112
112
  local_files_only=local_files_only,
113
- use_auth_token=use_auth_token,
113
+ token=token,
114
114
  revision=revision,
115
115
  subfolder=subfolder,
116
116
  user_agent=user_agent,
@@ -149,9 +149,11 @@ class IPAdapterMixin:
149
149
  self.feature_extractor = CLIPImageProcessor()
150
150
 
151
151
  # load ip-adapter into unet
152
- self.unet._load_ip_adapter_weights(state_dict)
152
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
153
+ unet._load_ip_adapter_weights(state_dict)
153
154
 
154
155
  def set_ip_adapter_scale(self, scale):
155
- for attn_processor in self.unet.attn_processors.values():
156
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
157
+ for attn_processor in unet.attn_processors.values():
156
158
  if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
157
159
  attn_processor.scale = scale
diffusers/loaders/lora.py CHANGED
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import inspect
14
15
  import os
15
16
  from contextlib import nullcontext
16
17
  from typing import Callable, Dict, List, Optional, Union
@@ -18,14 +19,14 @@ from typing import Callable, Dict, List, Optional, Union
18
19
  import safetensors
19
20
  import torch
20
21
  from huggingface_hub import model_info
22
+ from huggingface_hub.constants import HF_HUB_OFFLINE
23
+ from huggingface_hub.utils import validate_hf_hub_args
21
24
  from packaging import version
22
25
  from torch import nn
23
26
 
24
27
  from .. import __version__
25
28
  from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
26
29
  from ..utils import (
27
- DIFFUSERS_CACHE,
28
- HF_HUB_OFFLINE,
29
30
  USE_PEFT_BACKEND,
30
31
  _get_model_file,
31
32
  convert_state_dict_to_diffusers,
@@ -59,6 +60,7 @@ logger = logging.get_logger(__name__)
59
60
 
60
61
  TEXT_ENCODER_NAME = "text_encoder"
61
62
  UNET_NAME = "unet"
63
+ TRANSFORMER_NAME = "transformer"
62
64
 
63
65
  LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
64
66
  LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
@@ -74,6 +76,7 @@ class LoraLoaderMixin:
74
76
 
75
77
  text_encoder_name = TEXT_ENCODER_NAME
76
78
  unet_name = UNET_NAME
79
+ transformer_name = TRANSFORMER_NAME
77
80
  num_fused_loras = 0
78
81
 
79
82
  def load_lora_weights(
@@ -132,6 +135,7 @@ class LoraLoaderMixin:
132
135
  )
133
136
 
134
137
  @classmethod
138
+ @validate_hf_hub_args
135
139
  def lora_state_dict(
136
140
  cls,
137
141
  pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -174,7 +178,7 @@ class LoraLoaderMixin:
174
178
  local_files_only (`bool`, *optional*, defaults to `False`):
175
179
  Whether to only load local model weights and configuration files or not. If set to `True`, the model
176
180
  won't be downloaded from the Hub.
177
- use_auth_token (`str` or *bool*, *optional*):
181
+ token (`str` or *bool*, *optional*):
178
182
  The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
179
183
  `diffusers-cli login` (stored in `~/.huggingface`) is used.
180
184
  revision (`str`, *optional*, defaults to `"main"`):
@@ -195,12 +199,12 @@ class LoraLoaderMixin:
195
199
  """
196
200
  # Load the main state dict first which has the LoRA layers for either of
197
201
  # UNet and text encoder or both.
198
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
202
+ cache_dir = kwargs.pop("cache_dir", None)
199
203
  force_download = kwargs.pop("force_download", False)
200
204
  resume_download = kwargs.pop("resume_download", False)
201
205
  proxies = kwargs.pop("proxies", None)
202
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
203
- use_auth_token = kwargs.pop("use_auth_token", None)
206
+ local_files_only = kwargs.pop("local_files_only", None)
207
+ token = kwargs.pop("token", None)
204
208
  revision = kwargs.pop("revision", None)
205
209
  subfolder = kwargs.pop("subfolder", None)
206
210
  weight_name = kwargs.pop("weight_name", None)
@@ -229,7 +233,9 @@ class LoraLoaderMixin:
229
233
  # determine `weight_name`.
230
234
  if weight_name is None:
231
235
  weight_name = cls._best_guess_weight_name(
232
- pretrained_model_name_or_path_or_dict, file_extension=".safetensors"
236
+ pretrained_model_name_or_path_or_dict,
237
+ file_extension=".safetensors",
238
+ local_files_only=local_files_only,
233
239
  )
234
240
  model_file = _get_model_file(
235
241
  pretrained_model_name_or_path_or_dict,
@@ -239,7 +245,7 @@ class LoraLoaderMixin:
239
245
  resume_download=resume_download,
240
246
  proxies=proxies,
241
247
  local_files_only=local_files_only,
242
- use_auth_token=use_auth_token,
248
+ token=token,
243
249
  revision=revision,
244
250
  subfolder=subfolder,
245
251
  user_agent=user_agent,
@@ -255,7 +261,7 @@ class LoraLoaderMixin:
255
261
  if model_file is None:
256
262
  if weight_name is None:
257
263
  weight_name = cls._best_guess_weight_name(
258
- pretrained_model_name_or_path_or_dict, file_extension=".bin"
264
+ pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
259
265
  )
260
266
  model_file = _get_model_file(
261
267
  pretrained_model_name_or_path_or_dict,
@@ -265,7 +271,7 @@ class LoraLoaderMixin:
265
271
  resume_download=resume_download,
266
272
  proxies=proxies,
267
273
  local_files_only=local_files_only,
268
- use_auth_token=use_auth_token,
274
+ token=token,
269
275
  revision=revision,
270
276
  subfolder=subfolder,
271
277
  user_agent=user_agent,
@@ -294,7 +300,12 @@ class LoraLoaderMixin:
294
300
  return state_dict, network_alphas
295
301
 
296
302
  @classmethod
297
- def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"):
303
+ def _best_guess_weight_name(
304
+ cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
305
+ ):
306
+ if local_files_only or HF_HUB_OFFLINE:
307
+ raise ValueError("When using the offline mode, you must specify a `weight_name`.")
308
+
298
309
  targeted_files = []
299
310
 
300
311
  if os.path.isfile(pretrained_model_name_or_path_or_dict):
@@ -391,6 +402,10 @@ class LoraLoaderMixin:
391
402
  # their prefixes.
392
403
  keys = list(state_dict.keys())
393
404
 
405
+ if all(key.startswith("unet.unet") for key in keys):
406
+ deprecation_message = "Keys starting with 'unet.unet' are deprecated."
407
+ deprecate("unet.unet keys", "0.27", deprecation_message)
408
+
394
409
  if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
395
410
  # Load the layers corresponding to UNet.
396
411
  logger.info(f"Loading {cls.unet_name}.")
@@ -407,8 +422,9 @@ class LoraLoaderMixin:
407
422
  else:
408
423
  # Otherwise, we're dealing with the old format. This means the `state_dict` should only
409
424
  # contain the module names of the `unet` as its keys WITHOUT any prefix.
410
- warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
411
- logger.warn(warn_message)
425
+ if not USE_PEFT_BACKEND:
426
+ warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
427
+ logger.warn(warn_message)
412
428
 
413
429
  if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
414
430
  from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
@@ -648,6 +664,89 @@ class LoraLoaderMixin:
648
664
  _pipeline.enable_sequential_cpu_offload()
649
665
  # Unsafe code />
650
666
 
667
+ @classmethod
668
+ def load_lora_into_transformer(
669
+ cls, state_dict, network_alphas, transformer, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
670
+ ):
671
+ """
672
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
673
+
674
+ Parameters:
675
+ state_dict (`dict`):
676
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
677
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
678
+ encoder lora layers.
679
+ network_alphas (`Dict[str, float]`):
680
+ See `LoRALinearLayer` for more details.
681
+ unet (`UNet2DConditionModel`):
682
+ The UNet model to load the LoRA layers into.
683
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
684
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
685
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
686
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
687
+ argument to `True` will raise an error.
688
+ adapter_name (`str`, *optional*):
689
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
690
+ `default_{i}` where i is the total number of adapters being loaded.
691
+ """
692
+ low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
693
+
694
+ keys = list(state_dict.keys())
695
+
696
+ transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
697
+ state_dict = {
698
+ k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
699
+ }
700
+
701
+ if network_alphas is not None:
702
+ alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)]
703
+ network_alphas = {
704
+ k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
705
+ }
706
+
707
+ if len(state_dict.keys()) > 0:
708
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
709
+
710
+ if adapter_name in getattr(transformer, "peft_config", {}):
711
+ raise ValueError(
712
+ f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
713
+ )
714
+
715
+ rank = {}
716
+ for key, val in state_dict.items():
717
+ if "lora_B" in key:
718
+ rank[key] = val.shape[1]
719
+
720
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
721
+ lora_config = LoraConfig(**lora_config_kwargs)
722
+
723
+ # adapter_name
724
+ if adapter_name is None:
725
+ adapter_name = get_adapter_name(transformer)
726
+
727
+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
728
+ # otherwise loading LoRA weights will lead to an error
729
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
730
+
731
+ inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
732
+ incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
733
+
734
+ if incompatible_keys is not None:
735
+ # check only for unexpected keys
736
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
737
+ if unexpected_keys:
738
+ logger.warning(
739
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
740
+ f" {unexpected_keys}. "
741
+ )
742
+
743
+ # Offload back.
744
+ if is_model_cpu_offload:
745
+ _pipeline.enable_model_cpu_offload()
746
+ elif is_sequential_cpu_offload:
747
+ _pipeline.enable_sequential_cpu_offload()
748
+ # Unsafe code />
749
+
651
750
  @property
652
751
  def lora_scale(self) -> float:
653
752
  # property function that returns the lora scale which can be set at run time by the pipeline.
@@ -675,8 +774,7 @@ class LoraLoaderMixin:
675
774
 
676
775
  @classmethod
677
776
  def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
678
- if version.parse(__version__) > version.parse("0.23"):
679
- deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.25", LORA_DEPRECATION_MESSAGE)
777
+ deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.27", LORA_DEPRECATION_MESSAGE)
680
778
 
681
779
  for _, attn_module in text_encoder_attn_modules(text_encoder):
682
780
  if isinstance(attn_module.q_proj, PatchedLoraProjection):
@@ -704,8 +802,7 @@ class LoraLoaderMixin:
704
802
  r"""
705
803
  Monkey-patches the forward passes of attention modules of the text encoder.
706
804
  """
707
- if version.parse(__version__) > version.parse("0.23"):
708
- deprecate("_modify_text_encoder", "0.25", LORA_DEPRECATION_MESSAGE)
805
+ deprecate("_modify_text_encoder", "0.27", LORA_DEPRECATION_MESSAGE)
709
806
 
710
807
  def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
711
808
  linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
@@ -775,6 +872,7 @@ class LoraLoaderMixin:
775
872
  save_directory: Union[str, os.PathLike],
776
873
  unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
777
874
  text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
875
+ transformer_lora_layers: Dict[str, torch.nn.Module] = None,
778
876
  is_main_process: bool = True,
779
877
  weight_name: str = None,
780
878
  save_function: Callable = None,
@@ -802,29 +900,26 @@ class LoraLoaderMixin:
802
900
  safe_serialization (`bool`, *optional*, defaults to `True`):
803
901
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
804
902
  """
805
- # Create a flat dictionary.
806
903
  state_dict = {}
807
904
 
808
- # Populate the dictionary.
809
- if unet_lora_layers is not None:
810
- weights = (
811
- unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
905
+ def pack_weights(layers, prefix):
906
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
907
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
908
+ return layers_state_dict
909
+
910
+ if not (unet_lora_layers or text_encoder_lora_layers or transformer_lora_layers):
911
+ raise ValueError(
912
+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `transformer_lora_layers`."
812
913
  )
813
914
 
814
- unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()}
815
- state_dict.update(unet_lora_state_dict)
915
+ if unet_lora_layers:
916
+ state_dict.update(pack_weights(unet_lora_layers, cls.unet_name))
816
917
 
817
- if text_encoder_lora_layers is not None:
818
- weights = (
819
- text_encoder_lora_layers.state_dict()
820
- if isinstance(text_encoder_lora_layers, torch.nn.Module)
821
- else text_encoder_lora_layers
822
- )
918
+ if text_encoder_lora_layers:
919
+ state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
823
920
 
824
- text_encoder_lora_state_dict = {
825
- f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
826
- }
827
- state_dict.update(text_encoder_lora_state_dict)
921
+ if transformer_lora_layers:
922
+ state_dict.update(pack_weights(transformer_lora_layers, "transformer"))
828
923
 
829
924
  # Save the model
830
925
  cls.write_lora_layers(
@@ -881,6 +976,8 @@ class LoraLoaderMixin:
881
976
  >>> ...
882
977
  ```
883
978
  """
979
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
980
+
884
981
  if not USE_PEFT_BACKEND:
885
982
  if version.parse(__version__) > version.parse("0.23"):
886
983
  logger.warn(
@@ -888,13 +985,13 @@ class LoraLoaderMixin:
888
985
  "you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
889
986
  )
890
987
 
891
- for _, module in self.unet.named_modules():
988
+ for _, module in unet.named_modules():
892
989
  if hasattr(module, "set_lora_layer"):
893
990
  module.set_lora_layer(None)
894
991
  else:
895
- recurse_remove_peft_layers(self.unet)
896
- if hasattr(self.unet, "peft_config"):
897
- del self.unet.peft_config
992
+ recurse_remove_peft_layers(unet)
993
+ if hasattr(unet, "peft_config"):
994
+ del unet.peft_config
898
995
 
899
996
  # Safe to call the following regardless of LoRA.
900
997
  self._remove_text_encoder_monkey_patch()
@@ -905,6 +1002,7 @@ class LoraLoaderMixin:
905
1002
  fuse_text_encoder: bool = True,
906
1003
  lora_scale: float = 1.0,
907
1004
  safe_fusing: bool = False,
1005
+ adapter_names: Optional[List[str]] = None,
908
1006
  ):
909
1007
  r"""
910
1008
  Fuses the LoRA parameters into the original parameters of the corresponding blocks.
@@ -924,6 +1022,21 @@ class LoraLoaderMixin:
924
1022
  Controls how much to influence the outputs with the LoRA parameters.
925
1023
  safe_fusing (`bool`, defaults to `False`):
926
1024
  Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
1025
+ adapter_names (`List[str]`, *optional*):
1026
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
1027
+
1028
+ Example:
1029
+
1030
+ ```py
1031
+ from diffusers import DiffusionPipeline
1032
+ import torch
1033
+
1034
+ pipeline = DiffusionPipeline.from_pretrained(
1035
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
1036
+ ).to("cuda")
1037
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
1038
+ pipeline.fuse_lora(lora_scale=0.7)
1039
+ ```
927
1040
  """
928
1041
  if fuse_unet or fuse_text_encoder:
929
1042
  self.num_fused_loras += 1
@@ -933,25 +1046,44 @@ class LoraLoaderMixin:
933
1046
  )
934
1047
 
935
1048
  if fuse_unet:
936
- self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
1049
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1050
+ unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
937
1051
 
938
1052
  if USE_PEFT_BACKEND:
939
1053
  from peft.tuners.tuners_utils import BaseTunerLayer
940
1054
 
941
- def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
942
- # TODO(Patrick, Younes): enable "safe" fusing
1055
+ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
1056
+ merge_kwargs = {"safe_merge": safe_fusing}
1057
+
943
1058
  for module in text_encoder.modules():
944
1059
  if isinstance(module, BaseTunerLayer):
945
1060
  if lora_scale != 1.0:
946
1061
  module.scale_layer(lora_scale)
947
1062
 
948
- module.merge()
1063
+ # For BC with previous PEFT versions, we need to check the signature
1064
+ # of the `merge` method to see if it supports the `adapter_names` argument.
1065
+ supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
1066
+ if "adapter_names" in supported_merge_kwargs:
1067
+ merge_kwargs["adapter_names"] = adapter_names
1068
+ elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
1069
+ raise ValueError(
1070
+ "The `adapter_names` argument is not supported with your PEFT version. "
1071
+ "Please upgrade to the latest version of PEFT. `pip install -U peft`"
1072
+ )
1073
+
1074
+ module.merge(**merge_kwargs)
949
1075
 
950
1076
  else:
951
- if version.parse(__version__) > version.parse("0.23"):
952
- deprecate("fuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
1077
+ deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
1078
+
1079
+ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, **kwargs):
1080
+ if "adapter_names" in kwargs and kwargs["adapter_names"] is not None:
1081
+ raise ValueError(
1082
+ "The `adapter_names` argument is not supported in your environment. Please switch to PEFT "
1083
+ "backend to use this argument by installing latest PEFT and transformers."
1084
+ " `pip install -U peft transformers`"
1085
+ )
953
1086
 
954
- def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
955
1087
  for _, attn_module in text_encoder_attn_modules(text_encoder):
956
1088
  if isinstance(attn_module.q_proj, PatchedLoraProjection):
957
1089
  attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
@@ -966,9 +1098,9 @@ class LoraLoaderMixin:
966
1098
 
967
1099
  if fuse_text_encoder:
968
1100
  if hasattr(self, "text_encoder"):
969
- fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing)
1101
+ fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing, adapter_names=adapter_names)
970
1102
  if hasattr(self, "text_encoder_2"):
971
- fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing)
1103
+ fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing, adapter_names=adapter_names)
972
1104
 
973
1105
  def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
974
1106
  r"""
@@ -987,13 +1119,14 @@ class LoraLoaderMixin:
987
1119
  Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
988
1120
  LoRA parameters then it won't have any effect.
989
1121
  """
1122
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
990
1123
  if unfuse_unet:
991
1124
  if not USE_PEFT_BACKEND:
992
- self.unet.unfuse_lora()
1125
+ unet.unfuse_lora()
993
1126
  else:
994
1127
  from peft.tuners.tuners_utils import BaseTunerLayer
995
1128
 
996
- for module in self.unet.modules():
1129
+ for module in unet.modules():
997
1130
  if isinstance(module, BaseTunerLayer):
998
1131
  module.unmerge()
999
1132
 
@@ -1006,8 +1139,7 @@ class LoraLoaderMixin:
1006
1139
  module.unmerge()
1007
1140
 
1008
1141
  else:
1009
- if version.parse(__version__) > version.parse("0.23"):
1010
- deprecate("unfuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
1142
+ deprecate("unfuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
1011
1143
 
1012
1144
  def unfuse_text_encoder_lora(text_encoder):
1013
1145
  for _, attn_module in text_encoder_attn_modules(text_encoder):
@@ -1110,8 +1242,9 @@ class LoraLoaderMixin:
1110
1242
  adapter_names: Union[List[str], str],
1111
1243
  adapter_weights: Optional[List[float]] = None,
1112
1244
  ):
1245
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1113
1246
  # Handle the UNET
1114
- self.unet.set_adapters(adapter_names, adapter_weights)
1247
+ unet.set_adapters(adapter_names, adapter_weights)
1115
1248
 
1116
1249
  # Handle the Text Encoder
1117
1250
  if hasattr(self, "text_encoder"):
@@ -1124,7 +1257,8 @@ class LoraLoaderMixin:
1124
1257
  raise ValueError("PEFT backend is required for this method.")
1125
1258
 
1126
1259
  # Disable unet adapters
1127
- self.unet.disable_lora()
1260
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1261
+ unet.disable_lora()
1128
1262
 
1129
1263
  # Disable text encoder adapters
1130
1264
  if hasattr(self, "text_encoder"):
@@ -1137,7 +1271,8 @@ class LoraLoaderMixin:
1137
1271
  raise ValueError("PEFT backend is required for this method.")
1138
1272
 
1139
1273
  # Enable unet adapters
1140
- self.unet.enable_lora()
1274
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1275
+ unet.enable_lora()
1141
1276
 
1142
1277
  # Enable text encoder adapters
1143
1278
  if hasattr(self, "text_encoder"):
@@ -1159,7 +1294,8 @@ class LoraLoaderMixin:
1159
1294
  adapter_names = [adapter_names]
1160
1295
 
1161
1296
  # Delete unet adapters
1162
- self.unet.delete_adapters(adapter_names)
1297
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1298
+ unet.delete_adapters(adapter_names)
1163
1299
 
1164
1300
  for adapter_name in adapter_names:
1165
1301
  # Delete text encoder adapters
@@ -1192,8 +1328,8 @@ class LoraLoaderMixin:
1192
1328
  from peft.tuners.tuners_utils import BaseTunerLayer
1193
1329
 
1194
1330
  active_adapters = []
1195
-
1196
- for module in self.unet.modules():
1331
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1332
+ for module in unet.modules():
1197
1333
  if isinstance(module, BaseTunerLayer):
1198
1334
  active_adapters = module.active_adapters
1199
1335
  break
@@ -1217,8 +1353,9 @@ class LoraLoaderMixin:
1217
1353
  if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"):
1218
1354
  set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys())
1219
1355
 
1220
- if hasattr(self, "unet") and hasattr(self.unet, "peft_config"):
1221
- set_adapters["unet"] = list(self.unet.peft_config.keys())
1356
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1357
+ if hasattr(self, self.unet_name) and hasattr(unet, "peft_config"):
1358
+ set_adapters[self.unet_name] = list(self.unet.peft_config.keys())
1222
1359
 
1223
1360
  return set_adapters
1224
1361
 
@@ -1239,7 +1376,8 @@ class LoraLoaderMixin:
1239
1376
  from peft.tuners.tuners_utils import BaseTunerLayer
1240
1377
 
1241
1378
  # Handle the UNET
1242
- for unet_module in self.unet.modules():
1379
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1380
+ for unet_module in unet.modules():
1243
1381
  if isinstance(unet_module, BaseTunerLayer):
1244
1382
  for adapter_name in adapter_names:
1245
1383
  unet_module.lora_A[adapter_name].to(device)