diffusers 0.28.2__py3-none-any.whl → 0.29.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. diffusers/__init__.py +15 -1
  2. diffusers/commands/env.py +1 -5
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +2 -1
  5. diffusers/loaders/__init__.py +2 -2
  6. diffusers/loaders/lora.py +406 -140
  7. diffusers/loaders/lora_conversion_utils.py +7 -1
  8. diffusers/loaders/single_file.py +13 -1
  9. diffusers/loaders/single_file_model.py +15 -8
  10. diffusers/loaders/single_file_utils.py +267 -17
  11. diffusers/loaders/unet.py +307 -272
  12. diffusers/models/__init__.py +7 -3
  13. diffusers/models/attention.py +125 -1
  14. diffusers/models/attention_processor.py +169 -1
  15. diffusers/models/autoencoders/__init__.py +1 -0
  16. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  17. diffusers/models/autoencoders/autoencoder_kl.py +17 -6
  18. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -2
  19. diffusers/models/autoencoders/consistency_decoder_vae.py +9 -9
  20. diffusers/models/autoencoders/vq_model.py +182 -0
  21. diffusers/models/controlnet_sd3.py +418 -0
  22. diffusers/models/controlnet_xs.py +6 -6
  23. diffusers/models/embeddings.py +112 -84
  24. diffusers/models/model_loading_utils.py +55 -0
  25. diffusers/models/modeling_utils.py +138 -20
  26. diffusers/models/normalization.py +11 -6
  27. diffusers/models/transformers/__init__.py +1 -0
  28. diffusers/models/transformers/dual_transformer_2d.py +5 -4
  29. diffusers/models/transformers/hunyuan_transformer_2d.py +149 -2
  30. diffusers/models/transformers/prior_transformer.py +5 -5
  31. diffusers/models/transformers/transformer_2d.py +2 -2
  32. diffusers/models/transformers/transformer_sd3.py +353 -0
  33. diffusers/models/transformers/transformer_temporal.py +12 -10
  34. diffusers/models/unets/unet_1d.py +3 -3
  35. diffusers/models/unets/unet_2d.py +3 -3
  36. diffusers/models/unets/unet_2d_condition.py +4 -15
  37. diffusers/models/unets/unet_3d_condition.py +5 -17
  38. diffusers/models/unets/unet_i2vgen_xl.py +4 -4
  39. diffusers/models/unets/unet_motion_model.py +4 -4
  40. diffusers/models/unets/unet_spatio_temporal_condition.py +3 -3
  41. diffusers/models/vq_model.py +8 -165
  42. diffusers/pipelines/__init__.py +11 -0
  43. diffusers/pipelines/animatediff/pipeline_animatediff.py +4 -3
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +4 -3
  45. diffusers/pipelines/auto_pipeline.py +8 -0
  46. diffusers/pipelines/controlnet/pipeline_controlnet.py +4 -3
  47. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +4 -3
  48. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +4 -3
  49. diffusers/pipelines/controlnet_sd3/__init__.py +53 -0
  50. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1062 -0
  51. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +4 -3
  52. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  53. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +4 -3
  54. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +4 -3
  55. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +4 -3
  56. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +4 -3
  57. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +4 -3
  58. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +24 -5
  59. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +4 -3
  60. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +4 -3
  61. diffusers/pipelines/marigold/marigold_image_processing.py +35 -20
  62. diffusers/pipelines/pia/pipeline_pia.py +4 -3
  63. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  64. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  65. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +17 -17
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -3
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -3
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -3
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +4 -3
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +4 -3
  72. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -6
  73. diffusers/pipelines/stable_diffusion_3/__init__.py +52 -0
  74. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  75. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +904 -0
  76. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +941 -0
  77. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +4 -3
  78. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +10 -11
  79. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +4 -3
  80. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  81. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +4 -3
  82. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +4 -3
  83. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +4 -3
  84. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +4 -3
  85. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +4 -3
  86. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +4 -3
  87. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +4 -3
  88. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -3
  89. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  90. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +4 -3
  91. diffusers/schedulers/__init__.py +2 -0
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  93. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -3
  94. diffusers/schedulers/scheduling_edm_euler.py +2 -4
  95. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +287 -0
  96. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  97. diffusers/training_utils.py +4 -4
  98. diffusers/utils/__init__.py +3 -0
  99. diffusers/utils/constants.py +2 -0
  100. diffusers/utils/dummy_pt_objects.py +60 -0
  101. diffusers/utils/dummy_torch_and_transformers_objects.py +45 -0
  102. diffusers/utils/dynamic_modules_utils.py +15 -13
  103. diffusers/utils/hub_utils.py +106 -0
  104. diffusers/utils/import_utils.py +0 -1
  105. diffusers/utils/logging.py +3 -1
  106. diffusers/utils/state_dict_utils.py +2 -0
  107. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/METADATA +3 -3
  108. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/RECORD +112 -112
  109. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/WHEEL +1 -1
  110. diffusers/models/dual_transformer_2d.py +0 -20
  111. diffusers/models/prior_transformer.py +0 -12
  112. diffusers/models/t5_film_transformer.py +0 -70
  113. diffusers/models/transformer_2d.py +0 -25
  114. diffusers/models/transformer_temporal.py +0 -34
  115. diffusers/models/unet_1d.py +0 -26
  116. diffusers/models/unet_1d_blocks.py +0 -203
  117. diffusers/models/unet_2d.py +0 -27
  118. diffusers/models/unet_2d_blocks.py +0 -375
  119. diffusers/models/unet_2d_condition.py +0 -25
  120. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/LICENSE +0 -0
  121. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/entry_points.txt +0 -0
  122. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/top_level.txt +0 -0
