diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2222 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +1 -12
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +262 -2
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1795 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +319 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +1 -4
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +19 -16
  210. diffusers/utils/loading_utils.py +76 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -74,6 +74,12 @@ CHECKPOINT_KEY_NAMES = {
74
74
  "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
75
75
  "stable_cascade_stage_c": "clip_txt_mapper.weight",
76
76
  "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
77
+ "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
78
+ "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
79
+ "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
80
+ "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
81
+ "animatediff_rgb": "controlnet_cond_embedding.weight",
82
+ "flux": "double_blocks.0.img_attn.norm.key_norm.scale",
77
83
  }
78
84
 
79
85
  DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -103,6 +109,14 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
103
109
  "sd3": {
104
110
  "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
105
111
  },
112
+ "animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
113
+ "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
114
+ "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
115
+ "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
116
+ "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
117
+ "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
118
+ "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
119
+ "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
106
120
  }
107
121
 
108
122
  # Use to configure model sample size when original config is provided
@@ -306,7 +320,6 @@ def _is_model_weights_in_cached_folder(cached_folder, name):
306
320
 
307
321
  def load_single_file_checkpoint(
308
322
  pretrained_model_link_or_path,
309
- resume_download=False,
310
323
  force_download=False,
311
324
  proxies=None,
312
325
  token=None,
@@ -324,7 +337,6 @@ def load_single_file_checkpoint(
324
337
  weights_name=weights_name,
325
338
  force_download=force_download,
326
339
  cache_dir=cache_dir,
327
- resume_download=resume_download,
328
340
  proxies=proxies,
329
341
  local_files_only=local_files_only,
330
342
  token=token,
@@ -485,6 +497,30 @@ def infer_diffusers_model_type(checkpoint):
485
497
  elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
486
498
  model_type = "sd3"
487
499
 
500
+ elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
501
+ if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
502
+ model_type = "animatediff_scribble"
503
+
504
+ elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint:
505
+ model_type = "animatediff_rgb"
506
+
507
+ elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
508
+ model_type = "animatediff_v2"
509
+
510
+ elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
511
+ model_type = "animatediff_sdxl_beta"
512
+
513
+ elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff"]].shape[1] == 24:
514
+ model_type = "animatediff_v1"
515
+
516
+ else:
517
+ model_type = "animatediff_v3"
518
+
519
+ elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint:
520
+ if "guidance_in.in_layer.bias" in checkpoint:
521
+ model_type = "flux-dev"
522
+ else:
523
+ model_type = "flux-schnell"
488
524
  else:
489
525
  model_type = "v1"
490
526
 
@@ -1808,4 +1844,228 @@ def create_diffusers_t5_model_from_checkpoint(
1808
1844
 
1809
1845
  else:
1810
1846
  model.load_state_dict(diffusers_format_checkpoint)
1847
+
1848
+ use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16)
1849
+ if use_keep_in_fp32_modules:
1850
+ keep_in_fp32_modules = model._keep_in_fp32_modules
1851
+ else:
1852
+ keep_in_fp32_modules = []
1853
+
1854
+ if keep_in_fp32_modules is not None:
1855
+ for name, param in model.named_parameters():
1856
+ if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
1857
+ # param = param.to(torch.float32) does not work here as only in the local scope.
1858
+ param.data = param.data.to(torch.float32)
1859
+
1811
1860
  return model
1861
+
1862
+
1863
+ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
1864
+ converted_state_dict = {}
1865
+ for k, v in checkpoint.items():
1866
+ if "pos_encoder" in k:
1867
+ continue
1868
+
1869
+ else:
1870
+ converted_state_dict[
1871
+ k.replace(".norms.0", ".norm1")
1872
+ .replace(".norms.1", ".norm2")
1873
+ .replace(".ff_norm", ".norm3")
1874
+ .replace(".attention_blocks.0", ".attn1")
1875
+ .replace(".attention_blocks.1", ".attn2")
1876
+ .replace(".temporal_transformer", "")
1877
+ ] = v
1878
+
1879
+ return converted_state_dict
1880
+
1881
+
1882
+ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1883
+ converted_state_dict = {}
1884
+
1885
+ num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
1886
+ num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
1887
+ mlp_ratio = 4.0
1888
+ inner_dim = 3072
1889
+
1890
+ # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
1891
+ # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
1892
+ def swap_scale_shift(weight):
1893
+ shift, scale = weight.chunk(2, dim=0)
1894
+ new_weight = torch.cat([scale, shift], dim=0)
1895
+ return new_weight
1896
+
1897
+ ## time_text_embed.timestep_embedder <- time_in
1898
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
1899
+ "time_in.in_layer.weight"
1900
+ )
1901
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
1902
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
1903
+ "time_in.out_layer.weight"
1904
+ )
1905
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")
1906
+
1907
+ ## time_text_embed.text_embedder <- vector_in
1908
+ converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
1909
+ converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
1910
+ converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
1911
+ "vector_in.out_layer.weight"
1912
+ )
1913
+ converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")
1914
+
1915
+ # guidance
1916
+ has_guidance = any("guidance" in k for k in checkpoint)
1917
+ if has_guidance:
1918
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
1919
+ "guidance_in.in_layer.weight"
1920
+ )
1921
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
1922
+ "guidance_in.in_layer.bias"
1923
+ )
1924
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
1925
+ "guidance_in.out_layer.weight"
1926
+ )
1927
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
1928
+ "guidance_in.out_layer.bias"
1929
+ )
1930
+
1931
+ # context_embedder
1932
+ converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
1933
+ converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
1934
+
1935
+ # x_embedder
1936
+ converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
1937
+ converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
1938
+
1939
+ # double transformer blocks
1940
+ for i in range(num_layers):
1941
+ block_prefix = f"transformer_blocks.{i}."
1942
+ # norms.
1943
+ ## norm1
1944
+ converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
1945
+ f"double_blocks.{i}.img_mod.lin.weight"
1946
+ )
1947
+ converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
1948
+ f"double_blocks.{i}.img_mod.lin.bias"
1949
+ )
1950
+ ## norm1_context
1951
+ converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
1952
+ f"double_blocks.{i}.txt_mod.lin.weight"
1953
+ )
1954
+ converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
1955
+ f"double_blocks.{i}.txt_mod.lin.bias"
1956
+ )
1957
+ # Q, K, V
1958
+ sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
1959
+ context_q, context_k, context_v = torch.chunk(
1960
+ checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
1961
+ )
1962
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
1963
+ checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
1964
+ )
1965
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
1966
+ checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
1967
+ )
1968
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
1969
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
1970
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
1971
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
1972
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
1973
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
1974
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
1975
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
1976
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
1977
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
1978
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
1979
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
1980
+ # qk_norm
1981
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
1982
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale"
1983
+ )
1984
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
1985
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale"
1986
+ )
1987
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
1988
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
1989
+ )
1990
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
1991
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
1992
+ )
1993
+ # ff img_mlp
1994
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
1995
+ f"double_blocks.{i}.img_mlp.0.weight"
1996
+ )
1997
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
1998
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
1999
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
2000
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
2001
+ f"double_blocks.{i}.txt_mlp.0.weight"
2002
+ )
2003
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
2004
+ f"double_blocks.{i}.txt_mlp.0.bias"
2005
+ )
2006
+ converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
2007
+ f"double_blocks.{i}.txt_mlp.2.weight"
2008
+ )
2009
+ converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
2010
+ f"double_blocks.{i}.txt_mlp.2.bias"
2011
+ )
2012
+ # output projections.
2013
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
2014
+ f"double_blocks.{i}.img_attn.proj.weight"
2015
+ )
2016
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
2017
+ f"double_blocks.{i}.img_attn.proj.bias"
2018
+ )
2019
+ converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
2020
+ f"double_blocks.{i}.txt_attn.proj.weight"
2021
+ )
2022
+ converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
2023
+ f"double_blocks.{i}.txt_attn.proj.bias"
2024
+ )
2025
+
2026
+ # single transfomer blocks
2027
+ for i in range(num_single_layers):
2028
+ block_prefix = f"single_transformer_blocks.{i}."
2029
+ # norm.linear <- single_blocks.0.modulation.lin
2030
+ converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
2031
+ f"single_blocks.{i}.modulation.lin.weight"
2032
+ )
2033
+ converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
2034
+ f"single_blocks.{i}.modulation.lin.bias"
2035
+ )
2036
+ # Q, K, V, mlp
2037
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
2038
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
2039
+ q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
2040
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
2041
+ checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
2042
+ )
2043
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
2044
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
2045
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
2046
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
2047
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
2048
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
2049
+ converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
2050
+ converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
2051
+ # qk norm
2052
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
2053
+ f"single_blocks.{i}.norm.query_norm.scale"
2054
+ )
2055
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
2056
+ f"single_blocks.{i}.norm.key_norm.scale"
2057
+ )
2058
+ # output projections.
2059
+ converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
2060
+ converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
2061
+
2062
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
2063
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
2064
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
2065
+ checkpoint.pop("final_layer.adaLN_modulation.1.weight")
2066
+ )
2067
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
2068
+ checkpoint.pop("final_layer.adaLN_modulation.1.bias")
2069
+ )
2070
+
2071
+ return converted_state_dict
@@ -38,7 +38,6 @@ TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
38
38
  def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
