diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. diffusers/__init__.py +34 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +170 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +35 -6
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +2 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
  53. diffusers/pipelines/cogview3/__init__.py +47 -0
  54. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  55. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  56. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  57. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  58. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  60. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  62. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  63. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  64. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  66. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  67. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  68. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  70. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  71. diffusers/pipelines/flux/__init__.py +10 -0
  72. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  73. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  74. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  76. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  77. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  78. diffusers/pipelines/free_noise_utils.py +365 -5
  79. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  80. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  81. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  82. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  83. diffusers/pipelines/kolors/tokenizer.py +4 -0
  84. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  86. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  87. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  89. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  90. diffusers/pipelines/pag/__init__.py +6 -0
  91. diffusers/pipelines/pag/pag_utils.py +8 -2
  92. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  96. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  97. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  98. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  100. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  101. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  102. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  103. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  106. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  107. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  108. diffusers/pipelines/pipeline_utils.py +123 -180
  109. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  111. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  117. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  120. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  121. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  122. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  126. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  127. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  130. diffusers/quantizers/__init__.py +16 -0
  131. diffusers/quantizers/auto.py +126 -0
  132. diffusers/quantizers/base.py +233 -0
  133. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  134. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  135. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  136. diffusers/quantizers/quantization_config.py +391 -0
  137. diffusers/schedulers/scheduling_ddim.py +4 -1
  138. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  140. diffusers/schedulers/scheduling_ddpm.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  142. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  143. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  145. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  146. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  147. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  149. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  150. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  151. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  152. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  153. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  154. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  155. diffusers/schedulers/scheduling_sasolver.py +78 -1
  156. diffusers/schedulers/scheduling_unclip.py +4 -1
  157. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  158. diffusers/training_utils.py +48 -18
  159. diffusers/utils/__init__.py +2 -1
  160. diffusers/utils/dummy_pt_objects.py +60 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
  162. diffusers/utils/hub_utils.py +16 -4
  163. diffusers/utils/import_utils.py +31 -8
  164. diffusers/utils/loading_utils.py +28 -4
  165. diffusers/utils/peft_utils.py +3 -3
  166. diffusers/utils/testing_utils.py +59 -0
  167. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  168. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
  169. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  170. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
  171. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  172. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -25,13 +25,32 @@ from ..utils import (
25
25
  deprecate,
26
26
  get_adapter_name,
27
27
  get_peft_kwargs,
28
+ is_peft_available,
28
29
  is_peft_version,
30
+ is_torch_version,
29
31
  is_transformers_available,
32
+ is_transformers_version,
30
33
  logging,
31
34
  scale_lora_layers,
32
35
  )
33
36
  from .lora_base import LoraBaseMixin
34
- from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
37
+ from .lora_conversion_utils import (
38
+ _convert_kohya_flux_lora_to_diffusers,
39
+ _convert_non_diffusers_lora_to_diffusers,
40
+ _convert_xlabs_flux_lora_to_diffusers,
41
+ _maybe_map_sgm_blocks_to_diffusers,
42
+ )
43
+
44
+
45
+ _LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
46
+ if is_torch_version(">=", "1.9.0"):
47
+ if (
48
+ is_peft_available()
49
+ and is_peft_version(">=", "0.13.1")
50
+ and is_transformers_available()
51
+ and is_transformers_version(">", "4.45.2")
52
+ ):
53
+ _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
35
54
 
36
55
 
37
56
  if is_transformers_available():
@@ -78,15 +97,24 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
78
97
  Parameters:
79
98
  pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
80
99
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
81
- kwargs (`dict`, *optional*):
82
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
83
100
  adapter_name (`str`, *optional*):
84
101
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
85
102
  `default_{i}` where i is the total number of adapters being loaded.
103
+ low_cpu_mem_usage (`bool`, *optional*):
104
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
105
+ weights.
106
+ kwargs (`dict`, *optional*):
107
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
86
108
  """
87
109
  if not USE_PEFT_BACKEND:
88
110
  raise ValueError("PEFT backend is required for this method.")
89
111
 
112
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
113
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
114
+ raise ValueError(
115
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
116
+ )
117
+
90
118
  # if a dict is passed, copy it instead of modifying it inplace
91
119
  if isinstance(pretrained_model_name_or_path_or_dict, dict):
92
120
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
@@ -94,7 +122,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
94
122
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
95
123
  state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
96
124
 
97
- is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
125
+ is_correct_format = all("lora" in key for key in state_dict.keys())
98
126
  if not is_correct_format:
99
127
  raise ValueError("Invalid LoRA checkpoint.")
100
128
 
@@ -104,6 +132,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
104
132
  unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
105
133
  adapter_name=adapter_name,
106
134
  _pipeline=self,
135
+ low_cpu_mem_usage=low_cpu_mem_usage,
107
136
  )
108
137
  self.load_lora_into_text_encoder(
109
138
  state_dict,
@@ -114,6 +143,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
114
143
  lora_scale=self.lora_scale,
115
144
  adapter_name=adapter_name,
116
145
  _pipeline=self,
146
+ low_cpu_mem_usage=low_cpu_mem_usage,
117
147
  )
118
148
 
119
149
  @classmethod
@@ -206,6 +236,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
206
236
  user_agent=user_agent,
207
237
  allow_pickle=allow_pickle,
208
238
  )
239
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
240
+ if is_dora_scale_present:
241
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
242
+ logger.warning(warn_msg)
243
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
209
244
 
210
245
  network_alphas = None
211
246
  # TODO: replace it with a method from `state_dict_utils`
@@ -227,7 +262,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
227
262
  return state_dict, network_alphas
228
263
 
229
264
  @classmethod
230
- def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
265
+ def load_lora_into_unet(
266
+ cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
267
+ ):
231
268
  """