diffusers/loaders/lora.py CHANGED
@@ -22,17 +22,14 @@ import torch
22
22
  from huggingface_hub import model_info
23
23
  from huggingface_hub.constants import HF_HUB_OFFLINE
24
24
  from huggingface_hub.utils import validate_hf_hub_args
25
- from packaging import version
26
25
  from torch import nn
27
26
 
28
- from .. import __version__
29
- from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
27
+ from ..models.modeling_utils import load_state_dict
30
28
  from ..utils import (
31
29
  USE_PEFT_BACKEND,
32
30
  _get_model_file,
33
31
  convert_state_dict_to_diffusers,
34
32
  convert_state_dict_to_peft,
35
- convert_unet_state_dict_to_peft,
36
33
  delete_adapter_layers,
37
34
  get_adapter_name,
38
35
  get_peft_kwargs,
@@ -119,13 +116,10 @@ class LoraLoaderMixin:
119
116
  if not is_correct_format:
120
117
  raise ValueError("Invalid LoRA checkpoint.")
121
118
 
122
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
123
-
124
119
  self.load_lora_into_unet(
125
120
  state_dict,
126
121
  network_alphas=network_alphas,
127
122
  unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
128
- low_cpu_mem_usage=low_cpu_mem_usage,
129
123
  adapter_name=adapter_name,
130
124
  _pipeline=self,
131
125
  )
@@ -136,7 +130,6 @@ class LoraLoaderMixin:
136
130
  if not hasattr(self, "text_encoder")
137
131
  else self.text_encoder,
138
132
  lora_scale=self.lora_scale,
139
- low_cpu_mem_usage=low_cpu_mem_usage,
140
133
  adapter_name=adapter_name,
141
134
  _pipeline=self,
142
135
  )
@@ -193,16 +186,8 @@ class LoraLoaderMixin:
193
186
  allowed by Git.
194
187
  subfolder (`str`, *optional*, defaults to `""`):
195
188
  The subfolder location of a model file within a larger model repository on the Hub or locally.
196
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
197
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
198
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
199
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
200
- argument to `True` will raise an error.
201
- mirror (`str`, *optional*):
202
- Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
203
- guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
204
- information.
205
-
189
+ weight_name (`str`, *optional*, defaults to None):
190
+ Name of the serialized state dict file.
206
191
  """
207
192
  # Load the main state dict first which has the LoRA layers for either of
208
193
  # UNet and text encoder or both.
@@ -383,9 +368,7 @@ class LoraLoaderMixin:
383
368
  return (is_model_cpu_offload, is_sequential_cpu_offload)
384
369
 
385
370
  @classmethod
386
- def load_lora_into_unet(
387
- cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
388
- ):
371
+ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
389
372
  """