39
39
  cache_dir = kwargs.pop("cache_dir", None)
40
40
  force_download = kwargs.pop("force_download", False)
41
- resume_download = kwargs.pop("resume_download", None)
42
41
  proxies = kwargs.pop("proxies", None)
43
42
  local_files_only = kwargs.pop("local_files_only", None)
44
43
  token = kwargs.pop("token", None)
@@ -72,7 +71,6 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
72
71
  weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
73
72
  cache_dir=cache_dir,
74
73
  force_download=force_download,
75
- resume_download=resume_download,
76
74
  proxies=proxies,
77
75
  local_files_only=local_files_only,
78
76
  token=token,
@@ -93,7 +91,6 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
93
91
  weights_name=weight_name or TEXT_INVERSION_NAME,
94
92
  cache_dir=cache_dir,
95
93
  force_download=force_download,
96
- resume_download=resume_download,
97
94
  proxies=proxies,
98
95
  local_files_only=local_files_only,
99
96
  token=token,
@@ -308,9 +305,7 @@ class TextualInversionLoaderMixin:
308
305
  force_download (`bool`, *optional*, defaults to `False`):
309
306
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
310
307
  cached versions if they exist.
311
- resume_download:
312
- Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
313
- of Diffusers.
308
+
314
309
  proxies (`Dict[str, str]`, *optional*):