232
269
  This will load the LoRA layers specified in `state_dict` into `unet`.
233
270
 
@@ -245,10 +282,16 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
245
282
  adapter_name (`str`, *optional*):
246
283
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
247
284
  `default_{i}` where i is the total number of adapters being loaded.
285
+ Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
248
286
  """
249
287
  if not USE_PEFT_BACKEND:
250
288
  raise ValueError("PEFT backend is required for this method.")
251
289
 
290
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
291
+ raise ValueError(
292
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
293
+ )
294
+
252
295
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
253
296
  # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
254
297
  # their prefixes.
@@ -258,7 +301,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
258
301
  # Load the layers corresponding to UNet.
259
302
  logger.info(f"Loading {cls.unet_name}.")
260
303
  unet.load_attn_procs(
261
- state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
304
+ state_dict,
305
+ network_alphas=network_alphas,
306
+ adapter_name=adapter_name,
307
+ _pipeline=_pipeline,
308
+ low_cpu_mem_usage=low_cpu_mem_usage,
262
309
  )
263
310
 
264
311
  @classmethod
@@ -271,6 +318,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
271
318
  lora_scale=1.0,
272
319
  adapter_name=None,
273
320
  _pipeline=None,
321
+ low_cpu_mem_usage=False,
274
322
  ):
275
323
  """
276
324
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -280,7 +328,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
280
328
  A standard state dict containing the lora layer parameters. The key should be prefixed with an
281
329
  additional `text_encoder` to distinguish between unet lora layers.
282
330
  network_alphas (`Dict[str, float]`):
283
- See `LoRALinearLayer` for more details.
331
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
332
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
333
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
284
334
  text_encoder (`CLIPTextModel`):
285
335
  The text encoder model to load the LoRA layers into.
286
336
  prefix (`str`):
@@ -291,10 +341,25 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
291
341
  adapter_name (`str`, *optional*):
292
342
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
293
343
  `default_{i}` where i is the total number of adapters being loaded.
344
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
294
345
  """
295
346
  if not USE_PEFT_BACKEND:
296
347
  raise ValueError("PEFT backend is required for this method.")
297
348
 
349
+ peft_kwargs = {}
350
+ if low_cpu_mem_usage:
351
+ if not is_peft_version(">=", "0.13.1"):
352
+ raise ValueError(
353
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
354
+ )
355
+ if not is_transformers_version(">", "4.45.2"):
356
+ # Note from sayakpaul: It's not in `transformers` stable yet.
357
+ # https://github.com/huggingface/transformers/pull/33725/
358
+ raise ValueError(
359
+ "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
360
+ )
361
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
362
+
298
363
  from peft import LoraConfig
299
364
 
300
365
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -365,6 +430,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
365
430
  adapter_name=adapter_name,
366
431
  adapter_state_dict=text_encoder_lora_state_dict,
367
432
  peft_config=lora_config,
433
+ **peft_kwargs,
368
434
  )
369
435
 
370
436
  # scale LoRA layers with `lora_scale`
@@ -535,12 +601,19 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
535
601
  adapter_name (`str`, *optional*):
536
602
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
537
603
  `default_{i}` where i is the total number of adapters being loaded.
604
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
538
605
  kwargs (`dict`, *optional*):
539
606
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
540
607
  """
541
608
  if not USE_PEFT_BACKEND:
542
609
  raise ValueError("PEFT backend is required for this method.")
543
610
 
611
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
612
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
613
+ raise ValueError(
614
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
615
+ )
616
+
544
617
  # We could have accessed the unet config from `lora_state_dict()` too. We pass
545
618
  # it here explicitly to be able to tell that it's coming from an SDXL
546
619
  # pipeline.
@@ -555,12 +628,18 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
555
628
  unet_config=self.unet.config,
556
629
  **kwargs,
557
630
  )
558
- is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
631
+
632
+ is_correct_format = all("lora" in key for key in state_dict.keys())
559
633
  if not is_correct_format:
560
634
  raise ValueError("Invalid LoRA checkpoint.")
561
635
 
562
636
  self.load_lora_into_unet(
563
- state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
637
+ state_dict,
638
+ network_alphas=network_alphas,
639
+ unet=self.unet,
640
+ adapter_name=adapter_name,
641
+ _pipeline=self,
642
+ low_cpu_mem_usage=low_cpu_mem_usage,
564
643
  )
565
644
  text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
566
645
  if len(text_encoder_state_dict) > 0:
@@ -572,6 +651,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
572
651
  lora_scale=self.lora_scale,
573
652
  adapter_name=adapter_name,
574
653
  _pipeline=self,
654
+ low_cpu_mem_usage=low_cpu_mem_usage,
575
655
  )
576
656
 
577
657
  text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
@@ -584,6 +664,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
584
664
  lora_scale=self.lora_scale,
585
665
  adapter_name=adapter_name,
586
666
  _pipeline=self,
667
+ low_cpu_mem_usage=low_cpu_mem_usage,
587
668
  )
588
669
 
589
670
  @classmethod
@@ -677,6 +758,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
677
758
  user_agent=user_agent,
678
759
  allow_pickle=allow_pickle,
679
760
  )
