diffusers 0.28.2__py3-none-any.whl → 0.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. diffusers/__init__.py +9 -1
  2. diffusers/commands/env.py +1 -5
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +2 -1
  5. diffusers/loaders/__init__.py +2 -2
  6. diffusers/loaders/lora.py +406 -140
  7. diffusers/loaders/lora_conversion_utils.py +7 -1
  8. diffusers/loaders/single_file.py +1 -1
  9. diffusers/loaders/single_file_model.py +5 -0
  10. diffusers/loaders/single_file_utils.py +242 -2
  11. diffusers/loaders/unet.py +307 -272
  12. diffusers/models/__init__.py +5 -3
  13. diffusers/models/attention.py +125 -1
  14. diffusers/models/attention_processor.py +169 -1
  15. diffusers/models/autoencoders/__init__.py +1 -0
  16. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  17. diffusers/models/autoencoders/autoencoder_kl.py +17 -6
  18. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -2
  19. diffusers/models/autoencoders/consistency_decoder_vae.py +9 -9
  20. diffusers/models/autoencoders/vq_model.py +182 -0
  21. diffusers/models/controlnet_xs.py +6 -6
  22. diffusers/models/embeddings.py +112 -84
  23. diffusers/models/model_loading_utils.py +55 -0
  24. diffusers/models/modeling_utils.py +128 -17
  25. diffusers/models/normalization.py +11 -6
  26. diffusers/models/transformers/__init__.py +1 -0
  27. diffusers/models/transformers/dual_transformer_2d.py +5 -4
  28. diffusers/models/transformers/hunyuan_transformer_2d.py +149 -2
  29. diffusers/models/transformers/prior_transformer.py +5 -5
  30. diffusers/models/transformers/transformer_2d.py +2 -2
  31. diffusers/models/transformers/transformer_sd3.py +344 -0
  32. diffusers/models/transformers/transformer_temporal.py +12 -10
  33. diffusers/models/unets/unet_1d.py +3 -3
  34. diffusers/models/unets/unet_2d.py +3 -3
  35. diffusers/models/unets/unet_2d_condition.py +4 -15
  36. diffusers/models/unets/unet_3d_condition.py +5 -17
  37. diffusers/models/unets/unet_i2vgen_xl.py +4 -4
  38. diffusers/models/unets/unet_motion_model.py +4 -4
  39. diffusers/models/unets/unet_spatio_temporal_condition.py +3 -3
  40. diffusers/models/vq_model.py +8 -165
  41. diffusers/pipelines/__init__.py +2 -0
  42. diffusers/pipelines/animatediff/pipeline_animatediff.py +4 -3
  43. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +4 -3
  44. diffusers/pipelines/controlnet/pipeline_controlnet.py +4 -3
  45. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +4 -3
  46. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +4 -3
  47. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +4 -3
  48. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  49. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +4 -3
  50. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +4 -3
  51. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +4 -3
  52. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +4 -3
  53. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +4 -3
  54. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +24 -5
  55. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +4 -3
  56. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +4 -3
  57. diffusers/pipelines/marigold/marigold_image_processing.py +35 -20
  58. diffusers/pipelines/pia/pipeline_pia.py +4 -3
  59. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  60. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  61. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +17 -17
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -3
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -3
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +4 -3
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +4 -3
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -6
  69. diffusers/pipelines/stable_diffusion_3/__init__.py +52 -0
  70. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  71. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +886 -0
  72. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +923 -0
  73. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +4 -3
  74. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +10 -11
  75. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +4 -3
  76. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  77. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +4 -3
  78. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +4 -3
  79. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +4 -3
  80. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +4 -3
  81. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +4 -3
  82. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +4 -3
  83. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +4 -3
  84. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -3
  85. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  86. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +4 -3
  87. diffusers/schedulers/__init__.py +2 -0
  88. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  89. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -3
  90. diffusers/schedulers/scheduling_edm_euler.py +2 -4
  91. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +287 -0
  92. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  93. diffusers/training_utils.py +4 -4
  94. diffusers/utils/__init__.py +3 -0
  95. diffusers/utils/constants.py +2 -0
  96. diffusers/utils/dummy_pt_objects.py +30 -0
  97. diffusers/utils/dummy_torch_and_transformers_objects.py +30 -0
  98. diffusers/utils/dynamic_modules_utils.py +15 -13
  99. diffusers/utils/hub_utils.py +106 -0
  100. diffusers/utils/import_utils.py +0 -1
  101. diffusers/utils/logging.py +3 -1
  102. diffusers/utils/state_dict_utils.py +2 -0
  103. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/METADATA +45 -45
  104. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/RECORD +108 -111
  105. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/WHEEL +1 -1
  106. diffusers/models/dual_transformer_2d.py +0 -20
  107. diffusers/models/prior_transformer.py +0 -12
  108. diffusers/models/t5_film_transformer.py +0 -70
  109. diffusers/models/transformer_2d.py +0 -25
  110. diffusers/models/transformer_temporal.py +0 -34
  111. diffusers/models/unet_1d.py +0 -26
  112. diffusers/models/unet_1d_blocks.py +0 -203
  113. diffusers/models/unet_2d.py +0 -27
  114. diffusers/models/unet_2d_blocks.py +0 -375
  115. diffusers/models/unet_2d_condition.py +0 -25
  116. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/LICENSE +0 -0
  117. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/entry_points.txt +0 -0
  118. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/top_level.txt +0 -0