390
373
  This will load the LoRA layers specified in `state_dict` into `unet`.
391
374
 
@@ -395,14 +378,11 @@ class LoraLoaderMixin:
395
378
  into the unet or prefixed with an additional `unet` which can be used to distinguish between text
396
379
  encoder lora layers.
397
380
  network_alphas (`Dict[str, float]`):
398
- See `LoRALinearLayer` for more details.
381
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
382
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
383
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
399
384
  unet (`UNet2DConditionModel`):
400
385
  The UNet model to load the LoRA layers into.
401
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
402
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
403
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
404
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
405
- argument to `True` will raise an error.
406
386
  adapter_name (`str`, *optional*):
407
387
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
408
388
  `default_{i}` where i is the total number of adapters being loaded.
@@ -410,94 +390,18 @@ class LoraLoaderMixin:
410
390
  if not USE_PEFT_BACKEND:
411
391
  raise ValueError("PEFT backend is required for this method.")
412
392
 
413
- from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
414
-
415
- low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
416
393
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
417
394
  # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
418
395
  # their prefixes.
419
396
  keys = list(state_dict.keys())
397
+ only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
420
398
 
421
- if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
399
+ if any(key.startswith(cls.unet_name) for key in keys) and not only_text_encoder:
422
400
  # Load the layers corresponding to UNet.
423
401
  logger.info(f"Loading {cls.unet_name}.")
424
-
425
- unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
426
- state_dict = {k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
427
-
428
- if network_alphas is not None:
429
- alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)]
430
- network_alphas = {
431
- k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
432
- }
433
-
434
- else:
435
- # Otherwise, we're dealing with the old format. This means the `state_dict` should only
436
- # contain the module names of the `unet` as its keys WITHOUT any prefix.
437
- if not USE_PEFT_BACKEND:
438
- warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
439
- logger.warning(warn_message)
440
-
441
- if len(state_dict.keys()) > 0:
442
- if adapter_name in getattr(unet, "peft_config", {}):
443
- raise ValueError(
444
- f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
445
- )
446
-
447
- state_dict = convert_unet_state_dict_to_peft(state_dict)
448
-
449
- if network_alphas is not None:
450
- # The alphas state dict have the same structure as Unet, thus we convert it to peft format using
451
- # `convert_unet_state_dict_to_peft` method.
452
- network_alphas = convert_unet_state_dict_to_peft(network_alphas)
453
-
454
- rank = {}
455
- for key, val in state_dict.items():
456
- if "lora_B" in key:
457
- rank[key] = val.shape[1]
458
-
459
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
460
- if "use_dora" in lora_config_kwargs:
461
- if lora_config_kwargs["use_dora"]:
462
- if is_peft_version("<", "0.9.0"):
463
- raise ValueError(
464
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
465
- )
466
- else:
467
- if is_peft_version("<", "0.9.0"):
468
- lora_config_kwargs.pop("use_dora")
469
- lora_config = LoraConfig(**lora_config_kwargs)
470
-
471
- # adapter_name
472
- if adapter_name is None:
473
- adapter_name = get_adapter_name(unet)
474
-
475
- # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
476
- # otherwise loading LoRA weights will lead to an error
477
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
478
-
479
- inject_adapter_in_model(lora_config, unet, adapter_name=adapter_name)
480
- incompatible_keys = set_peft_model_state_dict(unet, state_dict, adapter_name)
481
-
482
- if incompatible_keys is not None:
483
- # check only for unexpected keys
484
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
485
- if unexpected_keys:
486
- logger.warning(
487
- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
488
- f" {unexpected_keys}. "
489
- )
490
-
491
- # Offload back.
492
- if is_model_cpu_offload:
493
- _pipeline.enable_model_cpu_offload()
494
- elif is_sequential_cpu_offload:
495
- _pipeline.enable_sequential_cpu_offload()
496
- # Unsafe code />
497
-
498
- unet.load_attn_procs(
499
- state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
500
- )
402
+ unet.load_attn_procs(
403
+ state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
404
+ )
501
405
 
502
406
  @classmethod
