diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +72 -26
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
diffusers/loaders/peft.py CHANGED
@@ -13,30 +13,103 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
  import inspect
16
+ import os
16
17
  from functools import partial
18
+ from pathlib import Path
17
19
  from typing import Dict, List, Optional, Union
18
20
 
21
+ import safetensors
22
+ import torch
23
+ import torch.nn as nn
24
+
19
25
  from ..utils import (
20
26
  MIN_PEFT_VERSION,
21
27
  USE_PEFT_BACKEND,
22
28
  check_peft_version,
29
+ convert_unet_state_dict_to_peft,
23
30
  delete_adapter_layers,
31
+ get_adapter_name,
32
+ get_peft_kwargs,
33
+ is_accelerate_available,
24
34
  is_peft_available,
35
+ is_peft_version,
36
+ logging,
25
37
  set_adapter_layers,
26
38
  set_weights_and_activate_adapters,
27
39
  )
40
+ from .lora_base import _fetch_state_dict
28
41
  from .unet_loader_utils import _maybe_expand_lora_scales
29
42
 
30
43
 
44
+ if is_accelerate_available():
45
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
46
+
47
+ logger = logging.get_logger(__name__)
48
+
31
49
  _SET_ADAPTER_SCALE_FN_MAPPING = {
32
50
  "UNet2DConditionModel": _maybe_expand_lora_scales,
33
51
  "UNetMotionModel": _maybe_expand_lora_scales,
34
52
  "SD3Transformer2DModel": lambda model_cls, weights: weights,
35
53
  "FluxTransformer2DModel": lambda model_cls, weights: weights,
36
54
  "CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
55
+ "MochiTransformer3DModel": lambda model_cls, weights: weights,
56
+ "HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
57
+ "LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
58
+ "SanaTransformer2DModel": lambda model_cls, weights: weights,
37
59
  }
38
60
 
39
61
 
62
+ def _maybe_adjust_config(config):
63
+ """
64
+ We may run into some ambiguous configuration values when a model has module names, sharing a common prefix
65
+ (`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This
66
+ method removes the ambiguity by following what is described here:
67
+ https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
68
+ """
69
+ rank_pattern = config["rank_pattern"].copy()
70
+ target_modules = config["target_modules"]
71
+ original_r = config["r"]
72
+
73
+ for key in list(rank_pattern.keys()):
74
+ key_rank = rank_pattern[key]
75
+
76
+ # try to detect ambiguity
77
+ # `target_modules` can also be a str, in which case this loop would loop
78
+ # over the chars of the str. The technically correct way to match LoRA keys
79
+ # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
80
+ # But this cuts it for now.
81
+ exact_matches = [mod for mod in target_modules if mod == key]
82
+ substring_matches = [mod for mod in target_modules if key in mod and mod != key]
83
+ ambiguous_key = key
84
+
85
+ if exact_matches and substring_matches:
86
+ # if ambiguous we update the rank associated with the ambiguous key (`proj_out`, for example)
87
+ config["r"] = key_rank
88
+ # remove the ambiguous key from `rank_pattern` and update its rank to `r`, instead
89
+ del config["rank_pattern"][key]
90
+ for mod in substring_matches:
91
+ # avoid overwriting if the module already has a specific rank
92
+ if mod not in config["rank_pattern"]:
93
+ config["rank_pattern"][mod] = original_r
94
+
95
+ # update the rest of the keys with the `original_r`
96
+ for mod in target_modules:
97
+ if mod != ambiguous_key and mod not in config["rank_pattern"]:
98
+ config["rank_pattern"][mod] = original_r
99
+
100
+ # handle alphas to deal with cases like
101
+ # https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
102
+ has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
103
+ if has_different_ranks:
104
+ config["lora_alpha"] = config["r"]
105
+ alpha_pattern = {}
106
+ for module_name, rank in config["rank_pattern"].items():
107
+ alpha_pattern[module_name] = rank
108
+ config["alpha_pattern"] = alpha_pattern
109
+
110
+ return config
111
+
112
+
40
113
  class PeftAdapterMixin:
41
114
  """
42
115
  A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
@@ -53,6 +126,312 @@ class PeftAdapterMixin:
53
126
 
54
127
  _hf_peft_config_loaded = False
55
128
 
129
+ @classmethod
130
+ # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
131
+ def _optionally_disable_offloading(cls, _pipeline):
132
+ """
133
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
134
+
135
+ Args:
136
+ _pipeline (`DiffusionPipeline`):
137
+ The pipeline to disable offloading for.
138
+
139
+ Returns:
140
+ tuple:
141
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
142
+ """
143
+ is_model_cpu_offload = False
144
+ is_sequential_cpu_offload = False
145
+
146
+ if _pipeline is not None and _pipeline.hf_device_map is None:
147
+ for _, component in _pipeline.components.items():
148
+ if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
149
+ if not is_model_cpu_offload:
150
+ is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
151
+ if not is_sequential_cpu_offload:
152
+ is_sequential_cpu_offload = (
153
+ isinstance(component._hf_hook, AlignDevicesHook)
154
+ or hasattr(component._hf_hook, "hooks")
155
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
156
+ )
157
+
158
+ logger.info(
159
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
160
+ )
161
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
162
+
163
+ return (is_model_cpu_offload, is_sequential_cpu_offload)
164
+
165
+ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
166
+ r"""
167
+ Loads a LoRA adapter into the underlying model.
168
+
169
+ Parameters:
170
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
171
+ Can be either:
172
+
173
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
174
+ the Hub.
175
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
176
+ with [`ModelMixin.save_pretrained`].
177
+ - A [torch state
178
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
179
+
180
+ prefix (`str`, *optional*): Prefix to filter the state dict.
181
+
182
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
183
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
184
+ is not used.
185
+ force_download (`bool`, *optional*, defaults to `False`):
186
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
187
+ cached versions if they exist.
188
+ proxies (`Dict[str, str]`, *optional*):
189
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
190
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
191
+ local_files_only (`bool`, *optional*, defaults to `False`):
192
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
193
+ won't be downloaded from the Hub.
194
+ token (`str` or *bool*, *optional*):
195
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
196
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
197
+ revision (`str`, *optional*, defaults to `"main"`):
198
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
199
+ allowed by Git.
200
+ subfolder (`str`, *optional*, defaults to `""`):
201
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
202
+ network_alphas (`Dict[str, float]`):
203
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
204
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
205
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
206
+ low_cpu_mem_usage (`bool`, *optional*):
207
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
208
+ weights.
209
+ """
210
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
211
+ from peft.tuners.tuners_utils import BaseTunerLayer
212
+
213
+ cache_dir = kwargs.pop("cache_dir", None)
214
+ force_download = kwargs.pop("force_download", False)
215
+ proxies = kwargs.pop("proxies", None)
216
+ local_files_only = kwargs.pop("local_files_only", None)
217
+ token = kwargs.pop("token", None)
218
+ revision = kwargs.pop("revision", None)
219
+ subfolder = kwargs.pop("subfolder", None)
220
+ weight_name = kwargs.pop("weight_name", None)
221
+ use_safetensors = kwargs.pop("use_safetensors", None)
222
+ adapter_name = kwargs.pop("adapter_name", None)
223
+ network_alphas = kwargs.pop("network_alphas", None)
224
+ _pipeline = kwargs.pop("_pipeline", None)
225
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
226
+ allow_pickle = False
227
+
228
+ if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
229
+ raise ValueError(
230
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
231
+ )
232
+
233
+ user_agent = {
234
+ "file_type": "attn_procs_weights",
235
+ "framework": "pytorch",
236
+ }
237
+
238
+ state_dict = _fetch_state_dict(
239
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
240
+ weight_name=weight_name,
241
+ use_safetensors=use_safetensors,
242
+ local_files_only=local_files_only,
243
+ cache_dir=cache_dir,
244
+ force_download=force_download,
245
+ proxies=proxies,
246
+ token=token,
247
+ revision=revision,
248
+ subfolder=subfolder,
249
+ user_agent=user_agent,
250
+ allow_pickle=allow_pickle,
251
+ )
252
+ if network_alphas is not None and prefix is None:
253
+ raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
254
+
255
+ if prefix is not None:
256
+ keys = list(state_dict.keys())
257
+ model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
258
+ if len(model_keys) > 0:
259
+ state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}
260
+
261
+ if len(state_dict) > 0:
262
+ if adapter_name in getattr(self, "peft_config", {}):
263
+ raise ValueError(
264
+ f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
265
+ )
266
+
267
+ # check with first key if is not in peft format
268
+ first_key = next(iter(state_dict.keys()))
269
+ if "lora_A" not in first_key:
270
+ state_dict = convert_unet_state_dict_to_peft(state_dict)
271
+
272
+ rank = {}
273
+ for key, val in state_dict.items():
274
+ # Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
275
+ # Bias layers in LoRA only have a single dimension
276
+ if "lora_B" in key and val.ndim > 1:
277
+ rank[key] = val.shape[1]
278
+
279
+ if network_alphas is not None and len(network_alphas) >= 1:
280
+ alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
281
+ network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
282
+
283
+ lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
284
+ lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
285
+
286
+ if "use_dora" in lora_config_kwargs:
287
+ if lora_config_kwargs["use_dora"]:
288
+ if is_peft_version("<", "0.9.0"):
289
+ raise ValueError(
290
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
291
+ )
292
+ else:
293
+ if is_peft_version("<", "0.9.0"):
294
+ lora_config_kwargs.pop("use_dora")
295
+
296
+ if "lora_bias" in lora_config_kwargs:
297
+ if lora_config_kwargs["lora_bias"]:
298
+ if is_peft_version("<=", "0.13.2"):
299
+ raise ValueError(
300
+ "You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
301
+ )
302
+ else:
303
+ if is_peft_version("<=", "0.13.2"):
304
+ lora_config_kwargs.pop("lora_bias")
305
+
306
+ lora_config = LoraConfig(**lora_config_kwargs)
307
+ # adapter_name
308
+ if adapter_name is None:
309
+ adapter_name = get_adapter_name(self)
310
+
311
+ # <Unsafe code
312
+ # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
313
+ # Now we remove any existing hooks to `_pipeline`.
314
+
315
+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
316
+ # otherwise loading LoRA weights will lead to an error
317
+ is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
318
+
319
+ peft_kwargs = {}
320
+ if is_peft_version(">=", "0.13.1"):
321
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
322
+
323
+ # To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
324
+ # we should also delete the `peft_config` associated to the `adapter_name`.
325
+ try:
326
+ inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
327
+ incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
328
+ except RuntimeError as e:
329
+ for module in self.modules():
330
+ if isinstance(module, BaseTunerLayer):
331
+ active_adapters = module.active_adapters
332
+ for active_adapter in active_adapters:
333
+ if adapter_name in active_adapter:
334
+ module.delete_adapter(adapter_name)
335
+
336
+ self.peft_config.pop(adapter_name)
337
+ logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
338
+ raise
339
+
340
+ warn_msg = ""
341
+ if incompatible_keys is not None:
342
+ # Check only for unexpected keys.
343
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
344
+ if unexpected_keys:
345
+ lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
346
+ if lora_unexpected_keys:
347
+ warn_msg = (
348
+ f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
349
+ f" {', '.join(lora_unexpected_keys)}. "
350
+ )
351
+
352
+ # Filter missing keys specific to the current adapter.
353
+ missing_keys = getattr(incompatible_keys, "missing_keys", None)
354
+ if missing_keys:
355
+ lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
356
+ if lora_missing_keys:
357
+ warn_msg += (
358
+ f"Loading adapter weights from state_dict led to missing keys in the model:"
359
+ f" {', '.join(lora_missing_keys)}."
360
+ )
361
+
362
+ if warn_msg:
363
+ logger.warning(warn_msg)
364
+
365
+ # Offload back.
366
+ if is_model_cpu_offload:
367
+ _pipeline.enable_model_cpu_offload()
368
+ elif is_sequential_cpu_offload:
369
+ _pipeline.enable_sequential_cpu_offload()
370
+ # Unsafe code />
371
+
372
+ def save_lora_adapter(
373
+ self,
374
+ save_directory,
375
+ adapter_name: str = "default",
376
+ upcast_before_saving: bool = False,
377
+ safe_serialization: bool = True,
378
+ weight_name: Optional[str] = None,
379
+ ):
380
+ """
381
+ Save the LoRA parameters corresponding to the underlying model.
382
+
383
+ Arguments:
384
+ save_directory (`str` or `os.PathLike`):
385
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
386
+ adapter_name: (`str`, defaults to "default"): The name of the adapter to serialize. Useful when the
387
+ underlying model has multiple adapters loaded.
388
+ upcast_before_saving (`bool`, defaults to `False`):
389
+ Whether to cast the underlying model to `torch.float32` before serialization.
390
+ save_function (`Callable`):
391
+ The function to use to save the state dictionary. Useful during distributed training when you need to
392
+ replace `torch.save` with another method. Can be configured with the environment variable
393
+ `DIFFUSERS_SAVE_MODE`.
394
+ safe_serialization (`bool`, *optional*, defaults to `True`):
395
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
396
+ weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
397
+ """
398
+ from peft.utils import get_peft_model_state_dict
399
+
400
+ from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
401
+
402
+ if adapter_name is None:
403
+ adapter_name = get_adapter_name(self)
404
+
405
+ if adapter_name not in getattr(self, "peft_config", {}):
406
+ raise ValueError(f"Adapter name {adapter_name} not found in the model.")
407
+
408
+ lora_layers_to_save = get_peft_model_state_dict(
409
+ self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
410
+ )
411
+ if os.path.isfile(save_directory):
412
+ raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
413
+
414
+ if safe_serialization:
415
+
416
+ def save_function(weights, filename):
417
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
418
+
419
+ else:
420
+ save_function = torch.save
421
+
422
+ os.makedirs(save_directory, exist_ok=True)
423
+
424
+ if weight_name is None:
425
+ if safe_serialization:
426
+ weight_name = LORA_WEIGHT_NAME_SAFE
427
+ else:
428
+ weight_name = LORA_WEIGHT_NAME
429
+
430
+ # TODO: we could consider saving the `peft_config` as well.
431
+ save_path = Path(save_directory, weight_name).as_posix()
432
+ save_function(lora_layers_to_save, save_path)
433
+ logger.info(f"Model weights saved in {save_path}")
434
+
56
435
  def set_adapters(
57
436
  self,
58
437
  adapter_names: Union[List[str], str],
@@ -17,16 +17,23 @@ import re
17
17
  from contextlib import nullcontext
18
18
  from typing import Optional
19
19
 
20
+ import torch
20
21
  from huggingface_hub.utils import validate_hf_hub_args
21
22
 
23
+ from ..quantizers import DiffusersAutoQuantizer
22
24
  from ..utils import deprecate, is_accelerate_available, logging
23
25
  from .single_file_utils import (
24
26
  SingleFileComponentError,
25
27
  convert_animatediff_checkpoint_to_diffusers,
28
+ convert_autoencoder_dc_checkpoint_to_diffusers,
26
29
  convert_controlnet_checkpoint,
27
30
  convert_flux_transformer_checkpoint_to_diffusers,
31
+ convert_hunyuan_video_transformer_to_diffusers,
28
32
  convert_ldm_unet_checkpoint,
29
33
  convert_ldm_vae_checkpoint,
34
+ convert_ltx_transformer_checkpoint_to_diffusers,
35
+ convert_ltx_vae_checkpoint_to_diffusers,
36
+ convert_mochi_transformer_checkpoint_to_diffusers,
30
37
  convert_sd3_transformer_checkpoint_to_diffusers,
31
38
  convert_stable_cascade_unet_single_file_to_diffusers,
32
39
  create_controlnet_diffusers_config_from_ldm,
@@ -82,6 +89,23 @@ SINGLE_FILE_LOADABLE_CLASSES = {
82
89
  "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
83
90
  "default_subfolder": "transformer",
84
91
  },
92
+ "LTXVideoTransformer3DModel": {
93
+ "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
94
+ "default_subfolder": "transformer",
95
+ },
96
+ "AutoencoderKLLTXVideo": {
97
+ "checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
98
+ "default_subfolder": "vae",
99
+ },
100
+ "AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
101
+ "MochiTransformer3DModel": {
102
+ "checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
103
+ "default_subfolder": "transformer",
104
+ },
105
+ "HunyuanVideoTransformer3DModel": {
106
+ "checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
107
+ "default_subfolder": "transformer",
108
+ },
85
109
  }
86
110
 
87
111
 
@@ -201,7 +225,10 @@ class FromOriginalModelMixin:
201
225
  local_files_only = kwargs.pop("local_files_only", None)
202
226
  subfolder = kwargs.pop("subfolder", None)
203
227
  revision = kwargs.pop("revision", None)
228
+ config_revision = kwargs.pop("config_revision", None)
204
229
  torch_dtype = kwargs.pop("torch_dtype", None)
230
+ quantization_config = kwargs.pop("quantization_config", None)
231
+ device = kwargs.pop("device", None)
205
232
 
206
233
  if isinstance(pretrained_model_link_or_path_or_dict, dict):
207
234
  checkpoint = pretrained_model_link_or_path_or_dict
@@ -215,11 +242,17 @@ class FromOriginalModelMixin:
215
242
  local_files_only=local_files_only,
216
243
  revision=revision,
217
244
  )
245
+ if quantization_config is not None:
246
+ hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
247
+ hf_quantizer.validate_environment()
248
+
249
+ else:
250
+ hf_quantizer = None
218
251
 
219
252
  mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
220
253
 
221
254
  checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
222
- if original_config:
255
+ if original_config is not None:
223
256
  if "config_mapping_fn" in mapping_functions:
224
257
  config_mapping_fn = mapping_functions["config_mapping_fn"]
225
258
  else:
@@ -243,7 +276,7 @@ class FromOriginalModelMixin:
243
276
  original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
244
277
  )
245
278
  else:
246
- if config:
279
+ if config is not None:
247
280
  if isinstance(config, str):
248
281
  default_pretrained_model_config_name = config
249
282
  else:
@@ -269,6 +302,8 @@ class FromOriginalModelMixin:
269
302
  pretrained_model_name_or_path=default_pretrained_model_config_name,
270
303
  subfolder=subfolder,
271
304
  local_files_only=local_files_only,
305
+ token=token,
306
+ revision=config_revision,
272
307
  )
273
308
  expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
274
309
 
@@ -295,8 +330,36 @@ class FromOriginalModelMixin:
295
330
  with ctx():
296
331
  model = cls.from_config(diffusers_model_config)
297
332
 
333
+ # Check if `_keep_in_fp32_modules` is not None
334
+ use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
335
+ (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
336
+ )
337
+ if use_keep_in_fp32_modules:
338
+ keep_in_fp32_modules = cls._keep_in_fp32_modules
339
+ if not isinstance(keep_in_fp32_modules, list):
340
+ keep_in_fp32_modules = [keep_in_fp32_modules]
341
+
342
+ else:
343
+ keep_in_fp32_modules = []
344
+
345
+ if hf_quantizer is not None:
346
+ hf_quantizer.preprocess_model(
347
+ model=model,
348
+ device_map=None,
349
+ state_dict=diffusers_format_checkpoint,
350
+ keep_in_fp32_modules=keep_in_fp32_modules,
351
+ )
352
+
298
353
  if is_accelerate_available():
299
- unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
354
+ param_device = torch.device(device) if device else torch.device("cpu")
355
+ unexpected_keys = load_model_dict_into_meta(
356
+ model,
357
+ diffusers_format_checkpoint,
358
+ dtype=torch_dtype,
359
+ device=param_device,
360
+ hf_quantizer=hf_quantizer,
361
+ keep_in_fp32_modules=keep_in_fp32_modules,
362
+ )
300
363
 
301
364
  else:
302
365
  _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
@@ -310,7 +373,11 @@ class FromOriginalModelMixin:
310
373
  f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
311
374
  )
312
375
 
313
- if torch_dtype is not None:
376
+ if hf_quantizer is not None:
377
+ hf_quantizer.postprocess_model(model)
378
+ model.hf_quantizer = hf_quantizer
379
+
380
+ if torch_dtype is not None and hf_quantizer is None:
314
381
  model.to(torch_dtype)
315
382
 
316
383
  model.eval()