diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,9 @@ if is_accelerate_available():
51
51
 
52
52
  logger = logging.get_logger(__name__)
53
53
 
54
+ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
55
+ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
56
+
54
57
 
55
58
  def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
56
59
  """
@@ -181,6 +184,119 @@ def _remove_text_encoder_monkey_patch(text_encoder):
181
184
  text_encoder._hf_peft_config_loaded = None
182
185
 
183
186
 
187
+ def _fetch_state_dict(
188
+ pretrained_model_name_or_path_or_dict,
189
+ weight_name,
190
+ use_safetensors,
191
+ local_files_only,
192
+ cache_dir,
193
+ force_download,
194
+ proxies,
195
+ token,
196
+ revision,
197
+ subfolder,
198
+ user_agent,
199
+ allow_pickle,
200
+ ):
201
+ model_file = None
202
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
203
+ # Let's first try to load .safetensors weights
204
+ if (use_safetensors and weight_name is None) or (
205
+ weight_name is not None and weight_name.endswith(".safetensors")
206
+ ):
207
+ try:
208
+ # Here we're relaxing the loading check to enable more Inference API
209
+ # friendliness where sometimes, it's not at all possible to automatically
210
+ # determine `weight_name`.
211
+ if weight_name is None:
212
+ weight_name = _best_guess_weight_name(
213
+ pretrained_model_name_or_path_or_dict,
214
+ file_extension=".safetensors",
215
+ local_files_only=local_files_only,
216
+ )
217
+ model_file = _get_model_file(
218
+ pretrained_model_name_or_path_or_dict,
219
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
220
+ cache_dir=cache_dir,
221
+ force_download=force_download,
222
+ proxies=proxies,
223
+ local_files_only=local_files_only,
224
+ token=token,
225
+ revision=revision,
226
+ subfolder=subfolder,
227
+ user_agent=user_agent,
228
+ )
229
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
230
+ except (IOError, safetensors.SafetensorError) as e:
231
+ if not allow_pickle:
232
+ raise e
233
+ # try loading non-safetensors weights
234
+ model_file = None
235
+ pass
236
+
237
+ if model_file is None:
238
+ if weight_name is None:
239
+ weight_name = _best_guess_weight_name(
240
+ pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
241
+ )
242
+ model_file = _get_model_file(
243
+ pretrained_model_name_or_path_or_dict,
244
+ weights_name=weight_name or LORA_WEIGHT_NAME,
245
+ cache_dir=cache_dir,
246
+ force_download=force_download,
247
+ proxies=proxies,
248
+ local_files_only=local_files_only,
249
+ token=token,
250
+ revision=revision,
251
+ subfolder=subfolder,
252
+ user_agent=user_agent,
253
+ )
254
+ state_dict = load_state_dict(model_file)
255
+ else:
256
+ state_dict = pretrained_model_name_or_path_or_dict
257
+
258
+ return state_dict
259
+
260
+
261
+ def _best_guess_weight_name(
262
+ pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
263
+ ):
264
+ if local_files_only or HF_HUB_OFFLINE:
265
+ raise ValueError("When using the offline mode, you must specify a `weight_name`.")
266
+
267
+ targeted_files = []
268
+
269
+ if os.path.isfile(pretrained_model_name_or_path_or_dict):
270
+ return
271
+ elif os.path.isdir(pretrained_model_name_or_path_or_dict):
272
+ targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
273
+ else:
274
+ files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
275
+ targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
276
+ if len(targeted_files) == 0:
277
+ return
278
+
279
+ # "scheduler" does not correspond to a LoRA checkpoint.
280
+ # "optimizer" does not correspond to a LoRA checkpoint
281
+ # only top-level checkpoints are considered and not the other ones, hence "checkpoint".
282
+ unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
283
+ targeted_files = list(
284
+ filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
285
+ )
286
+
287
+ if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
288
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
289
+ elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
290
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
291
+
292
+ if len(targeted_files) > 1:
293
+ raise ValueError(
294
+ f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
295
+ )
296
+ weight_name = targeted_files[0]
297
+ return weight_name
298
+
299
+
184
300
  class LoraBaseMixin:
185
301
  """Utility class for handling LoRAs."""
186
302
 
@@ -234,124 +350,16 @@ class LoraBaseMixin:
234
350
  return (is_model_cpu_offload, is_sequential_cpu_offload)
235
351
 
236
352
  @classmethod
237
- def _fetch_state_dict(
238
- cls,
239
- pretrained_model_name_or_path_or_dict,
240
- weight_name,
241
- use_safetensors,
242
- local_files_only,
243
- cache_dir,
244
- force_download,
245
- proxies,
246
- token,
247
- revision,
248
- subfolder,
249
- user_agent,
250
- allow_pickle,
251
- ):
252
- from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
253
-
254
- model_file = None
255
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
256
- # Let's first try to load .safetensors weights
257
- if (use_safetensors and weight_name is None) or (
258
- weight_name is not None and weight_name.endswith(".safetensors")
259
- ):
260
- try:
261
- # Here we're relaxing the loading check to enable more Inference API
262
- # friendliness where sometimes, it's not at all possible to automatically
263
- # determine `weight_name`.
264
- if weight_name is None:
265
- weight_name = cls._best_guess_weight_name(
266
- pretrained_model_name_or_path_or_dict,
267
- file_extension=".safetensors",
268
- local_files_only=local_files_only,
269
- )
270
- model_file = _get_model_file(
271
- pretrained_model_name_or_path_or_dict,
272
- weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
273
- cache_dir=cache_dir,
274
- force_download=force_download,
275
- proxies=proxies,
276
- local_files_only=local_files_only,
277
- token=token,
278
- revision=revision,
279
- subfolder=subfolder,
280
- user_agent=user_agent,
281
- )
282
- state_dict = safetensors.torch.load_file(model_file, device="cpu")
283
- except (IOError, safetensors.SafetensorError) as e:
284
- if not allow_pickle:
285
- raise e
286
- # try loading non-safetensors weights
287
- model_file = None
288
- pass
289
-
290
- if model_file is None:
291
- if weight_name is None:
292
- weight_name = cls._best_guess_weight_name(
293
- pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
294
- )
295
- model_file = _get_model_file(
296
- pretrained_model_name_or_path_or_dict,
297
- weights_name=weight_name or LORA_WEIGHT_NAME,
298
- cache_dir=cache_dir,
299
- force_download=force_download,
300
- proxies=proxies,
301
- local_files_only=local_files_only,
302
- token=token,
303
- revision=revision,
304
- subfolder=subfolder,
305
- user_agent=user_agent,
306
- )
307
- state_dict = load_state_dict(model_file)
308
- else:
309
- state_dict = pretrained_model_name_or_path_or_dict
310
-
311
- return state_dict
353
+ def _fetch_state_dict(cls, *args, **kwargs):
354
+ deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
355
+ deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
356
+ return _fetch_state_dict(*args, **kwargs)
312
357
 
313
358
  @classmethod
314
- def _best_guess_weight_name(
315
- cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
316
- ):
317
- from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
318
-
319
- if local_files_only or HF_HUB_OFFLINE:
320
- raise ValueError("When using the offline mode, you must specify a `weight_name`.")
321
-
322
- targeted_files = []
323
-
324
- if os.path.isfile(pretrained_model_name_or_path_or_dict):
325
- return
326
- elif os.path.isdir(pretrained_model_name_or_path_or_dict):
327
- targeted_files = [
328
- f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
329
- ]
330
- else:
331
- files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
332
- targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
333
- if len(targeted_files) == 0:
334
- return
335
-
336
- # "scheduler" does not correspond to a LoRA checkpoint.
337
- # "optimizer" does not correspond to a LoRA checkpoint
338
- # only top-level checkpoints are considered and not the other ones, hence "checkpoint".
339
- unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
340
- targeted_files = list(
341
- filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
342
- )
343
-
344
- if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
345
- targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
346
- elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
347
- targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
348
-
349
- if len(targeted_files) > 1:
350
- raise ValueError(
351
- f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
352
- )
353
- weight_name = targeted_files[0]
354
- return weight_name
359
+ def _best_guess_weight_name(cls, *args, **kwargs):
360
+ deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
361
+ deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
362
+ return _best_guess_weight_name(*args, **kwargs)
355
363
 
356
364
  def unload_lora_weights(self):
357
365
  """