761
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
762
+ if is_dora_scale_present:
763
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
764
+ logger.warning(warn_msg)
765
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
680
766
 
681
767
  network_alphas = None
682
768
  # TODO: replace it with a method from `state_dict_utils`
@@ -699,7 +785,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
699
785
 
700
786
  @classmethod
701
787
  # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
702
- def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
788
+ def load_lora_into_unet(
789
+ cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
790
+ ):
703
791
  """
704
792
  This will load the LoRA layers specified in `state_dict` into `unet`.
705
793
 
@@ -717,10 +805,16 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
717
805
  adapter_name (`str`, *optional*):
718
806
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
719
807
  `default_{i}` where i is the total number of adapters being loaded.
808
+ Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
720
809
  """
721
810
  if not USE_PEFT_BACKEND:
722
811
  raise ValueError("PEFT backend is required for this method.")
723
812
 
813
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
814
+ raise ValueError(
815
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
816
+ )
817
+
724
818
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
725
819
  # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
726
820
  # their prefixes.
@@ -730,7 +824,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
730
824
  # Load the layers corresponding to UNet.
731
825
  logger.info(f"Loading {cls.unet_name}.")
732
826
  unet.load_attn_procs(
733
- state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
827
+ state_dict,
828
+ network_alphas=network_alphas,
829
+ adapter_name=adapter_name,
830
+ _pipeline=_pipeline,
831
+ low_cpu_mem_usage=low_cpu_mem_usage,
734
832
  )
735
833
 
736
834
  @classmethod
@@ -744,6 +842,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
744
842
  lora_scale=1.0,
745
843
  adapter_name=None,
746
844
  _pipeline=None,
845
+ low_cpu_mem_usage=False,
747
846
  ):
748
847
  """
749
848
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -753,7 +852,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
753
852
  A standard state dict containing the lora layer parameters. The key should be prefixed with an
754
853
  additional `text_encoder` to distinguish between unet lora layers.
755
854
  network_alphas (`Dict[str, float]`):
756
- See `LoRALinearLayer` for more details.
855
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
856
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
857
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
757
858
  text_encoder (`CLIPTextModel`):
758
859
  The text encoder model to load the LoRA layers into.
759
860
  prefix (`str`):
@@ -764,10 +865,25 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
764
865
  adapter_name (`str`, *optional*):
765
866
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
766
867
  `default_{i}` where i is the total number of adapters being loaded.
868
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
767
869
  """
768
870
  if not USE_PEFT_BACKEND:
769
871
  raise ValueError("PEFT backend is required for this method.")
770
872
 
873
+ peft_kwargs = {}
874
+ if low_cpu_mem_usage:
875
+ if not is_peft_version(">=", "0.13.1"):
876
+ raise ValueError(
877
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
878
+ )
879
+ if not is_transformers_version(">", "4.45.2"):
880
+ # Note from sayakpaul: It's not in `transformers` stable yet.
881
+ # https://github.com/huggingface/transformers/pull/33725/
882
+ raise ValueError(
883
+ "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
884
+ )
885
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
886
+
771
887
  from peft import LoraConfig
772
888
 
773
889
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -838,6 +954,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
838
954
  adapter_name=adapter_name,
839
955
  adapter_state_dict=text_encoder_lora_state_dict,
840
956
  peft_config=lora_config,
957
+ **peft_kwargs,
841
958
  )
842
959
 
843
960
  # scale LoRA layers with `lora_scale`
@@ -1080,6 +1197,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1080
1197
  allow_pickle=allow_pickle,
1081
1198
  )
1082
1199
 
1200
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
1201
+ if is_dora_scale_present:
1202
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1203
+ logger.warning(warn_msg)
1204
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
1205
+
1083
1206
  return state_dict
1084
1207
 