diffusers/loaders/unet.py CHANGED
@@ -33,34 +33,32 @@ from ..models.embeddings import (
33
33
  IPAdapterPlusImageProjection,
34
34
  MultiIPAdapterImageProjection,
35
35
  )
36
- from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
36
+ from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict
37
37
  from ..utils import (
38
38
  USE_PEFT_BACKEND,
39
39
  _get_model_file,
40
+ convert_unet_state_dict_to_peft,
40
41
  delete_adapter_layers,
42
+ get_adapter_name,
43
+ get_peft_kwargs,
41
44
  is_accelerate_available,
45
+ is_peft_version,
42
46
  is_torch_version,
43
47
  logging,
44
48
  set_adapter_layers,
45
49
  set_weights_and_activate_adapters,
46
50
  )
51
+ from .lora import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
47
52
  from .unet_loader_utils import _maybe_expand_lora_scales
48
53
  from .utils import AttnProcsLayers
49
54
 
50
55
 
51
56
  if is_accelerate_available():
52
- from accelerate import init_empty_weights
53
57
  from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
54
58
 
55
59
  logger = logging.get_logger(__name__)
56
60
 
57
61
 
58
- TEXT_ENCODER_NAME = "text_encoder"
59
- UNET_NAME = "unet"
60
-
61
- LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
62
- LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
63
-
64
62
  CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
65
63
  CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
66
64
 
@@ -79,7 +77,8 @@ class UNet2DConditionLoadersMixin:
79
77
  Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
80
78
  defined in
81
79
  [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
82
- and be a `torch.nn.Module` class.
80
+ and be a `torch.nn.Module` class. Currently supported: LoRA, Custom Diffusion. For LoRA, one must install
81
+ `peft`: `pip install -U peft`.
83
82
 
84
83
  Parameters:
85
84
  pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
@@ -110,20 +109,20 @@ class UNet2DConditionLoadersMixin:
110
109
  token (`str` or *bool*, *optional*):
111
110
  The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
112
111
  `diffusers-cli login` (stored in `~/.huggingface`) is used.
113
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
114
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
115
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
116
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
117
- argument to `True` will raise an error.
118
112
  revision (`str`, *optional*, defaults to `"main"`):
119
113
  The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
120
114
  allowed by Git.
121
115
  subfolder (`str`, *optional*, defaults to `""`):
122
116
  The subfolder location of a model file within a larger model repository on the Hub or locally.
123
- mirror (`str`, *optional*):
124
- Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
125
- guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
126
- information.
117
+ network_alphas (`Dict[str, float]`):
118
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
119
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
120
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
121
+ adapter_name (`str`, *optional*, defaults to None):
122
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
123
+ `default_{i}` where i is the total number of adapters being loaded.
124
+ weight_name (`str`, *optional*, defaults to None):
125
+ Name of the serialized state dict file.
127
126
 
128
127
  Example:
129
128
 
@@ -139,9 +138,6 @@ class UNet2DConditionLoadersMixin:
139
138
  )