315
310
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
316
311
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
diffusers/loaders/unet.py CHANGED
@@ -11,13 +11,11 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import inspect
15
14
  import os
16
15
  from collections import defaultdict
17
16
  from contextlib import nullcontext
18
- from functools import partial
19
17
  from pathlib import Path
20
- from typing import Callable, Dict, List, Optional, Union
18
+ from typing import Callable, Dict, Union
21
19
 
22
20
  import safetensors
23
21
  import torch
@@ -38,18 +36,14 @@ from ..utils import (
38
36
  USE_PEFT_BACKEND,
39
37
  _get_model_file,
40
38
  convert_unet_state_dict_to_peft,
41
- delete_adapter_layers,
42
39
  get_adapter_name,
43
40
  get_peft_kwargs,
44
41
  is_accelerate_available,
45
42
  is_peft_version,
46
43
  is_torch_version,
47
44
  logging,
48
- set_adapter_layers,
49
- set_weights_and_activate_adapters,
50
45
  )
51
- from .lora import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
52
- from .unet_loader_utils import _maybe_expand_lora_scales
46
+ from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
53
47
  from .utils import AttnProcsLayers
54
48
 
55
49
 
@@ -97,9 +91,7 @@ class UNet2DConditionLoadersMixin:
97
91
  force_download (`bool`, *optional*, defaults to `False`):
98
92
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
99
93
  cached versions if they exist.
100
- resume_download:
101
- Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
102
- of Diffusers.
94
+
103
95
  proxies (`Dict[str, str]`, *optional*):
104
96
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
105
97
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -140,7 +132,6 @@ class UNet2DConditionLoadersMixin:
140
132
  """
141
133
  cache_dir = kwargs.pop("cache_dir", None)
142
134
  force_download = kwargs.pop("force_download", False)
143
- resume_download = kwargs.pop("resume_download", None)
144
135
  proxies = kwargs.pop("proxies", None)
145
136
  local_files_only = kwargs.pop("local_files_only", None)
146
137
  token = kwargs.pop("token", None)
@@ -174,7 +165,6 @@ class UNet2DConditionLoadersMixin:
174
165
  weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
175
166
  cache_dir=cache_dir,
176
167
  force_download=force_download,
177
- resume_download=resume_download,
178
168
  proxies=proxies,
179
169
  local_files_only=local_files_only,
180
170
  token=token,
@@ -194,7 +184,6 @@ class UNet2DConditionLoadersMixin:
194
184
  weights_name=weight_name or LORA_WEIGHT_NAME,
195
185
  cache_dir=cache_dir,
196
186
  force_download=force_download,
197
- resume_download=resume_download,
198
187
  proxies=proxies,
199
188
  local_files_only=local_files_only,
200
189
  token=token,
@@ -362,7 +351,7 @@ class UNet2DConditionLoadersMixin:
362
351
  return is_model_cpu_offload, is_sequential_cpu_offload
363
352
 
364
353
  @classmethod
365
- # Copied from diffusers.loaders.lora.LoraLoaderMixin._optionally_disable_offloading
354
+ # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
366
355
  def _optionally_disable_offloading(cls, _pipeline):
367
356
  """