1085
1208
  def load_lora_weights(
@@ -1100,15 +1223,22 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1100
1223
  Parameters:
1101
1224
  pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1102
1225
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1103
- kwargs (`dict`, *optional*):
1104
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1105
1226
  adapter_name (`str`, *optional*):
1106
1227
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1107
1228
  `default_{i}` where i is the total number of adapters being loaded.
1229
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
1230
+ kwargs (`dict`, *optional*):
1231
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1108
1232
  """
1109
1233
  if not USE_PEFT_BACKEND:
1110
1234
  raise ValueError("PEFT backend is required for this method.")
1111
1235
 
1236
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
1237
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
1238
+ raise ValueError(
1239
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1240
+ )
1241
+
1112
1242
  # if a dict is passed, copy it instead of modifying it inplace
1113
1243
  if isinstance(pretrained_model_name_or_path_or_dict, dict):
1114
1244
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
@@ -1116,7 +1246,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1116
1246
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1117
1247
  state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
1118
1248
 
1119
- is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
1249
+ is_correct_format = all("lora" in key for key in state_dict.keys())
1120
1250
  if not is_correct_format:
1121
1251
  raise ValueError("Invalid LoRA checkpoint.")
1122
1252
 
@@ -1125,6 +1255,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1125
1255
  transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
1126
1256
  adapter_name=adapter_name,
1127
1257
  _pipeline=self,
1258
+ low_cpu_mem_usage=low_cpu_mem_usage,
1128
1259
  )
1129
1260
 
1130
1261
  text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
@@ -1137,6 +1268,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1137
1268
  lora_scale=self.lora_scale,
1138
1269
  adapter_name=adapter_name,
1139
1270
  _pipeline=self,
1271
+ low_cpu_mem_usage=low_cpu_mem_usage,
1140
1272
  )
1141
1273
 
1142
1274
  text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
@@ -1149,10 +1281,13 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1149
1281
  lora_scale=self.lora_scale,
1150
1282
  adapter_name=adapter_name,
1151
1283
  _pipeline=self,
1284
+ low_cpu_mem_usage=low_cpu_mem_usage,
1152
1285
  )
1153
1286
 
1154
1287
  @classmethod
1155
- def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
1288
+ def load_lora_into_transformer(
1289
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
1290
+ ):
1156
1291
  """
1157
1292
  This will load the LoRA layers specified in `state_dict` into `transformer`.
1158
1293
 
@@ -1166,7 +1301,13 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1166
1301
  adapter_name (`str`, *optional*):
1167
1302
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1168
1303
  `default_{i}` where i is the total number of adapters being loaded.
1304
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
1169
1305
  """
1306
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
1307
+ raise ValueError(
1308
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1309
+ )
1310
+
1170
1311
  from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
1171
1312
 
1172
1313
  keys = list(state_dict.keys())
@@ -1210,17 +1351,37 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1210
1351
  # otherwise loading LoRA weights will lead to an error
1211
1352
  is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
1212
1353
 
1213
- inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
1214
- incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
1354
+ peft_kwargs = {}
1355
+ if is_peft_version(">=", "0.13.1"):
1356
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
1357
+
1358
+ inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
1359
+ incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
1215
1360
 
1361
+ warn_msg = ""
1216
1362
  if incompatible_keys is not None:
1217
- # check only for unexpected keys
1363
+ # Check only for unexpected keys.
1218
1364
  unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1219
1365
  if unexpected_keys:
1220
- logger.warning(
1221
- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1222
- f" {unexpected_keys}. "
1223
- )
1366
+ lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
1367
+ if lora_unexpected_keys:
1368
+ warn_msg = (
1369
+ f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
1370
+ f" {', '.join(lora_unexpected_keys)}. "
1371
+ )
1372
+
1373
+ # Filter missing keys specific to the current adapter.
1374
+ missing_keys = getattr(incompatible_keys, "missing_keys", None)
1375
+ if missing_keys:
1376
+ lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
1377
+ if lora_missing_keys:
1378
+ warn_msg += (
1379
+ f"Loading adapter weights from state_dict led to missing keys in the model:"
1380
+ f" {', '.join(lora_missing_keys)}."
1381
+ )
1382
+
1383
+ if warn_msg:
1384
+ logger.warning(warn_msg)
1224
1385
 
1225
1386
  # Offload back.
1226
1387
  if is_model_cpu_offload:
@@ -1240,6 +1401,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1240
1401
  lora_scale=1.0,
1241
1402
  adapter_name=None,
1242
1403
  _pipeline=None,
1404
+ low_cpu_mem_usage=False,
1243
1405
  ):
1244
1406
  """
1245
1407
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1249,7 +1411,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1249
1411
  A standard state dict containing the lora layer parameters. The key should be prefixed with an
1250
1412
  additional `text_encoder` to distinguish between unet lora layers.
1251
1413
  network_alphas (`Dict[str, float]`):
1252
- See `LoRALinearLayer` for more details.
1414
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
1415
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
1416
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
1253
1417
  text_encoder (`CLIPTextModel`):
1254
1418
  The text encoder model to load the LoRA layers into.
1255
1419
  prefix (`str`):
@@ -1260,10 +1424,25 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1260
1424
  adapter_name (`str`, *optional*):
1261
1425
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1262
1426
  `default_{i}` where i is the total number of adapters being loaded.
1427
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
1263
1428
  """
1264
1429
  if not USE_PEFT_BACKEND:
1265
1430
  raise ValueError("PEFT backend is required for this method.")
1266
1431
 
1432
+ peft_kwargs = {}
1433
+ if low_cpu_mem_usage:
1434
+ if not is_peft_version(">=", "0.13.1"):
1435
+ raise ValueError(
1436
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1437
+ )
1438
+ if not is_transformers_version(">", "4.45.2"):
1439
+ # Note from sayakpaul: It's not in `transformers` stable yet.
1440
+ # https://github.com/huggingface/transformers/pull/33725/
1441
+ raise ValueError(
1442
+ "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
1443
+ )
1444
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
1445
+
1267
1446
  from peft import LoraConfig
1268
1447
 
1269
1448
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -1334,6 +1513,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1334
1513
  adapter_name=adapter_name,
1335
1514
  adapter_state_dict=text_encoder_lora_state_dict,
1336
1515
  peft_config=lora_config,
1516
+ **peft_kwargs,
1337
1517
  )
1338
1518
 
1339
1519
  # scale LoRA layers with `lora_scale`
@@ -1576,6 +1756,24 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1576
1756
  user_agent=user_agent,
1577
1757
  allow_pickle=allow_pickle,
1578
1758
  )
1759
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
1760
+ if is_dora_scale_present:
1761
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1762
+ logger.warning(warn_msg)
1763
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
1764
+
1765
+ # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
1766
+ is_kohya = any(".lora_down.weight" in k for k in state_dict)
1767
+ if is_kohya:
1768
+ state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
1769
+ # Kohya already takes care of scaling the LoRA parameters with alpha.
1770
+ return (state_dict, None) if return_alphas else state_dict
1771
+
1772
+ is_xlabs = any("processor" in k for k in state_dict)
1773
+ if is_xlabs:
1774
+ state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
1775
+ # xlabs doesn't use `alpha`.
1776
+ return (state_dict, None) if return_alphas else state_dict
1579
1777
 
1580
1778
  # For state dicts like
1581
1779
  # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
@@ -1621,10 +1819,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1621
1819
  adapter_name (`str`, *optional*):
1622
1820
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1623
1821
  `default_{i}` where i is the total number of adapters being loaded.
1822
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
1624
1823
  """
1625
1824
  if not USE_PEFT_BACKEND:
1626
1825
  raise ValueError("PEFT backend is required for this method.")
1627
1826
 
1827
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
1828
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
1829
+ raise ValueError(
1830
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1831
+ )
1832
+
1628
1833
  # if a dict is passed, copy it instead of modifying it inplace
1629
1834
  if isinstance(pretrained_model_name_or_path_or_dict, dict):
1630
1835
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
@@ -1634,7 +1839,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1634
1839
  pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
1635
1840
  )