140
139
  ```
141
140
  """
142
- from ..models.attention_processor import CustomDiffusionAttnProcessor
143
- from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
144
-
145
141
  cache_dir = kwargs.pop("cache_dir", None)
146
142
  force_download = kwargs.pop("force_download", False)
147
143
  resume_download = kwargs.pop("resume_download", None)
@@ -152,15 +148,9 @@ class UNet2DConditionLoadersMixin:
152
148
  subfolder = kwargs.pop("subfolder", None)
153
149
  weight_name = kwargs.pop("weight_name", None)
154
150
  use_safetensors = kwargs.pop("use_safetensors", None)
155
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
156
- # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
157
- # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
158
- network_alphas = kwargs.pop("network_alphas", None)
159
-
151
+ adapter_name = kwargs.pop("adapter_name", None)
160
152
  _pipeline = kwargs.pop("_pipeline", None)
161
-
162
- is_network_alphas_none = network_alphas is None
163
-
153
+ network_alphas = kwargs.pop("network_alphas", None)
164
154
  allow_pickle = False
165
155
 
166
156
  if use_safetensors is None:
@@ -216,198 +206,196 @@ class UNet2DConditionLoadersMixin:
216
206
  else:
217
207
  state_dict = pretrained_model_name_or_path_or_dict
218
208
 
219
- # fill attn processors
220
- lora_layers_list = []
221
-
222
- is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND
223
209
  is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
210
+ is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
211
+ is_model_cpu_offload = False
212
+ is_sequential_cpu_offload = False
224
213
 
225
- if is_lora:
226
- # correct keys
227
- state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)
214
+ if is_custom_diffusion:
215
+ attn_processors = self._process_custom_diffusion(state_dict=state_dict)
216
+ elif is_lora:
217
+ is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
218
+ state_dict=state_dict,
219
+ unet_identifier_key=self.unet_name,
220
+ network_alphas=network_alphas,
221
+ adapter_name=adapter_name,
222
+ _pipeline=_pipeline,
223
+ )
224
+ else:
225
+ raise ValueError(
226
+ f"{model_file} does not seem to be in the correct format expected by Custom Diffusion training."
227
+ )
228
228
 
229
- if network_alphas is not None:
230
- network_alphas_keys = list(network_alphas.keys())
231
- used_network_alphas_keys = set()
232
-
233
- lora_grouped_dict = defaultdict(dict)
234
- mapped_network_alphas = {}
235
-
236
- all_keys = list(state_dict.keys())
237
- for key in all_keys:
238
- value = state_dict.pop(key)
239
- attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
240
- lora_grouped_dict[attn_processor_key][sub_key] = value
241
-
242
- # Create another `mapped_network_alphas` dictionary so that we can properly map them.
243
- if network_alphas is not None:
244
- for k in network_alphas_keys:
245
- if k.replace(".alpha", "") in key:
246
- mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})
247
- used_network_alphas_keys.add(k)
248
-
249
- if not is_network_alphas_none:
250
- if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:
251
- raise ValueError(
252
- f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
253
- )
229
+ # <Unsafe code
230
+ # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
231
+ # Now we remove any existing hooks to `_pipeline`.
254
232
 
255
- if len(state_dict) > 0:
256
- raise ValueError(
257
- f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
258
- )
233
+ # For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
234
+ if is_custom_diffusion and _pipeline is not None:
235
+ is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
259
236
 
260
- for key, value_dict in lora_grouped_dict.items():
261
- attn_processor = self
262
- for sub_key in key.split("."):
263
- attn_processor = getattr(attn_processor, sub_key)
264
-
265
- # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
266
- # or add_{k,v,q,out_proj}_proj_lora layers.
267
- rank = value_dict["lora.down.weight"].shape[0]
268
-
269
- if isinstance(attn_processor, LoRACompatibleConv):
270
- in_features = attn_processor.in_channels
271
- out_features = attn_processor.out_channels
272
- kernel_size = attn_processor.kernel_size
273
-
274
- ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
275
- with ctx():
276
- lora = LoRAConv2dLayer(
277
- in_features=in_features,
278
- out_features=out_features,
279
- rank=rank,
280
- kernel_size=kernel_size,
281
- stride=attn_processor.stride,
282
- padding=attn_processor.padding,
283
- network_alpha=mapped_network_alphas.get(key),
284
- )
285
- elif isinstance(attn_processor, LoRACompatibleLinear):
286
- ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
287
- with ctx():
288
- lora = LoRALinearLayer(
289
- attn_processor.in_features,
290
- attn_processor.out_features,
291
- rank,
292
- mapped_network_alphas.get(key),
293
- )
294
- else:
295
- raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
237
+ # only custom diffusion needs to set attn processors
238
+ self.set_attn_processor(attn_processors)
239
+ self.to(dtype=self.dtype, device=self.device)
296
240
 
297
- value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
298
- lora_layers_list.append((attn_processor, lora))
241
+ # Offload back.
242
+ if is_model_cpu_offload:
243
+ _pipeline.enable_model_cpu_offload()
244
+ elif is_sequential_cpu_offload:
245
+ _pipeline.enable_sequential_cpu_offload()
246
+ # Unsafe code />
299
247
 
300
- if low_cpu_mem_usage:
301
- device = next(iter(value_dict.values())).device
302
- dtype = next(iter(value_dict.values())).dtype
303
- load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
304
- else:
305
- lora.load_state_dict(value_dict)
248
+ def _process_custom_diffusion(self, state_dict):
249
+ from ..models.attention_processor import CustomDiffusionAttnProcessor
306
250
 
307
- elif is_custom_diffusion:
308
- attn_processors = {}
309
- custom_diffusion_grouped_dict = defaultdict(dict)
310
- for key, value in state_dict.items():
311
- if len(value) == 0:
312
- custom_diffusion_grouped_dict[key] = {}
251
+ attn_processors = {}
252
+ custom_diffusion_grouped_dict = defaultdict(dict)
253
+ for key, value in state_dict.items():
254
+ if len(value) == 0:
255
+ custom_diffusion_grouped_dict[key] = {}
256
+ else:
257
+ if "to_out" in key:
258
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
313
259
  else:
314
- if "to_out" in key:
315
- attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
316
- else:
317
- attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
318
- custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
260
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
261
+ custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
319
262
 
320
- for key, value_dict in custom_diffusion_grouped_dict.items():
321
- if len(value_dict) == 0:
322
- attn_processors[key] = CustomDiffusionAttnProcessor(
323
- train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
324
- )
325
- else:
326
- cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
327
- hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
328
- train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
329
- attn_processors[key] = CustomDiffusionAttnProcessor(
330
- train_kv=True,
331
- train_q_out=train_q_out,
332
- hidden_size=hidden_size,
333
- cross_attention_dim=cross_attention_dim,
334
- )
335
- attn_processors[key].load_state_dict(value_dict)
336
- elif USE_PEFT_BACKEND:
337
- # In that case we have nothing to do as loading the adapter weights is already handled above by `set_peft_model_state_dict`
338
- # on the Unet
339
- pass
340
- else:
341
- raise ValueError(
342
- f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
343
- )
263
+ for key, value_dict in custom_diffusion_grouped_dict.items():
264
+ if len(value_dict) == 0:
265
+ attn_processors[key] = CustomDiffusionAttnProcessor(
266
+ train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
267
+ )
268
+ else:
269
+ cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
270
+ hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
271
+ train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
272
+ attn_processors[key] = CustomDiffusionAttnProcessor(
273
+ train_kv=True,
274
+ train_q_out=train_q_out,
275
+ hidden_size=hidden_size,
276
+ cross_attention_dim=cross_attention_dim,
277
+ )
278
+ attn_processors[key].load_state_dict(value_dict)
279
+
280
+ return attn_processors
281
+
282
+ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
283
+ # This method does the following things:
284
+ # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
285
+ # format. For legacy format no filtering is applied.
286
+ # 2. Converts the `state_dict` to the `peft` compatible format.
287
+ # 3. Creates a `LoraConfig` and then injects the converted `state_dict` into the UNet per the
288
+ # `LoraConfig` specs.
289
+ # 4. It also reports if the underlying `_pipeline` has any kind of offloading inside of it.
290
+ if not USE_PEFT_BACKEND:
291
+ raise ValueError("PEFT backend is required for this method.")
292
+
293
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
294
+
295
+ keys = list(state_dict.keys())
296
+
297
+ unet_keys = [k for k in keys if k.startswith(unet_identifier_key)]
298
+ unet_state_dict = {
299
+ k.replace(f"{unet_identifier_key}.", ""): v for k, v in state_dict.items() if k in unet_keys
300
+ }
301
+
302
+ if network_alphas is not None:
303
+ alpha_keys = [k for k in network_alphas.keys() if k.startswith(unet_identifier_key)]
304
+ network_alphas = {
305
+ k.replace(f"{unet_identifier_key}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
306
+ }
344
307
 
345
- # <Unsafe code
346
- # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
347
- # Now we remove any existing hooks to
348
308
  is_model_cpu_offload = False
349
309
  is_sequential_cpu_offload = False
310
+ state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
350
311
 
351
- # For PEFT backend the Unet is already offloaded at this stage as it is handled inside `load_lora_weights_into_unet`
352
- if not USE_PEFT_BACKEND:
353
- if _pipeline is not None:
354
- for _, component in _pipeline.components.items():
355
- if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
356
- is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
357
- is_sequential_cpu_offload = (
358
- isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
359
- or hasattr(component._hf_hook, "hooks")
360
- and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
361
- )
312
+ if len(state_dict_to_be_used) > 0:
313
+ if adapter_name in getattr(self, "peft_config", {}):
314
+ raise ValueError(
315
+ f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
316
+ )
362
317
 
363
- logger.info(
364
- "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
365
- )
366
- remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
318
+ state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used)
367
319
 
368
- # only custom diffusion needs to set attn processors
369
- if is_custom_diffusion:
370
- self.set_attn_processor(attn_processors)
320
+ if network_alphas is not None:
321
+ # The alphas state dict have the same structure as Unet, thus we convert it to peft format using
322
+ # `convert_unet_state_dict_to_peft` method.
323
+ network_alphas = convert_unet_state_dict_to_peft(network_alphas)
324
+
325
+ rank = {}
326
+ for key, val in state_dict.items():
327
+ if "lora_B" in key:
328
+ rank[key] = val.shape[1]
329
+
330
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
331
+ if "use_dora" in lora_config_kwargs:
332
+ if lora_config_kwargs["use_dora"]:
333
+ if is_peft_version("<", "0.9.0"):
334
+ raise ValueError(
335
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
336
+ )
337
+ else:
338
+ if is_peft_version("<", "0.9.0"):
339
+ lora_config_kwargs.pop("use_dora")
340
+ lora_config = LoraConfig(**lora_config_kwargs)
341
+
342
+ # adapter_name
343
+ if adapter_name is None:
344
+ adapter_name = get_adapter_name(self)
345
+
346
+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
347
+ # otherwise loading LoRA weights will lead to an error
348
+ is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
349
+
350
+ inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
351
+ incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
352
+
353
+ if incompatible_keys is not None:
354
+ # check only for unexpected keys
355
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
356
+ if unexpected_keys:
357
+ logger.warning(
358
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
359
+ f" {unexpected_keys}. "
360
+ )
371
361
 
372
- # set lora layers
373
- for target_module, lora_layer in lora_layers_list:
374
- target_module.set_lora_layer(lora_layer)
362
+ return is_model_cpu_offload, is_sequential_cpu_offload
375
363
 
376
- self.to(dtype=self.dtype, device=self.device)
377
-
378
- # Offload back.
379
- if is_model_cpu_offload:
380
- _pipeline.enable_model_cpu_offload()
381
- elif is_sequential_cpu_offload:
382
- _pipeline.enable_sequential_cpu_offload()
383
- # Unsafe code />
364
+ @classmethod
365
+ # Copied from diffusers.loaders.lora.LoraLoaderMixin._optionally_disable_offloading
366
+ def _optionally_disable_offloading(cls, _pipeline):
367
+ """
368
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
384
369
 
385
- def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
386
- is_new_lora_format = all(
387
- key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
388
- )
389
- if is_new_lora_format:
390
- # Strip the `"unet"` prefix.
391
- is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
392
- if is_text_encoder_present:
393
- warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
394
- logger.warning(warn_message)
395
- unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
396
- state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
370
+ Args:
371
+ _pipeline (`DiffusionPipeline`):
372
+ The pipeline to disable offloading for.
397
373
 
398
- # change processor format to 'pure' LoRACompatibleLinear format
399
- if any("processor" in k.split(".") for k in state_dict.keys()):
374
+ Returns:
375
+ tuple:
376
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
377
+ """
378
+ is_model_cpu_offload = False
379
+ is_sequential_cpu_offload = False
400
380
 
401
- def format_to_lora_compatible(key):
402
- if "processor" not in key.split("."):
403
- return key
404
- return key.replace(".processor", "").replace("to_out_lora", "to_out.0.lora").replace("_lora", ".lora")
381
+ if _pipeline is not None and _pipeline.hf_device_map is None:
382
+ for _, component in _pipeline.components.items():
383
+ if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
384
+ if not is_model_cpu_offload:
385
+ is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
386
+ if not is_sequential_cpu_offload:
387
+ is_sequential_cpu_offload = (
388
+ isinstance(component._hf_hook, AlignDevicesHook)
389
+ or hasattr(component._hf_hook, "hooks")
390
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
391
+ )
405
392
 
406
- state_dict = {format_to_lora_compatible(k): v for k, v in state_dict.items()}
393
+ logger.info(
394
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
395
+ )
396
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
407
397
 
408
- if network_alphas is not None:
409
- network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()}
410
- return state_dict, network_alphas
398
+ return (is_model_cpu_offload, is_sequential_cpu_offload)
411
399
 
412
400
  def save_attn_procs(
413
401
  self,
@@ -460,6 +448,23 @@ class UNet2DConditionLoadersMixin:
460
448
  logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
461
449
  return
462
450
 
451
+ is_custom_diffusion = any(
452
+ isinstance(
453
+ x,
454
+ (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
455
+ )
456
+ for (_, x) in self.attn_processors.items()
457
+ )
458
+ if is_custom_diffusion:
459
+ state_dict = self._get_custom_diffusion_state_dict()
460
+ else:
461
+ if not USE_PEFT_BACKEND:
462
+ raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
463
+
464
+ from peft.utils import get_peft_model_state_dict
465
+
466
+ state_dict = get_peft_model_state_dict(self)
467
+
463
468
  if save_function is None:
464
469
  if safe_serialization:
465
470
 
@@ -471,36 +476,6 @@ class UNet2DConditionLoadersMixin:
471
476
 
472
477
  os.makedirs(save_directory, exist_ok=True)
473
478
 
474
- is_custom_diffusion = any(
475
- isinstance(
476
- x,
477
- (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
478
- )
479
- for (_, x) in self.attn_processors.items()
480
- )
481
- if is_custom_diffusion:
482
- model_to_save = AttnProcsLayers(
483
- {
484
- y: x
485
- for (y, x) in self.attn_processors.items()
486
- if isinstance(
487
- x,
488
- (
489
- CustomDiffusionAttnProcessor,
490
- CustomDiffusionAttnProcessor2_0,
491
- CustomDiffusionXFormersAttnProcessor,
492
- ),
493
- )
494
- }
495
- )
496
- state_dict = model_to_save.state_dict()
497
- for name, attn in self.attn_processors.items():
498
- if len(attn.state_dict()) == 0:
499
- state_dict[name] = {}
500
- else:
501
- model_to_save = AttnProcsLayers(self.attn_processors)
502
- state_dict = model_to_save.state_dict()
503
-
504
479
  if weight_name is None:
505
480
  if safe_serialization:
506
481
  weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
@@ -512,56 +487,84 @@ class UNet2DConditionLoadersMixin:
512
487
  save_function(state_dict, save_path)
513
488
  logger.info(f"Model weights saved in {save_path}")
514
489
 
490
+ def _get_custom_diffusion_state_dict(self):
491
+ from ..models.attention_processor import (
492
+ CustomDiffusionAttnProcessor,
493
+ CustomDiffusionAttnProcessor2_0,
494
+ CustomDiffusionXFormersAttnProcessor,
495
+ )
496
+
497
+ model_to_save = AttnProcsLayers(
498
+ {
499
+ y: x
500
+ for (y, x) in self.attn_processors.items()
501
+ if isinstance(
502
+ x,
503
+ (
504
+ CustomDiffusionAttnProcessor,
505
+ CustomDiffusionAttnProcessor2_0,
506
+ CustomDiffusionXFormersAttnProcessor,
507
+ ),
508
+ )
509
+ }
510
+ )
511
+ state_dict = model_to_save.state_dict()
512
+ for name, attn in self.attn_processors.items():
513
+ if len(attn.state_dict()) == 0:
514
+ state_dict[name] = {}
515
+
516
+ return state_dict
517
+
515
518
  def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
519
+ if not USE_PEFT_BACKEND:
520
+ raise ValueError("PEFT backend is required for `fuse_lora()`.")
521
+
516
522
  self.lora_scale = lora_scale
517
523
  self._safe_fusing = safe_fusing
518
524
  self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
519
525
 
520
526
  def _fuse_lora_apply(self, module, adapter_names=None):
521
- if not USE_PEFT_BACKEND:
522
- if hasattr(module, "_fuse_lora"):
523
- module._fuse_lora(self.lora_scale, self._safe_fusing)
527
+ from peft.tuners.tuners_utils import BaseTunerLayer
528
+
529
+ merge_kwargs = {"safe_merge": self._safe_fusing}
530
+
531
+ if isinstance(module, BaseTunerLayer):
532
+ if self.lora_scale != 1.0:
533
+ module.scale_layer(self.lora_scale)
524
534
 
525
- if adapter_names is not None:
535
+ # For BC with prevous PEFT versions, we need to check the signature
536
+ # of the `merge` method to see if it supports the `adapter_names` argument.
537
+ supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
538
+ if "adapter_names" in supported_merge_kwargs:
539
+ merge_kwargs["adapter_names"] = adapter_names
540
+ elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
526
541
  raise ValueError(
527
- "The `adapter_names` argument is not supported in your environment. Please switch"
528
- " to PEFT backend to use this argument by installing latest PEFT and transformers."
529
- " `pip install -U peft transformers`"
542
+ "The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
543
+ " to the latest version of PEFT. `pip install -U peft`"
530
544
  )
531
- else:
532
- from peft.tuners.tuners_utils import BaseTunerLayer
533
-
534
- merge_kwargs = {"safe_merge": self._safe_fusing}
535
-
536
- if isinstance(module, BaseTunerLayer):
537
- if self.lora_scale != 1.0:
538
- module.scale_layer(self.lora_scale)
539
-
540
- # For BC with prevous PEFT versions, we need to check the signature
541
- # of the `merge` method to see if it supports the `adapter_names` argument.
542
- supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
543
- if "adapter_names" in supported_merge_kwargs:
544
- merge_kwargs["adapter_names"] = adapter_names
545
- elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
546
- raise ValueError(
547
- "The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
548
- " to the latest version of PEFT. `pip install -U peft`"
549
- )
550
545
 
551
- module.merge(**merge_kwargs)
546
+ module.merge(**merge_kwargs)
552
547
 
553
548
  def unfuse_lora(self):
549
+ if not USE_PEFT_BACKEND:
550
+ raise ValueError("PEFT backend is required for `unfuse_lora()`.")
554
551
  self.apply(self._unfuse_lora_apply)
555
552
 
556
553
  def _unfuse_lora_apply(self, module):
554
+ from peft.tuners.tuners_utils import BaseTunerLayer
555
+
556
+ if isinstance(module, BaseTunerLayer):
557
+ module.unmerge()
558
+
559
+ def unload_lora(self):
557
560
  if not USE_PEFT_BACKEND:
558
- if hasattr(module, "_unfuse_lora"):
559
- module._unfuse_lora()
560
- else:
561
- from peft.tuners.tuners_utils import BaseTunerLayer
561
+ raise ValueError("PEFT backend is required for `unload_lora()`.")
562
+
563
+ from ..utils import recurse_remove_peft_layers
562
564
 
563
- if isinstance(module, BaseTunerLayer):
564
- module.unmerge()
565
+ recurse_remove_peft_layers(self)
566
+ if hasattr(self, "peft_config"):
567
+ del self.peft_config
565
568
 
566
569
  def set_adapters(
567
570
  self,
@@ -847,7 +850,12 @@ class UNet2DConditionLoadersMixin:
847
850
  embed_dims = state_dict["proj_in.weight"].shape[1]
848
851
  output_dims = state_dict["proj_out.weight"].shape[0]
849
852
  hidden_dims = state_dict["latents"].shape[2]
850
- heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
853
+ attn_key_present = any("attn" in k for k in state_dict)
854
+ heads = (
855
+ state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
856
+ if attn_key_present
857
+ else state_dict["layers.0.0.to_q.weight"].shape[0] // 64
858
+ )
851
859
 
852
860
  with init_context():
853
861
  image_projection = IPAdapterPlusImageProjection(
@@ -860,26 +868,53 @@ class UNet2DConditionLoadersMixin:
860
868
 
861
869
  for key, value in state_dict.items():
862
870
  diffusers_name = key.replace("0.to", "2.to")
863
- diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight")
864
- diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias")
865
- diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight")
866
- diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight")
867
871
 
868
- if "norm1" in diffusers_name:
869
- updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
870
- elif "norm2" in diffusers_name:
871
- updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
872
- elif "to_kv" in diffusers_name:
872
+ diffusers_name = diffusers_name.replace("0.0.norm1", "0.ln0")
873
+ diffusers_name = diffusers_name.replace("0.0.norm2", "0.ln1")
874
+ diffusers_name = diffusers_name.replace("1.0.norm1", "1.ln0")
875
+ diffusers_name = diffusers_name.replace("1.0.norm2", "1.ln1")
876
+ diffusers_name = diffusers_name.replace("2.0.norm1", "2.ln0")
877
+ diffusers_name = diffusers_name.replace("2.0.norm2", "2.ln1")
878
+ diffusers_name = diffusers_name.replace("3.0.norm1", "3.ln0")
879
+ diffusers_name = diffusers_name.replace("3.0.norm2", "3.ln1")
880
+
881
+ if "to_kv" in diffusers_name:
882
+ parts = diffusers_name.split(".")
883
+ parts[2] = "attn"
884
+ diffusers_name = ".".join(parts)
873
885
  v_chunk = value.chunk(2, dim=0)
874
886
  updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
875
887
  updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
888
+ elif "to_q" in diffusers_name:
889
+ parts = diffusers_name.split(".")
890
+ parts[2] = "attn"
891
+ diffusers_name = ".".join(parts)
892
+ updated_state_dict[diffusers_name] = value
876
893
  elif "to_out" in diffusers_name:
894
+ parts = diffusers_name.split(".")
895
+ parts[2] = "attn"
896
+ diffusers_name = ".".join(parts)
877
897
  updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
878
898
  else:
899
+ diffusers_name = diffusers_name.replace("0.1.0", "0.ff.0")
900
+ diffusers_name = diffusers_name.replace("0.1.1", "0.ff.1.net.0.proj")
901
+ diffusers_name = diffusers_name.replace("0.1.3", "0.ff.1.net.2")
902
+
903
+ diffusers_name = diffusers_name.replace("1.1.0", "1.ff.0")
904
+ diffusers_name = diffusers_name.replace("1.1.1", "1.ff.1.net.0.proj")
905
+ diffusers_name = diffusers_name.replace("1.1.3", "1.ff.1.net.2")
906
+
907
+ diffusers_name = diffusers_name.replace("2.1.0", "2.ff.0")
908
+ diffusers_name = diffusers_name.replace("2.1.1", "2.ff.1.net.0.proj")
909
+ diffusers_name = diffusers_name.replace("2.1.3", "2.ff.1.net.2")
910
+
911
+ diffusers_name = diffusers_name.replace("3.1.0", "3.ff.0")
912
+ diffusers_name = diffusers_name.replace("3.1.1", "3.ff.1.net.0.proj")
913
+ diffusers_name = diffusers_name.replace("3.1.3", "3.ff.1.net.2")
879
914
  updated_state_dict[diffusers_name] = value
880
915
 
881
916
  if not low_cpu_mem_usage:
882
- image_projection.load_state_dict(updated_state_dict)
917
+ image_projection.load_state_dict(updated_state_dict, strict=True)
883
918
  else:
884
919
  load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
885
920