368
357
  Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
@@ -457,6 +446,15 @@ class UNet2DConditionLoadersMixin:
457
446
  )
458
447
  if is_custom_diffusion:
459
448
  state_dict = self._get_custom_diffusion_state_dict()
449
+ if save_function is None and safe_serialization:
450
+ # safetensors does not support saving dicts with non-tensor values
451
+ empty_state_dict = {k: v for k, v in state_dict.items() if not isinstance(v, torch.Tensor)}
452
+ if len(empty_state_dict) > 0:
453
+ logger.warning(
454
+ f"Safetensors does not support saving dicts with non-tensor values. "
455
+ f"The following keys will be ignored: {empty_state_dict.keys()}"
456
+ )
457
+ state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
460
458
  else:
461
459
  if not USE_PEFT_BACKEND:
462
460
  raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
@@ -515,194 +513,6 @@ class UNet2DConditionLoadersMixin:
515
513
 
516
514
  return state_dict
517
515
 
518
- def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
519
- if not USE_PEFT_BACKEND:
520
- raise ValueError("PEFT backend is required for `fuse_lora()`.")
521
-
522
- self.lora_scale = lora_scale
523
- self._safe_fusing = safe_fusing
524
- self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
525
-
526
- def _fuse_lora_apply(self, module, adapter_names=None):
527
- from peft.tuners.tuners_utils import BaseTunerLayer
528
-
529
- merge_kwargs = {"safe_merge": self._safe_fusing}
530
-
531
- if isinstance(module, BaseTunerLayer):
532
- if self.lora_scale != 1.0:
533
- module.scale_layer(self.lora_scale)
534
-
535
- # For BC with prevous PEFT versions, we need to check the signature
536
- # of the `merge` method to see if it supports the `adapter_names` argument.
537
- supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
538
- if "adapter_names" in supported_merge_kwargs:
539
- merge_kwargs["adapter_names"] = adapter_names
540
- elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
541
- raise ValueError(
542
- "The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
543
- " to the latest version of PEFT. `pip install -U peft`"
544
- )
545
-
546
- module.merge(**merge_kwargs)
547
-
548
- def unfuse_lora(self):
549
- if not USE_PEFT_BACKEND:
550
- raise ValueError("PEFT backend is required for `unfuse_lora()`.")
551
- self.apply(self._unfuse_lora_apply)
552
-
553
- def _unfuse_lora_apply(self, module):
554
- from peft.tuners.tuners_utils import BaseTunerLayer
555
-
556
- if isinstance(module, BaseTunerLayer):
557
- module.unmerge()
558
-
559
- def unload_lora(self):
560
- if not USE_PEFT_BACKEND:
561
- raise ValueError("PEFT backend is required for `unload_lora()`.")
562
-
563
- from ..utils import recurse_remove_peft_layers
564
-
565
- recurse_remove_peft_layers(self)
566
- if hasattr(self, "peft_config"):
567
- del self.peft_config
568
-
569
- def set_adapters(
570
- self,
571
- adapter_names: Union[List[str], str],
572
- weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
573
- ):
574
- """
575
- Set the currently active adapters for use in the UNet.
576
-
577
- Args:
578
- adapter_names (`List[str]` or `str`):
579
- The names of the adapters to use.
580
- adapter_weights (`Union[List[float], float]`, *optional*):
581
- The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
582
- adapters.
583
-
584
- Example:
585
-
586
- ```py
587
- from diffusers import AutoPipelineForText2Image
588
- import torch
589
-
590
- pipeline = AutoPipelineForText2Image.from_pretrained(
591
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
592
- ).to("cuda")
593
- pipeline.load_lora_weights(
594
- "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
595
- )
596
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
597
- pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
598
- ```
599
- """
600
- if not USE_PEFT_BACKEND:
601
- raise ValueError("PEFT backend is required for `set_adapters()`.")
602
-
603
- adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
604
-
605
- # Expand weights into a list, one entry per adapter
606
- # examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
607
- if not isinstance(weights, list):
608
- weights = [weights] * len(adapter_names)
609
-
610
- if len(adapter_names) != len(weights):
611
- raise ValueError(
612
- f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
613
- )
614
-
615
- # Set None values to default of 1.0
616
- # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
617
- weights = [w if w is not None else 1.0 for w in weights]
618
-
619
- # e.g. [{...}, 7] -> [{expanded dict...}, 7]
620
- weights = _maybe_expand_lora_scales(self, weights)
621
-
622
- set_weights_and_activate_adapters(self, adapter_names, weights)
623
-
624
- def disable_lora(self):
625
- """
626
- Disable the UNet's active LoRA layers.
627
-
628
- Example:
629
-
630
- ```py
631
- from diffusers import AutoPipelineForText2Image
632
- import torch
633
-
634
- pipeline = AutoPipelineForText2Image.from_pretrained(
635
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
636
- ).to("cuda")
637
- pipeline.load_lora_weights(
638
- "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
639
- )
640
- pipeline.disable_lora()
641
- ```
642
- """
643
- if not USE_PEFT_BACKEND:
644
- raise ValueError("PEFT backend is required for this method.")
645
- set_adapter_layers(self, enabled=False)
646
-
647
- def enable_lora(self):
648
- """
649
- Enable the UNet's active LoRA layers.
650
-
651
- Example:
652
-
653
- ```py
654
- from diffusers import AutoPipelineForText2Image
655
- import torch
656
-
657
- pipeline = AutoPipelineForText2Image.from_pretrained(
658
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
659
- ).to("cuda")
660
- pipeline.load_lora_weights(
661
- "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
662
- )
663
- pipeline.enable_lora()
664
- ```
665
- """
666
- if not USE_PEFT_BACKEND:
667
- raise ValueError("PEFT backend is required for this method.")
668
- set_adapter_layers(self, enabled=True)
669
-
670
- def delete_adapters(self, adapter_names: Union[List[str], str]):
671
- """
672
- Delete an adapter's LoRA layers from the UNet.
673
-
674
- Args:
675
- adapter_names (`Union[List[str], str]`):
676
- The names (single string or list of strings) of the adapter to delete.
677
-
678
- Example:
679
-
680
- ```py
681
- from diffusers import AutoPipelineForText2Image
682
- import torch
683
-
684
- pipeline = AutoPipelineForText2Image.from_pretrained(
685
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
686
- ).to("cuda")
687
- pipeline.load_lora_weights(
688
- "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
689
- )
690
- pipeline.delete_adapters("cinematic")
691
- ```
692
- """
693
- if not USE_PEFT_BACKEND:
694
- raise ValueError("PEFT backend is required for this method.")
695
-
696
- if isinstance(adapter_names, str):
697
- adapter_names = [adapter_names]
698
-
699
- for adapter_name in adapter_names:
700
- delete_adapter_layers(self, adapter_name)
701
-
702
- # Pop also the corresponding adapter from the config
703
- if hasattr(self, "peft_config"):
704
- self.peft_config.pop(adapter_name, None)
705
-
706
516
  def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