1636
1841
 
1637
- is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
1842
+ is_correct_format = all("lora" in key for key in state_dict.keys())
1638
1843
  if not is_correct_format:
1639
1844
  raise ValueError("Invalid LoRA checkpoint.")
1640
1845
 
@@ -1644,6 +1849,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1644
1849
  transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
1645
1850
  adapter_name=adapter_name,
1646
1851
  _pipeline=self,
1852
+ low_cpu_mem_usage=low_cpu_mem_usage,
1647
1853
  )
1648
1854
 
1649
1855
  text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
@@ -1656,10 +1862,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1656
1862
  lora_scale=self.lora_scale,
1657
1863
  adapter_name=adapter_name,
1658
1864
  _pipeline=self,
1865
+ low_cpu_mem_usage=low_cpu_mem_usage,
1659
1866
  )
1660
1867
 
1661
1868
  @classmethod
1662
- def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
1869
+ def load_lora_into_transformer(
1870
+ cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
1871
+ ):
1663
1872
  """
1664
1873
  This will load the LoRA layers specified in `state_dict` into `transformer`.
1665
1874
 
@@ -1677,7 +1886,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1677
1886
  adapter_name (`str`, *optional*):
1678
1887
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1679
1888
  `default_{i}` where i is the total number of adapters being loaded.
1889
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
1680
1890
  """
1891
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
1892
+ raise ValueError(
1893
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1894
+ )
1895
+
1681
1896
  from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
1682
1897
 
1683
1898
  keys = list(state_dict.keys())
@@ -1726,17 +1941,37 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1726
1941
  # otherwise loading LoRA weights will lead to an error
1727
1942
  is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
1728
1943
 
1729
- inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
1730
- incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
1944
+ peft_kwargs = {}
1945
+ if is_peft_version(">=", "0.13.1"):
1946
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
1947
+
1948
+ inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
1949
+ incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
1731
1950
 
1951
+ warn_msg = ""
1732
1952
  if incompatible_keys is not None:
1733
- # check only for unexpected keys
1953
+ # Check only for unexpected keys.
1734
1954
  unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1735
1955
  if unexpected_keys:
1736
- logger.warning(
1737
- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1738
- f" {unexpected_keys}. "
1739
- )
1956
+ lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
1957
+ if lora_unexpected_keys:
1958
+ warn_msg = (
1959
+ f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
1960
+ f" {', '.join(lora_unexpected_keys)}. "
1961
+ )
1962
+
1963
+ # Filter missing keys specific to the current adapter.
1964
+ missing_keys = getattr(incompatible_keys, "missing_keys", None)
1965
+ if missing_keys:
1966
+ lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
1967
+ if lora_missing_keys:
1968
+ warn_msg += (
1969
+ f"Loading adapter weights from state_dict led to missing keys in the model:"
1970
+ f" {', '.join(lora_missing_keys)}."
1971
+ )
1972
+
1973
+ if warn_msg:
1974
+ logger.warning(warn_msg)
1740
1975
 
1741
1976
  # Offload back.
1742
1977
  if is_model_cpu_offload:
@@ -1756,6 +1991,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1756
1991
  lora_scale=1.0,
1757
1992
  adapter_name=None,
1758
1993
  _pipeline=None,
1994
+ low_cpu_mem_usage=False,
1759
1995
  ):
1760
1996
  """
1761
1997
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1765,7 +2001,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1765
2001
  A standard state dict containing the lora layer parameters. The key should be prefixed with an
1766
2002
  additional `text_encoder` to distinguish between unet lora layers.
1767
2003
  network_alphas (`Dict[str, float]`):
1768
- See `LoRALinearLayer` for more details.
2004
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
2005
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
2006
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
1769
2007
  text_encoder (`CLIPTextModel`):
1770
2008
  The text encoder model to load the LoRA layers into.
1771
2009
  prefix (`str`):
@@ -1776,10 +2014,25 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1776
2014
  adapter_name (`str`, *optional*):
1777
2015
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1778
2016
  `default_{i}` where i is the total number of adapters being loaded.
2017
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
1779
2018
  """
1780
2019
  if not USE_PEFT_BACKEND:
1781
2020
  raise ValueError("PEFT backend is required for this method.")
1782
2021
 
2022
+ peft_kwargs = {}
2023
+ if low_cpu_mem_usage:
2024
+ if not is_peft_version(">=", "0.13.1"):
2025
+ raise ValueError(
2026
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2027
+ )
2028
+ if not is_transformers_version(">", "4.45.2"):
2029
+ # Note from sayakpaul: It's not in `transformers` stable yet.
2030
+ # https://github.com/huggingface/transformers/pull/33725/
2031
+ raise ValueError(
2032
+ "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
2033
+ )
2034
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
2035
+
1783
2036
  from peft import LoraConfig
