diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
diffusers/loaders/lora.py CHANGED
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import copy
14
15
  import inspect
15
16
  import os
16
17
  from pathlib import Path
@@ -25,7 +26,7 @@ from packaging import version
25
26
  from torch import nn
26
27
 
27
28
  from .. import __version__
28
- from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
29
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
29
30
  from ..utils import (
30
31
  USE_PEFT_BACKEND,
31
32
  _get_model_file,
@@ -36,6 +37,7 @@ from ..utils import (
36
37
  get_adapter_name,
37
38
  get_peft_kwargs,
38
39
  is_accelerate_available,
40
+ is_peft_version,
39
41
  is_transformers_available,
40
42
  logging,
41
43
  recurse_remove_peft_layers,
@@ -113,7 +115,7 @@ class LoraLoaderMixin:
113
115
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
114
116
  state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
115
117
 
116
- is_correct_format = all("lora" in key for key in state_dict.keys())
118
+ is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
117
119
  if not is_correct_format:
118
120
  raise ValueError("Invalid LoRA checkpoint.")
119
121
 
@@ -174,9 +176,9 @@ class LoraLoaderMixin:
174
176
  force_download (`bool`, *optional*, defaults to `False`):
175
177
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
176
178
  cached versions if they exist.
177
- resume_download (`bool`, *optional*, defaults to `False`):
178
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
179
- incompletely downloaded files are deleted.
179
+ resume_download:
180
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
181
+ of Diffusers.
180
182
  proxies (`Dict[str, str]`, *optional*):
181
183
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
182
184
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -206,7 +208,7 @@ class LoraLoaderMixin:
206
208
  # UNet and text encoder or both.
207
209
  cache_dir = kwargs.pop("cache_dir", None)
208
210
  force_download = kwargs.pop("force_download", False)
209
- resume_download = kwargs.pop("resume_download", False)
211
+ resume_download = kwargs.pop("resume_download", None)
210
212
  proxies = kwargs.pop("proxies", None)
211
213
  local_files_only = kwargs.pop("local_files_only", None)
212
214
  token = kwargs.pop("token", None)
@@ -281,7 +283,7 @@ class LoraLoaderMixin:
281
283
  subfolder=subfolder,
282
284
  user_agent=user_agent,
283
285
  )
284
- state_dict = torch.load(model_file, map_location="cpu")
286
+ state_dict = load_state_dict(model_file)
285
287
  else:
286
288
  state_dict = pretrained_model_name_or_path_or_dict
287
289
 
@@ -361,13 +363,17 @@ class LoraLoaderMixin:
361
363
  is_model_cpu_offload = False
362
364
  is_sequential_cpu_offload = False
363
365
 
364
- if _pipeline is not None:
366
+ if _pipeline is not None and _pipeline.hf_device_map is None:
365
367
  for _, component in _pipeline.components.items():
366
368
  if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
367
369
  if not is_model_cpu_offload:
368
370
  is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
369
371
  if not is_sequential_cpu_offload:
370
- is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook)
372
+ is_sequential_cpu_offload = (
373
+ isinstance(component._hf_hook, AlignDevicesHook)
374
+ or hasattr(component._hf_hook, "hooks")
375
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
376
+ )
371
377
 
372
378
  logger.info(
373
379
  "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
@@ -451,6 +457,15 @@ class LoraLoaderMixin:
451
457
  rank[key] = val.shape[1]
452
458
 
453
459
  lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
460
+ if "use_dora" in lora_config_kwargs:
461
+ if lora_config_kwargs["use_dora"]:
462
+ if is_peft_version("<", "0.9.0"):
463
+ raise ValueError(
464
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
465
+ )
466
+ else:
467
+ if is_peft_version("<", "0.9.0"):
468
+ lora_config_kwargs.pop("use_dora")
454
469
  lora_config = LoraConfig(**lora_config_kwargs)
455
470
 
456
471
  # adapter_name
@@ -572,6 +587,15 @@ class LoraLoaderMixin:
572
587
  }
573
588
 
574
589
  lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
590
+ if "use_dora" in lora_config_kwargs:
591
+ if lora_config_kwargs["use_dora"]:
592
+ if is_peft_version("<", "0.9.0"):
593
+ raise ValueError(
594
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
595
+ )
596
+ else:
597
+ if is_peft_version("<", "0.9.0"):
598
+ lora_config_kwargs.pop("use_dora")
575
599
  lora_config = LoraConfig(**lora_config_kwargs)
576
600
 
577
601
  # adapter_name
@@ -654,6 +678,13 @@ class LoraLoaderMixin:
654
678
  rank[key] = val.shape[1]