503
407
  def load_lora_into_text_encoder(
@@ -507,7 +411,6 @@ class LoraLoaderMixin:
507
411
  text_encoder,
508
412
  prefix=None,
509
413
  lora_scale=1.0,
510
- low_cpu_mem_usage=None,
511
414
  adapter_name=None,
512
415
  _pipeline=None,
513
416
  ):
@@ -527,11 +430,6 @@ class LoraLoaderMixin:
527
430
  lora_scale (`float`):
528
431
  How much to scale the output of the lora linear layer before it is added with the output of the regular
529
432
  lora layer.
530
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
531
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
532
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
533
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
534
- argument to `True` will raise an error.
535
433
  adapter_name (`str`, *optional*):
536
434
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
537
435
  `default_{i}` where i is the total number of adapters being loaded.
@@ -541,8 +439,6 @@ class LoraLoaderMixin:
541
439
 
542
440
  from peft import LoraConfig
543
441
 
544
- low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
545
-
546
442
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
547
443
  # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
548
444
  # their prefixes.
@@ -625,9 +521,7 @@ class LoraLoaderMixin:
625
521
  # Unsafe code />
626
522
 
627
523
  @classmethod
628
- def load_lora_into_transformer(
629
- cls, state_dict, network_alphas, transformer, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
630
- ):
524
+ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
631
525
  """
632
526
  This will load the LoRA layers specified in `state_dict` into `transformer`.
633
527
 
@@ -640,19 +534,12 @@ class LoraLoaderMixin:
640
534
  See `LoRALinearLayer` for more details.
641
535
  unet (`UNet2DConditionModel`):
642
536
  The UNet model to load the LoRA layers into.
643
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
644
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
645
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
646
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
647
- argument to `True` will raise an error.
648
537
  adapter_name (`str`, *optional*):
649
538
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
650
539
  `default_{i}` where i is the total number of adapters being loaded.
651
540
  """
652
541
  from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
653
542
 
654
- low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
655
-
656
543
  keys = list(state_dict.keys())
657
544
 
658
545
  transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
@@ -846,22 +733,11 @@ class LoraLoaderMixin:
846
733
  >>> ...