1784
2037
 
1785
2038
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -1850,6 +2103,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1850
2103
  adapter_name=adapter_name,
1851
2104
  adapter_state_dict=text_encoder_lora_state_dict,
1852
2105
  peft_config=lora_config,
2106
+ **peft_kwargs,
1853
2107
  )
1854
2108
 
1855
2109
  # scale LoRA layers with `lora_scale`
@@ -1998,7 +2252,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
1998
2252
  into the unet or prefixed with an additional `unet` which can be used to distinguish between text
1999
2253
  encoder lora layers.
2000
2254
  network_alphas (`Dict[str, float]`):
2001
- See `LoRALinearLayer` for more details.
2255
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
2256
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
2257
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
2002
2258
  unet (`UNet2DConditionModel`):
2003
2259
  The UNet model to load the LoRA layers into.
2004
2260
  adapter_name (`str`, *optional*):
@@ -2055,14 +2311,30 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2055
2311
  inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
2056
2312
  incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
2057
2313
 
2314
+ warn_msg = ""
2058
2315
  if incompatible_keys is not None:
2059
- # check only for unexpected keys
2316
+ # Check only for unexpected keys.
2060
2317
  unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
2061
2318
  if unexpected_keys:
2062
- logger.warning(
2063
- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
2064
- f" {unexpected_keys}. "
2065
- )
2319
+ lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
2320
+ if lora_unexpected_keys:
2321
+ warn_msg = (
2322
+ f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
2323
+ f" {', '.join(lora_unexpected_keys)}. "
2324
+ )
2325
+
2326
+ # Filter missing keys specific to the current adapter.
2327
+ missing_keys = getattr(incompatible_keys, "missing_keys", None)
2328
+ if missing_keys:
2329
+ lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
2330
+ if lora_missing_keys:
2331
+ warn_msg += (
2332
+ f"Loading adapter weights from state_dict led to missing keys in the model:"
2333
+ f" {', '.join(lora_missing_keys)}."
2334
+ )
2335
+
2336
+ if warn_msg:
2337
+ logger.warning(warn_msg)
2066
2338
 
2067
2339
  # Offload back.
2068
2340
  if is_model_cpu_offload:
@@ -2082,6 +2354,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2082
2354
  lora_scale=1.0,
2083
2355
  adapter_name=None,
2084
2356
  _pipeline=None,
2357
+ low_cpu_mem_usage=False,
2085
2358
  ):
2086
2359
  """
2087
2360
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -2091,7 +2364,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2091
2364
  A standard state dict containing the lora layer parameters. The key should be prefixed with an
2092
2365
  additional `text_encoder` to distinguish between unet lora layers.
2093
2366
  network_alphas (`Dict[str, float]`):
2094
- See `LoRALinearLayer` for more details.
2367
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
2368
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
2369
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
2095
2370
  text_encoder (`CLIPTextModel`):
2096
2371
  The text encoder model to load the LoRA layers into.
2097
2372
  prefix (`str`):
@@ -2102,10 +2377,25 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2102
2377
  adapter_name (`str`, *optional*):
2103
2378
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2104
2379
  `default_{i}` where i is the total number of adapters being loaded.
2380
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
2105
2381
  """
2106
2382
  if not USE_PEFT_BACKEND:
2107
2383
  raise ValueError("PEFT backend is required for this method.")
2108
2384
 
2385
+ peft_kwargs = {}
2386
+ if low_cpu_mem_usage:
2387
+ if not is_peft_version(">=", "0.13.1"):
2388
+ raise ValueError(
2389
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2390
+ )
2391
+ if not is_transformers_version(">", "4.45.2"):
2392
+ # Note from sayakpaul: It's not in `transformers` stable yet.
2393
+ # https://github.com/huggingface/transformers/pull/33725/
2394
+ raise ValueError(
2395
+ "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
2396
+ )
2397
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
2398
+
2109
2399
  from peft import LoraConfig
2110
2400
 
2111
2401
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -2176,6 +2466,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2176
2466
  adapter_name=adapter_name,
2177
2467
  adapter_state_dict=text_encoder_lora_state_dict,
2178
2468
  peft_config=lora_config,
2469
+ **peft_kwargs,
2179
2470
  )
2180
2471
 
2181
2472
  # scale LoRA layers with `lora_scale`
@@ -2245,6 +2536,381 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2245
2536
  )
2246
2537
 
2247
2538
 