655
679
 
656
680
  lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
681
+ if "use_dora" in lora_config_kwargs:
682
+ if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
683
+ raise ValueError(
684
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
685
+ )
686
+ else:
687
+ lora_config_kwargs.pop("use_dora")
657
688
  lora_config = LoraConfig(**lora_config_kwargs)
658
689
 
659
690
  # adapter_name
@@ -959,7 +990,7 @@ class LoraLoaderMixin:
959
990
  self,
960
991
  adapter_names: Union[List[str], str],
961
992
  text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
962
- text_encoder_weights: List[float] = None,
993
+ text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
963
994
  ):
964
995
  """
965
996
  Sets the adapter layers for the text encoder.
@@ -977,15 +1008,20 @@ class LoraLoaderMixin:
977
1008
  raise ValueError("PEFT backend is required for this method.")
978
1009
 
979
1010
  def process_weights(adapter_names, weights):
980
- if weights is None:
981
- weights = [1.0] * len(adapter_names)
982
- elif isinstance(weights, float):
983
- weights = [weights]
1011
+ # Expand weights into a list, one entry per adapter
1012
+ # e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
1013
+ if not isinstance(weights, list):
1014
+ weights = [weights] * len(adapter_names)
984
1015
 
985
1016
  if len(adapter_names) != len(weights):
986
1017
  raise ValueError(
987
1018
  f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
988
1019
  )
1020
+
1021
+ # Set None values to default of 1.0
1022
+ # e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
1023
+ weights = [w if w is not None else 1.0 for w in weights]
1024
+
989
1025
  return weights
990
1026
 
991
1027
  adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
@@ -1033,17 +1069,77 @@ class LoraLoaderMixin:
1033
1069
  def set_adapters(
1034
1070
  self,
1035
1071
  adapter_names: Union[List[str], str],
1036
- adapter_weights: Optional[List[float]] = None,
1072
+ adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
1037
1073
  ):
1074
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
1075
+
1076
+ adapter_weights = copy.deepcopy(adapter_weights)
1077
+
1078
+ # Expand weights into a list, one entry per adapter
1079
+ if not isinstance(adapter_weights, list):
1080
+ adapter_weights = [adapter_weights] * len(adapter_names)
1081
+
1082
+ if len(adapter_names) != len(adapter_weights):
1083
+ raise ValueError(
1084
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
1085
+ )
1086
+
1087
+ # Decompose weights into weights for unet, text_encoder and text_encoder_2
1088
+ unet_lora_weights, text_encoder_lora_weights, text_encoder_2_lora_weights = [], [], []
1089
+
1090
+ list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
1091
+ all_adapters = {
1092
+ adapter for adapters in list_adapters.values() for adapter in adapters
1093
+ } # eg ["adapter1", "adapter2"]
1094
+ invert_list_adapters = {
1095
+ adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
1096
+ for adapter in all_adapters
1097
+ } # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
1098
+
1099
+ for adapter_name, weights in zip(adapter_names, adapter_weights):
1100
+ if isinstance(weights, dict):
1101
+ unet_lora_weight = weights.pop("unet", None)
1102
+ text_encoder_lora_weight = weights.pop("text_encoder", None)
1103
+ text_encoder_2_lora_weight = weights.pop("text_encoder_2", None)
1104
+
1105
+ if len(weights) > 0:
1106
+ raise ValueError(
1107
+ f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}."
1108
+ )
1109
+
1110
+ if text_encoder_2_lora_weight is not None and not hasattr(self, "text_encoder_2"):
1111
+ logger.warning(
1112
+ "Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2."
1113
+ )
1114
+
1115
+ # warn if adapter doesn't have parts specified by adapter_weights
1116
+ for part_weight, part_name in zip(
1117
+ [unet_lora_weight, text_encoder_lora_weight, text_encoder_2_lora_weight],
1118
+ ["unet", "text_encoder", "text_encoder_2"],
1119
+ ):
1120
+ if part_weight is not None and part_name not in invert_list_adapters[adapter_name]:
1121
+ logger.warning(
1122
+ f"Lora weight dict for adapter '{adapter_name}' contains {part_name}, but this will be ignored because {adapter_name} does not contain weights for {part_name}. Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
1123
+ )
1124
+
1125
+ else:
1126
+ unet_lora_weight = weights
1127
+ text_encoder_lora_weight = weights
1128
+ text_encoder_2_lora_weight = weights
1129
+
1130
+ unet_lora_weights.append(unet_lora_weight)
1131
+ text_encoder_lora_weights.append(text_encoder_lora_weight)
1132
+ text_encoder_2_lora_weights.append(text_encoder_2_lora_weight)
1133
+
1038
1134
  unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1039
1135
  # Handle the UNET
1040
- unet.set_adapters(adapter_names, adapter_weights)
1136
+ unet.set_adapters(adapter_names, unet_lora_weights)
1041
1137
 
1042
1138
  # Handle the Text Encoder
1043
1139
  if hasattr(self, "text_encoder"):
1044
- self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, adapter_weights)
1140
+ self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, text_encoder_lora_weights)
1045
1141
  if hasattr(self, "text_encoder_2"):
1046
- self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, adapter_weights)
1142
+ self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, text_encoder_2_lora_weights)
1047
1143
 
1048
1144
  def disable_lora(self):
1049
1145
  if not USE_PEFT_BACKEND:
@@ -1175,6 +1271,11 @@ class LoraLoaderMixin:
1175
1271
  for adapter_name in adapter_names:
1176
1272
  unet_module.lora_A[adapter_name].to(device)
1177
1273
  unet_module.lora_B[adapter_name].to(device)
1274
+ # this is a param, not a module, so device placement is not in-place -> re-assign
1275
+ if hasattr(unet_module, "lora_magnitude_vector") and unet_module.lora_magnitude_vector is not None:
1276
+ unet_module.lora_magnitude_vector[adapter_name] = unet_module.lora_magnitude_vector[
1277
+ adapter_name
1278
+ ].to(device)
1178
1279
 
1179
1280
  # Handle the text encoder
1180
1281
  modules_to_process = []
@@ -1191,6 +1292,14 @@ class LoraLoaderMixin:
1191
1292
  for adapter_name in adapter_names:
1192
1293
  text_encoder_module.lora_A[adapter_name].to(device)
1193
1294
  text_encoder_module.lora_B[adapter_name].to(device)
1295
+ # this is a param, not a module, so device placement is not in-place -> re-assign
1296
+ if (
1297
+ hasattr(text_encoder_module, "lora_magnitude_vector")
1298
+ and text_encoder_module.lora_magnitude_vector is not None
1299
+ ):
1300
+ text_encoder_module.lora_magnitude_vector[
1301
+ adapter_name
1302
+ ] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device)
1194
1303
 
1195
1304
 
1196
1305
  class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
@@ -1243,7 +1352,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
1243
1352
  unet_config=self.unet.config,
1244
1353
  **kwargs,
1245
1354
  )
1246
- is_correct_format = all("lora" in key for key in state_dict.keys())
1355
+ is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
1247
1356
  if not is_correct_format:
1248
1357
  raise ValueError("Invalid LoRA checkpoint.")
1249
1358
 
@@ -1297,6 +1406,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
1297
1406
  text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1298
1407
  State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
1299
1408
  encoder LoRA state dict because it comes from 🤗 Transformers.
1409
+ text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1410
+ State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
1411
+ encoder LoRA state dict because it comes from 🤗 Transformers.
1300
1412
  is_main_process (`bool`, *optional*, defaults to `True`):
1301
1413
  Whether the process calling this is the main process or not. Useful during distributed training and you
1302
1414
  need to call this function on all processes. In this case, set `is_main_process=True` only on the main
@@ -1323,8 +1435,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
1323
1435
  if unet_lora_layers:
1324
1436
  state_dict.update(pack_weights(unet_lora_layers, "unet"))
1325
1437
 
1326
- if text_encoder_lora_layers and text_encoder_2_lora_layers:
1438
+ if text_encoder_lora_layers:
1327
1439
  state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
1440
+
1441
+ if text_encoder_2_lora_layers:
1328
1442
  state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
1329
1443
 
1330
1444
  cls.write_lora_layers(
@@ -14,7 +14,7 @@
14
14
 
15
15
  import re
16
16
 
17
- from ..utils import logging
17
+ from ..utils import is_peft_version, logging
18
18
 
19
19
 
20
20
  logger = logging.get_logger(__name__)
@@ -128,6 +128,15 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
128
128
  te_state_dict = {}
129
129
  te2_state_dict = {}
130
130
  network_alphas = {}
131
+ is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
132
+ is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
133
+ is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
134
+
135
+ if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora:
136
+ if is_peft_version("<", "0.9.0"):
137
+ raise ValueError(
138
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
139
+ )
131
140
 
132
141
  # every down weight has a corresponding up weight and potentially an alpha weight
133
142
  lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
@@ -198,46 +207,19 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
198
207
  unet_state_dict[diffusers_name] = state_dict.pop(key)
199
208
  unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
200
209
 
201
- elif lora_name.startswith("lora_te_"):
202
- diffusers_name = key.replace("lora_te_", "").replace("_", ".")
203
- diffusers_name = diffusers_name.replace("text.model", "text_model")
204
- diffusers_name = diffusers_name.replace("self.attn", "self_attn")
205
- diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
206
- diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
207
- diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
208
- diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
209
- if "self_attn" in diffusers_name:
210
- te_state_dict[diffusers_name] = state_dict.pop(key)
211
- te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
212
- elif "mlp" in diffusers_name:
213
- # Be aware that this is the new diffusers convention and the rest of the code might
214
- # not utilize it yet.
215
- diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
216
- te_state_dict[diffusers_name] = state_dict.pop(key)
217
- te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
210
+ if is_unet_dora_lora:
211
+ dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
212
+ unet_state_dict[
213
+ diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
214
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
218
215
 
219
- # (sayakpaul): Duplicate code. Needs to be cleaned.
220
- elif lora_name.startswith("lora_te1_"):
221
- diffusers_name = key.replace("lora_te1_", "").replace("_", ".")
222
- diffusers_name = diffusers_name.replace("text.model", "text_model")
223
- diffusers_name = diffusers_name.replace("self.attn", "self_attn")
224
- diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
225
- diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
226
- diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
227
- diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
228
- if "self_attn" in diffusers_name:
229
- te_state_dict[diffusers_name] = state_dict.pop(key)
230
- te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
231
- elif "mlp" in diffusers_name:
232
- # Be aware that this is the new diffusers convention and the rest of the code might
233
- # not utilize it yet.
234
- diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
235
- te_state_dict[diffusers_name] = state_dict.pop(key)
236
- te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
216
+ elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
217
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
218
+ key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
219
+ else:
220
+ key_to_replace = "lora_te2_"
237
221
 
238
- # (sayakpaul): Duplicate code. Needs to be cleaned.
239
- elif lora_name.startswith("lora_te2_"):
240
- diffusers_name = key.replace("lora_te2_", "").replace("_", ".")
222
+ diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
241
223
  diffusers_name = diffusers_name.replace("text.model", "text_model")
242
224
  diffusers_name = diffusers_name.replace("self.attn", "self_attn")
243
225
  diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
@@ -245,14 +227,35 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
245
227
  diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
246
228
  diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
247
229
  if "self_attn" in diffusers_name:
248
- te2_state_dict[diffusers_name] = state_dict.pop(key)
249
- te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
230
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
231
+ te_state_dict[diffusers_name] = state_dict.pop(key)
232
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
233
+ else:
234
+ te2_state_dict[diffusers_name] = state_dict.pop(key)
235
+ te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
250
236
  elif "mlp" in diffusers_name:
251
237
  # Be aware that this is the new diffusers convention and the rest of the code might
252
238
  # not utilize it yet.
253
239
  diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
254
- te2_state_dict[diffusers_name] = state_dict.pop(key)
255
- te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
240
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
241
+ te_state_dict[diffusers_name] = state_dict.pop(key)
242
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
243
+ else:
244
+ te2_state_dict[diffusers_name] = state_dict.pop(key)
245
+ te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
246
+
247
+ if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
248
+ dora_scale_key_to_replace_te = (
249
+ "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
250
+ )
251
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
252
+ te_state_dict[
253
+ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
254
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
255
+ elif lora_name.startswith("lora_te2_"):
256
+ te2_state_dict[
257
+ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
258
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
256
259
 
257
260
  # Rename the alphas so that they can be mapped appropriately.
258
261
  if lora_name_alpha in state_dict:
diffusers/loaders/peft.py CHANGED
@@ -20,7 +20,8 @@ from ..utils import MIN_PEFT_VERSION, check_peft_version, is_peft_available
20
20
  class PeftAdapterMixin:
21
21
  """
22
22
  A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
23
- more details about adapters and injecting them in a transformer-based model, check out the PEFT [documentation](https://huggingface.co/docs/peft/index).
23
+ more details about adapters and injecting them in a transformer-based model, check out the PEFT
24
+ [documentation](https://huggingface.co/docs/peft/index).
24
25
 
25
26
  Install the latest version of PEFT, and use this mixin to:
26
27
 
@@ -143,8 +144,8 @@ class PeftAdapterMixin:
143
144
 
144
145
  def enable_adapters(self) -> None:
145
146
  """
146
- Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the
147
- list of adapters to enable.
147
+ Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the list of
148
+ adapters to enable.
148
149
 
149
150
  If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
150
151
  [documentation](https://huggingface.co/docs/peft).