@@ -532,13 +540,19 @@ class LoraBaseMixin:
532
540
  )
533
541
 
534
542
  list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
535
- all_adapters = {
536
- adapter for adapters in list_adapters.values() for adapter in adapters
537
- } # eg ["adapter1", "adapter2"]
543
+ # eg ["adapter1", "adapter2"]
544
+ all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
545
+ missing_adapters = set(adapter_names) - all_adapters
546
+ if len(missing_adapters) > 0:
547
+ raise ValueError(
548
+ f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
549
+ )
550
+
551
+ # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
538
552
  invert_list_adapters = {
539
553
  adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
540
554
  for adapter in all_adapters
541
- } # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
555
+ }
542
556
 
543
557
  # Decompose weights into weights for denoiser and text encoders.
544
558
  _component_adapter_weights = {}
@@ -699,9 +713,10 @@ class LoraBaseMixin:
699
713
  module.lora_B[adapter_name].to(device)
700
714
  # this is a param, not a module, so device placement is not in-place -> re-assign
701
715
  if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
702
- module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
703
- adapter_name
704
- ].to(device)
716
+ if adapter_name in module.lora_magnitude_vector:
717
+ module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
718
+ adapter_name
719
+ ].to(device)
705
720
 
706
721
  @staticmethod
707
722
  def pack_weights(layers, prefix):
@@ -718,8 +733,6 @@ class LoraBaseMixin:
718
733
  save_function: Callable,
719
734
  safe_serialization: bool,
720
735
  ):
721
- from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
722
-
723
736
  if os.path.isfile(save_directory):
724
737
  logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
725
738
  return