2539
+ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2540
+ r"""
2541
+ Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoX`].
2542
+ """
2543
+
2544
+ _lora_loadable_modules = ["transformer"]
2545
+ transformer_name = TRANSFORMER_NAME
2546
+
2547
+ @classmethod
2548
+ @validate_hf_hub_args
2549
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
2550
+ def lora_state_dict(
2551
+ cls,
2552
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
2553
+ **kwargs,
2554
+ ):
2555
+ r"""
2556
+ Return state dict for lora weights and the network alphas.
2557
+
2558
+ <Tip warning={true}>
2559
+
2560
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
2561
+
2562
+ This function is experimental and might change in the future.
2563
+
2564
+ </Tip>
2565
+
2566
+ Parameters:
2567
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
2568
+ Can be either:
2569
+
2570
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
2571
+ the Hub.
2572
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
2573
+ with [`ModelMixin.save_pretrained`].
2574
+ - A [torch state
2575
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
2576
+
2577
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
2578
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
2579
+ is not used.
2580
+ force_download (`bool`, *optional*, defaults to `False`):
2581
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
2582
+ cached versions if they exist.
2583
+
2584
+ proxies (`Dict[str, str]`, *optional*):
2585
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
2586
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
2587
+ local_files_only (`bool`, *optional*, defaults to `False`):
2588
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
2589
+ won't be downloaded from the Hub.
2590
+ token (`str` or *bool*, *optional*):
2591
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
2592
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
2593
+ revision (`str`, *optional*, defaults to `"main"`):
2594
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
2595
+ allowed by Git.
2596
+ subfolder (`str`, *optional*, defaults to `""`):
2597
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
2598
+
2599
+ """
2600
+ # Load the main state dict first which has the LoRA layers for either of
2601
+ # transformer and text encoder or both.
2602
+ cache_dir = kwargs.pop("cache_dir", None)
2603
+ force_download = kwargs.pop("force_download", False)
2604
+ proxies = kwargs.pop("proxies", None)
2605
+ local_files_only = kwargs.pop("local_files_only", None)
2606
+ token = kwargs.pop("token", None)
2607
+ revision = kwargs.pop("revision", None)
2608
+ subfolder = kwargs.pop("subfolder", None)
2609
+ weight_name = kwargs.pop("weight_name", None)
2610
+ use_safetensors = kwargs.pop("use_safetensors", None)
2611
+
2612
+ allow_pickle = False
2613
+ if use_safetensors is None:
2614
+ use_safetensors = True
2615
+ allow_pickle = True
2616
+
2617
+ user_agent = {
2618
+ "file_type": "attn_procs_weights",
2619
+ "framework": "pytorch",
2620
+ }
2621
+
2622
+ state_dict = cls._fetch_state_dict(
2623
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
2624
+ weight_name=weight_name,
2625
+ use_safetensors=use_safetensors,
2626
+ local_files_only=local_files_only,
2627
+ cache_dir=cache_dir,
2628
+ force_download=force_download,
2629
+ proxies=proxies,
2630
+ token=token,
2631
+ revision=revision,
2632
+ subfolder=subfolder,
2633
+ user_agent=user_agent,
2634
+ allow_pickle=allow_pickle,
2635
+ )
2636
+
2637
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
2638
+ if is_dora_scale_present:
2639
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
2640
+ logger.warning(warn_msg)
2641
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
2642
+
2643
+ return state_dict
2644
+
2645
+ def load_lora_weights(
2646
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
2647
+ ):
2648
+ """
2649
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
2650
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
2651
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
2652
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
2653
+ dict is loaded into `self.transformer`.
2654
+
2655
+ Parameters:
2656
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
2657
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
2658
+ adapter_name (`str`, *optional*):
2659
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2660
+ `default_{i}` where i is the total number of adapters being loaded.
2661
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
2662
+ kwargs (`dict`, *optional*):
2663
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
2664
+ """
2665
+ if not USE_PEFT_BACKEND:
2666
+ raise ValueError("PEFT backend is required for this method.")
2667
+
2668
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
2669
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
2670
+ raise ValueError(
2671
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2672
+ )
2673
+
2674
+ # if a dict is passed, copy it instead of modifying it inplace
2675
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
2676
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
2677
+
2678
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
2679
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
2680
+
2681
+ is_correct_format = all("lora" in key for key in state_dict.keys())
2682
+ if not is_correct_format:
2683
+ raise ValueError("Invalid LoRA checkpoint.")
2684
+
2685
+ self.load_lora_into_transformer(
2686
+ state_dict,
2687
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
2688
+ adapter_name=adapter_name,
2689
+ _pipeline=self,
2690
+ low_cpu_mem_usage=low_cpu_mem_usage,
2691
+ )
2692
+
2693
+ @classmethod
2694
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
2695
+ def load_lora_into_transformer(
2696
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
2697
+ ):
2698
+ """
2699
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
2700
+
2701
+ Parameters:
2702
+ state_dict (`dict`):
2703
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
2704
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
2705
+ encoder lora layers.
2706
+ transformer (`SD3Transformer2DModel`):
2707
+ The Transformer model to load the LoRA layers into.
2708
+ adapter_name (`str`, *optional*):
2709
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2710
+ `default_{i}` where i is the total number of adapters being loaded.
2711
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
2712
+ """
2713
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
2714
+ raise ValueError(
2715
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2716
+ )
2717
+
2718
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
2719
+
2720
+ keys = list(state_dict.keys())
2721
+
2722
+ transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
2723
+ state_dict = {
2724
+ k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
2725
+ }
2726
+
2727
+ if len(state_dict.keys()) > 0:
2728
+ # check with first key if is not in peft format
2729
+ first_key = next(iter(state_dict.keys()))
2730
+ if "lora_A" not in first_key:
2731
+ state_dict = convert_unet_state_dict_to_peft(state_dict)
2732
+
2733
+ if adapter_name in getattr(transformer, "peft_config", {}):
2734
+ raise ValueError(
2735
+ f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
2736
+ )
2737
+
2738
+ rank = {}
2739
+ for key, val in state_dict.items():
2740
+ if "lora_B" in key:
2741
+ rank[key] = val.shape[1]
2742
+
2743
+ lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
2744
+ if "use_dora" in lora_config_kwargs:
2745
+ if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
2746
+ raise ValueError(
2747
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
2748
+ )
2749
+ else:
2750
+ lora_config_kwargs.pop("use_dora")
2751
+ lora_config = LoraConfig(**lora_config_kwargs)
2752
+
2753
+ # adapter_name
2754
+ if adapter_name is None:
2755
+ adapter_name = get_adapter_name(transformer)
2756
+
2757
+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
2758
+ # otherwise loading LoRA weights will lead to an error
2759
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
2760
+
2761
+ peft_kwargs = {}
2762
+ if is_peft_version(">=", "0.13.1"):
2763
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
2764
+
2765
+ inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
2766
+ incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
2767
+
2768
+ warn_msg = ""
2769
+ if incompatible_keys is not None:
2770
+ # Check only for unexpected keys.
2771
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
2772
+ if unexpected_keys:
2773
+ lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
2774
+ if lora_unexpected_keys:
2775
+ warn_msg = (
2776
+ f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
2777
+ f" {', '.join(lora_unexpected_keys)}. "
2778
+ )
2779
+
2780
+ # Filter missing keys specific to the current adapter.
2781
+ missing_keys = getattr(incompatible_keys, "missing_keys", None)
2782
+ if missing_keys:
2783
+ lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
2784
+ if lora_missing_keys:
2785
+ warn_msg += (
2786
+ f"Loading adapter weights from state_dict led to missing keys in the model:"
2787
+ f" {', '.join(lora_missing_keys)}."
2788
+ )
2789
+
2790
+ if warn_msg:
2791
+ logger.warning(warn_msg)
2792
+
2793
+ # Offload back.
2794
+ if is_model_cpu_offload:
2795
+ _pipeline.enable_model_cpu_offload()
2796
+ elif is_sequential_cpu_offload:
2797
+ _pipeline.enable_sequential_cpu_offload()
2798
+ # Unsafe code />
2799
+
2800
+ @classmethod
2801
+ # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
2802
+ def save_lora_weights(
2803
+ cls,
2804
+ save_directory: Union[str, os.PathLike],
2805
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
2806
+ is_main_process: bool = True,
2807
+ weight_name: str = None,
2808
+ save_function: Callable = None,
2809
+ safe_serialization: bool = True,
2810
+ ):
2811
+ r"""
2812
+ Save the LoRA parameters corresponding to the UNet and text encoder.
2813
+
2814
+ Arguments:
2815
+ save_directory (`str` or `os.PathLike`):
2816
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
2817
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
2818
+ State dict of the LoRA layers corresponding to the `transformer`.
2819
+ is_main_process (`bool`, *optional*, defaults to `True`):
2820
+ Whether the process calling this is the main process or not. Useful during distributed training and you
2821
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
2822
+ process to avoid race conditions.
2823
+ save_function (`Callable`):
2824
+ The function to use to save the state dictionary. Useful during distributed training when you need to
2825
+ replace `torch.save` with another method. Can be configured with the environment variable
2826
+ `DIFFUSERS_SAVE_MODE`.
2827
+ safe_serialization (`bool`, *optional*, defaults to `True`):
2828
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
2829
+ """
2830
+ state_dict = {}
2831
+
2832
+ if not transformer_lora_layers:
2833
+ raise ValueError("You must pass `transformer_lora_layers`.")
2834
+
2835
+ if transformer_lora_layers:
2836
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
2837
+
2838
+ # Save the model
2839
+ cls.write_lora_layers(
2840
+ state_dict=state_dict,
2841
+ save_directory=save_directory,
2842
+ is_main_process=is_main_process,
2843
+ weight_name=weight_name,
2844
+ save_function=save_function,
2845
+ safe_serialization=safe_serialization,
2846
+ )
2847
+
2848
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
2849
+ def fuse_lora(
2850
+ self,
2851
+ components: List[str] = ["transformer", "text_encoder"],
2852
+ lora_scale: float = 1.0,
2853
+ safe_fusing: bool = False,
2854
+ adapter_names: Optional[List[str]] = None,
2855
+ **kwargs,
2856
+ ):
2857
+ r"""
2858
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
2859
+
2860
+ <Tip warning={true}>
2861
+
2862
+ This is an experimental API.
2863
+
2864
+ </Tip>
2865
+
2866
+ Args:
2867
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
2868
+ lora_scale (`float`, defaults to 1.0):
2869
+ Controls how much to influence the outputs with the LoRA parameters.
2870
+ safe_fusing (`bool`, defaults to `False`):
2871
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
2872
+ adapter_names (`List[str]`, *optional*):
2873
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
2874
+
2875
+ Example:
2876
+
2877
+ ```py
2878
+ from diffusers import DiffusionPipeline
2879
+ import torch
2880
+
2881
+ pipeline = DiffusionPipeline.from_pretrained(
2882
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
2883
+ ).to("cuda")
2884
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
2885
+ pipeline.fuse_lora(lora_scale=0.7)
2886
+ ```
2887
+ """
2888
+ super().fuse_lora(
2889
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
2890
+ )
2891
+
2892
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
2893
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
2894
+ r"""
2895
+ Reverses the effect of
2896
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
2897
+
2898
+ <Tip warning={true}>
2899
+
2900
+ This is an experimental API.
2901
+
2902
+ </Tip>
2903
+
2904
+ Args:
2905
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
2906
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
2907
+ unfuse_text_encoder (`bool`, defaults to `True`):
2908
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
2909
+ LoRA parameters then it won't have any effect.
2910
+ """
2911
+ super().unfuse_lora(components=components)
2912
+
2913
+
2248
2914
  class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2249
2915
  def __init__(self, *args, **kwargs):
2250
2916
  deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."