707
517
  if low_cpu_mem_usage:
708
518
  if is_accelerate_available():
@@ -922,8 +732,6 @@ class UNet2DConditionLoadersMixin:
922
732
 
923
733
  def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
924
734
  from ..models.attention_processor import (
925
- AttnProcessor,
926
- AttnProcessor2_0,
927
735
  IPAdapterAttnProcessor,
928
736
  IPAdapterAttnProcessor2_0,
929
737
  )
@@ -963,9 +771,7 @@ class UNet2DConditionLoadersMixin:
963
771
  hidden_size = self.config.block_out_channels[block_id]
964
772
 
965
773
  if cross_attention_dim is None or "motion_modules" in name:
966
- attn_processor_class = (
967
- AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
968
- )
774
+ attn_processor_class = self.attn_processors[name].__class__
969
775
  attn_procs[name] = attn_processor_class()
970
776
 
971
777
  else:
@@ -1017,6 +823,15 @@ class UNet2DConditionLoadersMixin:
1017
823
  def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
1018
824
  if not isinstance(state_dicts, list):
1019
825
  state_dicts = [state_dicts]
826
+
827
+ # Kolors Unet already has a `encoder_hid_proj`
828
+ if (
829
+ self.encoder_hid_proj is not None
830
+ and self.config.encoder_hid_dim_type == "text_proj"
831
+ and not hasattr(self, "text_encoder_hid_proj")
832
+ ):
833
+ self.text_encoder_hid_proj = self.encoder_hid_proj
834
+
1020
835
  # Set encoder_hid_proj after loading ip_adapter weights,
1021
836
  # because `IPAdapterPlusImageProjection` also has `attn_processors`.
1022
837
  self.encoder_hid_proj = None