847
734
  ```
848
735
  """
849
- unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
850
-
851
736
  if not USE_PEFT_BACKEND:
852
- if version.parse(__version__) > version.parse("0.23"):
853
- logger.warning(
854
- "You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
855
- "you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
856
- )
737
+ raise ValueError("PEFT backend is required for this method.")
857
738
 
858
- for _, module in unet.named_modules():
859
- if hasattr(module, "set_lora_layer"):
860
- module.set_lora_layer(None)
861
- else:
862
- recurse_remove_peft_layers(unet)
863
- if hasattr(unet, "peft_config"):
864
- del unet.peft_config
739
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
740
+ unet.unload_lora()
865
741
 
866
742
  # Safe to call the following regardless of LoRA.
867
743
  self._remove_text_encoder_monkey_patch()
@@ -1461,3 +1337,393 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
1461
1337
  if getattr(self.text_encoder_2, "peft_config", None) is not None:
1462
1338
  del self.text_encoder_2.peft_config
1463
1339
  self.text_encoder_2._hf_peft_config_loaded = None
1340
+
1341
+
1342
+ class SD3LoraLoaderMixin:
1343
+ r"""
1344
+ Load LoRA layers into [`SD3Transformer2DModel`].
1345
+ """
1346
+
1347
+ transformer_name = TRANSFORMER_NAME
1348
+ num_fused_loras = 0
1349
+
1350
+ def load_lora_weights(
1351
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
1352
+ ):
1353
+ """
1354
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
1355
+ `self.text_encoder`.
1356
+
1357
+ All kwargs are forwarded to `self.lora_state_dict`.
1358
+
1359
+ See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
1360
+
1361
+ See [`~loaders.LoraLoaderMixin.load_lora_into_transformer`] for more details on how the state dict is loaded
1362
+ into `self.transformer`.
1363
+
1364
+ Parameters:
1365
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1366
+ See [`~loaders.LoraLoaderMixin.lora_state_dict`].
1367
+ kwargs (`dict`, *optional*):
1368
+ See [`~loaders.LoraLoaderMixin.lora_state_dict`].
1369
+ adapter_name (`str`, *optional*):
1370
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1371
+ `default_{i}` where i is the total number of adapters being loaded.
1372
+ """
1373
+ if not USE_PEFT_BACKEND:
1374
+ raise ValueError("PEFT backend is required for this method.")
1375
+
1376
+ # if a dict is passed, copy it instead of modifying it inplace
1377
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
1378
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
1379
+
1380
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1381
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
1382
+
1383
+ is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
1384
+ if not is_correct_format:
1385
+ raise ValueError("Invalid LoRA checkpoint.")
1386
+
1387
+ self.load_lora_into_transformer(
1388
+ state_dict,
1389
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
1390
+ adapter_name=adapter_name,
1391
+ _pipeline=self,
1392
+ )
1393
+
1394
+ @classmethod
1395
+ @validate_hf_hub_args
1396
+ def lora_state_dict(
1397
+ cls,
1398
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
1399
+ **kwargs,
1400
+ ):
1401
+ r"""
1402
+ Return state dict for lora weights and the network alphas.
1403
+
1404
+ <Tip warning={true}>
1405
+
1406
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
1407
+
1408
+ This function is experimental and might change in the future.
1409
+
1410
+ </Tip>
1411
+
1412
+ Parameters:
1413
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1414
+ Can be either:
1415
+
1416
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1417
+ the Hub.
1418
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1419
+ with [`ModelMixin.save_pretrained`].
1420
+ - A [torch state
1421
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
1422
+
1423
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1424
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1425
+ is not used.
1426
+ force_download (`bool`, *optional*, defaults to `False`):
1427
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1428
+ cached versions if they exist.
1429
+ resume_download (`bool`, *optional*, defaults to `False`):
1430
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1431
+ incompletely downloaded files are deleted.
1432
+ proxies (`Dict[str, str]`, *optional*):
1433
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1434
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1435
+ local_files_only (`bool`, *optional*, defaults to `False`):
1436
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1437
+ won't be downloaded from the Hub.
1438
+ token (`str` or *bool*, *optional*):
1439
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1440
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1441
+ revision (`str`, *optional*, defaults to `"main"`):
1442
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1443
+ allowed by Git.
1444
+ subfolder (`str`, *optional*, defaults to `""`):
1445
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1446
+
1447
+ """
1448
+ # Load the main state dict first which has the LoRA layers for either of
1449
+ # UNet and text encoder or both.
1450
+ cache_dir = kwargs.pop("cache_dir", None)
1451
+ force_download = kwargs.pop("force_download", False)
1452
+ resume_download = kwargs.pop("resume_download", False)
1453
+ proxies = kwargs.pop("proxies", None)
1454
+ local_files_only = kwargs.pop("local_files_only", None)
1455
+ token = kwargs.pop("token", None)
1456
+ revision = kwargs.pop("revision", None)
1457
+ subfolder = kwargs.pop("subfolder", None)
1458
+ weight_name = kwargs.pop("weight_name", None)
1459
+ use_safetensors = kwargs.pop("use_safetensors", None)
1460
+
1461
+ allow_pickle = False
1462
+ if use_safetensors is None:
1463
+ use_safetensors = True
1464
+ allow_pickle = True
1465
+
1466
+ user_agent = {
1467
+ "file_type": "attn_procs_weights",
1468
+ "framework": "pytorch",
1469
+ }
1470
+
1471
+ model_file = None
1472
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
1473
+ # Let's first try to load .safetensors weights
1474
+ if (use_safetensors and weight_name is None) or (
1475
+ weight_name is not None and weight_name.endswith(".safetensors")
1476
+ ):
1477
+ try:
1478
+ model_file = _get_model_file(
1479
+ pretrained_model_name_or_path_or_dict,
1480
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
1481
+ cache_dir=cache_dir,
1482
+ force_download=force_download,
1483
+ resume_download=resume_download,
1484
+ proxies=proxies,
1485
+ local_files_only=local_files_only,
1486
+ token=token,
1487
+ revision=revision,
1488
+ subfolder=subfolder,
1489
+ user_agent=user_agent,
1490
+ )
1491
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
1492
+ except (IOError, safetensors.SafetensorError) as e:
1493
+ if not allow_pickle:
1494
+ raise e
1495
+ # try loading non-safetensors weights
1496
+ model_file = None
1497
+ pass
1498
+
1499
+ if model_file is None:
1500
+ model_file = _get_model_file(
1501
+ pretrained_model_name_or_path_or_dict,
1502
+ weights_name=weight_name or LORA_WEIGHT_NAME,
1503
+ cache_dir=cache_dir,
1504
+ force_download=force_download,
1505
+ resume_download=resume_download,
1506
+ proxies=proxies,
1507
+ local_files_only=local_files_only,
1508
+ token=token,
1509
+ revision=revision,
1510
+ subfolder=subfolder,
1511
+ user_agent=user_agent,
1512
+ )
1513
+ state_dict = load_state_dict(model_file)
1514
+ else:
1515
+ state_dict = pretrained_model_name_or_path_or_dict
1516
+
1517
+ return state_dict
1518
+
1519
+ @classmethod
1520
+ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
1521
+ """
1522
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
1523
+
1524
+ Parameters:
1525
+ state_dict (`dict`):
1526
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
1527
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
1528
+ encoder lora layers.
1529
+ transformer (`SD3Transformer2DModel`):
1530
+ The Transformer model to load the LoRA layers into.
1531
+ adapter_name (`str`, *optional*):
1532
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1533
+ `default_{i}` where i is the total number of adapters being loaded.
1534
+ """
1535
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
1536
+
1537
+ keys = list(state_dict.keys())
1538
+
1539
+ transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
1540
+ state_dict = {
1541
+ k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
1542
+ }
1543
+
1544
+ if len(state_dict.keys()) > 0:
1545
+ if adapter_name in getattr(transformer, "peft_config", {}):
1546
+ raise ValueError(
1547
+ f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
1548
+ )
1549
+
1550
+ rank = {}
1551
+ for key, val in state_dict.items():
1552
+ if "lora_B" in key:
1553
+ rank[key] = val.shape[1]
1554
+
1555
+ lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
1556
+ if "use_dora" in lora_config_kwargs:
1557
+ if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
1558
+ raise ValueError(
1559
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1560
+ )
1561
+ else:
1562
+ lora_config_kwargs.pop("use_dora")
1563
+ lora_config = LoraConfig(**lora_config_kwargs)
1564
+
1565
+ # adapter_name
1566
+ if adapter_name is None:
1567
+ adapter_name = get_adapter_name(transformer)
1568
+
1569
+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
1570
+ # otherwise loading LoRA weights will lead to an error
1571
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
1572
+
1573
+ inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
1574
+ incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
1575
+
1576
+ if incompatible_keys is not None:
1577
+ # check only for unexpected keys
1578
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1579
+ if unexpected_keys:
1580
+ logger.warning(
1581
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1582
+ f" {unexpected_keys}. "
1583
+ )
1584
+
1585
+ # Offload back.
1586
+ if is_model_cpu_offload:
1587
+ _pipeline.enable_model_cpu_offload()
1588
+ elif is_sequential_cpu_offload:
1589
+ _pipeline.enable_sequential_cpu_offload()
1590
+ # Unsafe code />
1591
+
1592
+ @classmethod
1593
+ def save_lora_weights(
1594
+ cls,
1595
+ save_directory: Union[str, os.PathLike],
1596
+ transformer_lora_layers: Dict[str, torch.nn.Module] = None,
1597
+ is_main_process: bool = True,
1598
+ weight_name: str = None,
1599
+ save_function: Callable = None,
1600
+ safe_serialization: bool = True,
1601
+ ):
1602
+ r"""
1603
+ Save the LoRA parameters corresponding to the UNet and text encoder.
1604
+
1605
+ Arguments:
1606
+ save_directory (`str` or `os.PathLike`):
1607
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
1608
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1609
+ State dict of the LoRA layers corresponding to the `transformer`.
1610
+ is_main_process (`bool`, *optional*, defaults to `True`):
1611
+ Whether the process calling this is the main process or not. Useful during distributed training and you
1612
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
1613
+ process to avoid race conditions.
1614
+ save_function (`Callable`):
1615
+ The function to use to save the state dictionary. Useful during distributed training when you need to
1616
+ replace `torch.save` with another method. Can be configured with the environment variable
1617
+ `DIFFUSERS_SAVE_MODE`.
1618
+ safe_serialization (`bool`, *optional*, defaults to `True`):
1619
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
1620
+ """
1621
+ state_dict = {}
1622
+
1623
+ def pack_weights(layers, prefix):
1624
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
1625
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
1626
+ return layers_state_dict
1627
+
1628
+ if not transformer_lora_layers:
1629
+ raise ValueError("You must pass `transformer_lora_layers`.")
1630
+
1631
+ if transformer_lora_layers:
1632
+ state_dict.update(pack_weights(transformer_lora_layers, cls.transformer_name))
1633
+
1634
+ # Save the model
1635
+ cls.write_lora_layers(
1636
+ state_dict=state_dict,
1637
+ save_directory=save_directory,
1638
+ is_main_process=is_main_process,
1639
+ weight_name=weight_name,
1640
+ save_function=save_function,
1641
+ safe_serialization=safe_serialization,
1642
+ )
1643
+
1644
+ @staticmethod
1645
+ def write_lora_layers(
1646
+ state_dict: Dict[str, torch.Tensor],
1647
+ save_directory: str,
1648
+ is_main_process: bool,
1649
+ weight_name: str,
1650
+ save_function: Callable,
1651
+ safe_serialization: bool,
1652
+ ):
1653
+ if os.path.isfile(save_directory):
1654
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
1655
+ return
1656
+
1657
+ if save_function is None:
1658
+ if safe_serialization:
1659
+
1660
+ def save_function(weights, filename):
1661
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
1662
+
1663
+ else:
1664
+ save_function = torch.save
1665
+
1666
+ os.makedirs(save_directory, exist_ok=True)
1667
+
1668
+ if weight_name is None:
1669
+ if safe_serialization:
1670
+ weight_name = LORA_WEIGHT_NAME_SAFE
1671
+ else:
1672
+ weight_name = LORA_WEIGHT_NAME
1673
+
1674
+ save_path = Path(save_directory, weight_name).as_posix()
1675
+ save_function(state_dict, save_path)
1676
+ logger.info(f"Model weights saved in {save_path}")
1677
+
1678
+ def unload_lora_weights(self):
1679
+ """
1680
+ Unloads the LoRA parameters.
1681
+
1682
+ Examples:
1683
+
1684
+ ```python
1685
+ >>> # Assuming `pipeline` is already loaded with the LoRA parameters.
1686
+ >>> pipeline.unload_lora_weights()
1687
+ >>> ...
1688
+ ```
1689
+ """
1690
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
1691
+ recurse_remove_peft_layers(transformer)
1692
+ if hasattr(transformer, "peft_config"):
1693
+ del transformer.peft_config
1694
+
1695
+ @classmethod
1696
+ # Copied from diffusers.loaders.lora.LoraLoaderMixin._optionally_disable_offloading
1697
+ def _optionally_disable_offloading(cls, _pipeline):
1698
+ """
1699
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
1700
+
1701
+ Args:
1702
+ _pipeline (`DiffusionPipeline`):
1703
+ The pipeline to disable offloading for.
1704
+
1705
+ Returns:
1706
+ tuple:
1707
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
1708
+ """
1709
+ is_model_cpu_offload = False
1710
+ is_sequential_cpu_offload = False
1711
+
1712
+ if _pipeline is not None and _pipeline.hf_device_map is None:
1713
+ for _, component in _pipeline.components.items():
1714
+ if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
1715
+ if not is_model_cpu_offload:
1716
+ is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
1717
+ if not is_sequential_cpu_offload:
1718
+ is_sequential_cpu_offload = (
1719
+ isinstance(component._hf_hook, AlignDevicesHook)
1720
+ or hasattr(component._hf_hook, "hooks")
1721
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
1722
+ )
1723
+
1724
+ logger.info(
1725
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1726
+ )
1727
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
1728
+
1729
+ return (is_model_cpu_offload, is_sequential_cpu_offload)