diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +33 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +8 -0
  21. diffusers/models/activations.py +23 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +475 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +363 -32
  35. diffusers/models/model_loading_utils.py +177 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_outputs.py +14 -0
  39. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  40. diffusers/models/modeling_utils.py +175 -99
  41. diffusers/models/normalization.py +2 -1
  42. diffusers/models/resnet.py +18 -23
  43. diffusers/models/transformer_temporal.py +3 -3
  44. diffusers/models/transformers/__init__.py +3 -0
  45. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  46. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  47. diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
  48. diffusers/models/transformers/pixart_transformer_2d.py +336 -0
  49. diffusers/models/transformers/prior_transformer.py +7 -7
  50. diffusers/models/transformers/t5_film_transformer.py +17 -19
  51. diffusers/models/transformers/transformer_2d.py +292 -184
  52. diffusers/models/transformers/transformer_temporal.py +10 -10
  53. diffusers/models/unets/unet_1d.py +5 -5
  54. diffusers/models/unets/unet_1d_blocks.py +29 -29
  55. diffusers/models/unets/unet_2d.py +6 -6
  56. diffusers/models/unets/unet_2d_blocks.py +137 -128
  57. diffusers/models/unets/unet_2d_condition.py +19 -15
  58. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  59. diffusers/models/unets/unet_3d_blocks.py +79 -77
  60. diffusers/models/unets/unet_3d_condition.py +13 -9
  61. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  62. diffusers/models/unets/unet_kandinsky3.py +1 -1
  63. diffusers/models/unets/unet_motion_model.py +114 -14
  64. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  65. diffusers/models/unets/unet_stable_cascade.py +16 -13
  66. diffusers/models/upsampling.py +17 -20
  67. diffusers/models/vq_model.py +16 -15
  68. diffusers/pipelines/__init__.py +27 -3
  69. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  70. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  71. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  72. diffusers/pipelines/animatediff/__init__.py +2 -0
  73. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  74. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  75. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  76. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  77. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  78. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  79. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  80. diffusers/pipelines/auto_pipeline.py +21 -17
  81. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  82. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  83. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  84. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  85. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  86. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  87. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  88. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  89. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  90. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  91. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  92. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  93. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  94. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  95. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  96. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  97. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  98. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  99. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  100. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  101. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  102. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  103. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  104. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  105. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  106. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  107. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  108. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  109. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  110. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  111. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  112. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  113. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  114. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  115. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  116. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  117. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  118. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  119. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  120. diffusers/pipelines/free_init_utils.py +39 -38
  121. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  122. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
  123. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  124. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  125. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  126. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  127. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  128. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  129. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  130. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  131. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  132. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  133. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  134. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  135. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  136. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  137. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  138. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  139. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  140. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  141. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  142. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  143. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  144. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  145. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  146. diffusers/pipelines/marigold/__init__.py +50 -0
  147. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  148. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  149. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  150. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  151. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  152. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  153. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  154. diffusers/pipelines/pipeline_loading_utils.py +269 -23
  155. diffusers/pipelines/pipeline_utils.py +266 -37
  156. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
  158. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  159. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  160. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  161. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  162. diffusers/pipelines/shap_e/renderer.py +1 -1
  163. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  164. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  165. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  166. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  167. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  168. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  169. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  172. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  173. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  174. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  175. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  176. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  177. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  178. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  179. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  180. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  181. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  182. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  183. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  184. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  185. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  186. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  187. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  188. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  189. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  190. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  191. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  192. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  193. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  194. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  195. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  196. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  197. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  198. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  199. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  200. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  201. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  202. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  203. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  204. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  205. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  206. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  207. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  208. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  209. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  210. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  211. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  212. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  213. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  214. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  215. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  216. diffusers/schedulers/__init__.py +2 -2
  217. diffusers/schedulers/deprecated/__init__.py +1 -1
  218. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  219. diffusers/schedulers/scheduling_amused.py +5 -5
  220. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  221. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  222. diffusers/schedulers/scheduling_ddim.py +22 -24
  223. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  224. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  225. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  226. diffusers/schedulers/scheduling_ddpm.py +20 -22
  227. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  228. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  229. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  230. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  231. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  232. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  236. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  237. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  238. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  239. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  240. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  241. diffusers/schedulers/scheduling_ipndm.py +8 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  244. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  245. diffusers/schedulers/scheduling_lcm.py +21 -23
  246. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  247. diffusers/schedulers/scheduling_pndm.py +20 -20
  248. diffusers/schedulers/scheduling_repaint.py +20 -20
  249. diffusers/schedulers/scheduling_sasolver.py +55 -54
  250. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  251. diffusers/schedulers/scheduling_tcd.py +39 -30
  252. diffusers/schedulers/scheduling_unclip.py +15 -15
  253. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  254. diffusers/schedulers/scheduling_utils.py +14 -5
  255. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  256. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  257. diffusers/training_utils.py +56 -1
  258. diffusers/utils/__init__.py +7 -0
  259. diffusers/utils/doc_utils.py +1 -0
  260. diffusers/utils/dummy_pt_objects.py +75 -0
  261. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  262. diffusers/utils/dynamic_modules_utils.py +24 -11
  263. diffusers/utils/hub_utils.py +3 -2
  264. diffusers/utils/import_utils.py +91 -0
  265. diffusers/utils/loading_utils.py +2 -2
  266. diffusers/utils/logging.py +1 -1
  267. diffusers/utils/peft_utils.py +32 -5
  268. diffusers/utils/state_dict_utils.py +11 -2
  269. diffusers/utils/testing_utils.py +71 -6
  270. diffusers/utils/torch_utils.py +1 -0
  271. diffusers/video_processor.py +113 -0
  272. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
  273. diffusers-0.28.1.dist-info/RECORD +419 -0
  274. diffusers-0.27.2.dist-info/RECORD +0 -399
  275. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
  276. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
  277. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
  278. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
@@ -92,6 +92,51 @@ class ControlNetModel(metaclass=DummyObject):
92
92
  requires_backends(cls, ["torch"])
93
93
 
94
94
 
95
+ class ControlNetXSAdapter(metaclass=DummyObject):
96
+ _backends = ["torch"]
97
+
98
+ def __init__(self, *args, **kwargs):
99
+ requires_backends(self, ["torch"])
100
+
101
+ @classmethod
102
+ def from_config(cls, *args, **kwargs):
103
+ requires_backends(cls, ["torch"])
104
+
105
+ @classmethod
106
+ def from_pretrained(cls, *args, **kwargs):
107
+ requires_backends(cls, ["torch"])
108
+
109
+
110
+ class DiTTransformer2DModel(metaclass=DummyObject):
111
+ _backends = ["torch"]
112
+
113
+ def __init__(self, *args, **kwargs):
114
+ requires_backends(self, ["torch"])
115
+
116
+ @classmethod
117
+ def from_config(cls, *args, **kwargs):
118
+ requires_backends(cls, ["torch"])
119
+
120
+ @classmethod
121
+ def from_pretrained(cls, *args, **kwargs):
122
+ requires_backends(cls, ["torch"])
123
+
124
+
125
+ class HunyuanDiT2DModel(metaclass=DummyObject):
126
+ _backends = ["torch"]
127
+
128
+ def __init__(self, *args, **kwargs):
129
+ requires_backends(self, ["torch"])
130
+
131
+ @classmethod
132
+ def from_config(cls, *args, **kwargs):
133
+ requires_backends(cls, ["torch"])
134
+
135
+ @classmethod
136
+ def from_pretrained(cls, *args, **kwargs):
137
+ requires_backends(cls, ["torch"])
138
+
139
+
95
140
  class I2VGenXLUNet(metaclass=DummyObject):
96
141
  _backends = ["torch"]
97
142
 
@@ -167,6 +212,21 @@ class MultiAdapter(metaclass=DummyObject):
167
212
  requires_backends(cls, ["torch"])
168
213
 
169
214
 
215
+ class PixArtTransformer2DModel(metaclass=DummyObject):
216
+ _backends = ["torch"]
217
+
218
+ def __init__(self, *args, **kwargs):
219
+ requires_backends(self, ["torch"])
220
+
221
+ @classmethod
222
+ def from_config(cls, *args, **kwargs):
223
+ requires_backends(cls, ["torch"])
224
+
225
+ @classmethod
226
+ def from_pretrained(cls, *args, **kwargs):
227
+ requires_backends(cls, ["torch"])
228
+
229
+
170
230
  class PriorTransformer(metaclass=DummyObject):
171
231
  _backends = ["torch"]
172
232
 
@@ -287,6 +347,21 @@ class UNet3DConditionModel(metaclass=DummyObject):
287
347
  requires_backends(cls, ["torch"])
288
348
 
289
349
 
350
+ class UNetControlNetXSModel(metaclass=DummyObject):
351
+ _backends = ["torch"]
352
+
353
+ def __init__(self, *args, **kwargs):
354
+ requires_backends(self, ["torch"])
355
+
356
+ @classmethod
357
+ def from_config(cls, *args, **kwargs):
358
+ requires_backends(cls, ["torch"])
359
+
360
+ @classmethod
361
+ def from_pretrained(cls, *args, **kwargs):
362
+ requires_backends(cls, ["torch"])
363
+
364
+
290
365
  class UNetMotionModel(metaclass=DummyObject):
291
366
  _backends = ["torch"]
292
367
 
@@ -92,6 +92,21 @@ class AnimateDiffPipeline(metaclass=DummyObject):
92
92
  requires_backends(cls, ["torch", "transformers"])
93
93
 
94
94
 
95
+ class AnimateDiffSDXLPipeline(metaclass=DummyObject):
96
+ _backends = ["torch", "transformers"]
97
+
98
+ def __init__(self, *args, **kwargs):
99
+ requires_backends(self, ["torch", "transformers"])
100
+
101
+ @classmethod
102
+ def from_config(cls, *args, **kwargs):
103
+ requires_backends(cls, ["torch", "transformers"])
104
+
105
+ @classmethod
106
+ def from_pretrained(cls, *args, **kwargs):
107
+ requires_backends(cls, ["torch", "transformers"])
108
+
109
+
95
110
  class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject):
96
111
  _backends = ["torch", "transformers"]
97
112
 
@@ -197,6 +212,21 @@ class CycleDiffusionPipeline(metaclass=DummyObject):
197
212
  requires_backends(cls, ["torch", "transformers"])
198
213
 
199
214
 
215
+ class HunyuanDiTPipeline(metaclass=DummyObject):
216
+ _backends = ["torch", "transformers"]
217
+
218
+ def __init__(self, *args, **kwargs):
219
+ requires_backends(self, ["torch", "transformers"])
220
+
221
+ @classmethod
222
+ def from_config(cls, *args, **kwargs):
223
+ requires_backends(cls, ["torch", "transformers"])
224
+
225
+ @classmethod
226
+ def from_pretrained(cls, *args, **kwargs):
227
+ requires_backends(cls, ["torch", "transformers"])
228
+
229
+
200
230
  class I2VGenXLPipeline(metaclass=DummyObject):
201
231
  _backends = ["torch", "transformers"]
202
232
 
@@ -677,6 +707,36 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
677
707
  requires_backends(cls, ["torch", "transformers"])
678
708
 
679
709
 
710
+ class MarigoldDepthPipeline(metaclass=DummyObject):
711
+ _backends = ["torch", "transformers"]
712
+
713
+ def __init__(self, *args, **kwargs):
714
+ requires_backends(self, ["torch", "transformers"])
715
+
716
+ @classmethod
717
+ def from_config(cls, *args, **kwargs):
718
+ requires_backends(cls, ["torch", "transformers"])
719
+
720
+ @classmethod
721
+ def from_pretrained(cls, *args, **kwargs):
722
+ requires_backends(cls, ["torch", "transformers"])
723
+
724
+
725
+ class MarigoldNormalsPipeline(metaclass=DummyObject):
726
+ _backends = ["torch", "transformers"]
727
+
728
+ def __init__(self, *args, **kwargs):
729
+ requires_backends(self, ["torch", "transformers"])
730
+
731
+ @classmethod
732
+ def from_config(cls, *args, **kwargs):
733
+ requires_backends(cls, ["torch", "transformers"])
734
+
735
+ @classmethod
736
+ def from_pretrained(cls, *args, **kwargs):
737
+ requires_backends(cls, ["torch", "transformers"])
738
+
739
+
680
740
  class MusicLDMPipeline(metaclass=DummyObject):
681
741
  _backends = ["torch", "transformers"]
682
742
 
@@ -737,6 +797,21 @@ class PixArtAlphaPipeline(metaclass=DummyObject):
737
797
  requires_backends(cls, ["torch", "transformers"])
738
798
 
739
799
 
800
+ class PixArtSigmaPipeline(metaclass=DummyObject):
801
+ _backends = ["torch", "transformers"]
802
+
803
+ def __init__(self, *args, **kwargs):
804
+ requires_backends(self, ["torch", "transformers"])
805
+
806
+ @classmethod
807
+ def from_config(cls, *args, **kwargs):
808
+ requires_backends(cls, ["torch", "transformers"])
809
+
810
+ @classmethod
811
+ def from_pretrained(cls, *args, **kwargs):
812
+ requires_backends(cls, ["torch", "transformers"])
813
+
814
+
740
815
  class SemanticStableDiffusionPipeline(metaclass=DummyObject):
741
816
  _backends = ["torch", "transformers"]
742
817
 
@@ -902,6 +977,21 @@ class StableDiffusionControlNetPipeline(metaclass=DummyObject):
902
977
  requires_backends(cls, ["torch", "transformers"])
903
978
 
904
979
 
980
+ class StableDiffusionControlNetXSPipeline(metaclass=DummyObject):
981
+ _backends = ["torch", "transformers"]
982
+
983
+ def __init__(self, *args, **kwargs):
984
+ requires_backends(self, ["torch", "transformers"])
985
+
986
+ @classmethod
987
+ def from_config(cls, *args, **kwargs):
988
+ requires_backends(cls, ["torch", "transformers"])
989
+
990
+ @classmethod
991
+ def from_pretrained(cls, *args, **kwargs):
992
+ requires_backends(cls, ["torch", "transformers"])
993
+
994
+
905
995
  class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject):
906
996
  _backends = ["torch", "transformers"]
907
997
 
@@ -1247,6 +1337,21 @@ class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
1247
1337
  requires_backends(cls, ["torch", "transformers"])
1248
1338
 
1249
1339
 
1340
+ class StableDiffusionXLControlNetXSPipeline(metaclass=DummyObject):
1341
+ _backends = ["torch", "transformers"]
1342
+
1343
+ def __init__(self, *args, **kwargs):
1344
+ requires_backends(self, ["torch", "transformers"])
1345
+
1346
+ @classmethod
1347
+ def from_config(cls, *args, **kwargs):
1348
+ requires_backends(cls, ["torch", "transformers"])
1349
+
1350
+ @classmethod
1351
+ def from_pretrained(cls, *args, **kwargs):
1352
+ requires_backends(cls, ["torch", "transformers"])
1353
+
1354
+
1250
1355
  class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
1251
1356
  _backends = ["torch", "transformers"]
1252
1357
 
@@ -201,7 +201,7 @@ def get_cached_module_file(
201
201
  module_file: str,
202
202
  cache_dir: Optional[Union[str, os.PathLike]] = None,
203
203
  force_download: bool = False,
204
- resume_download: bool = False,
204
+ resume_download: Optional[bool] = None,
205
205
  proxies: Optional[Dict[str, str]] = None,
206
206
  token: Optional[Union[bool, str]] = None,
207
207
  revision: Optional[str] = None,
@@ -228,9 +228,9 @@ def get_cached_module_file(
228
228
  cache should not be used.
229
229
  force_download (`bool`, *optional*, defaults to `False`):
230
230
  Whether or not to force to (re-)download the configuration files and override the cached versions if they
231
- exist.
232
- resume_download (`bool`, *optional*, defaults to `False`):
233
- Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
231
+ exist. resume_download:
232
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
233
+ of Diffusers.
234
234
  proxies (`Dict[str, str]`, *optional*):
235
235
  A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
236
236
  'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
@@ -246,8 +246,8 @@ def get_cached_module_file(
246
246
 
247
247
  <Tip>
248
248
 
249
- You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private
250
- or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
249
+ You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or
250
+ [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
251
251
 
252
252
  </Tip>
253
253
 
@@ -329,6 +329,11 @@ def get_cached_module_file(
329
329
  # The only reason we do the copy is to avoid putting too many folders in sys.path.
330
330
  shutil.copy(resolved_module_file, submodule_path / module_file)
331
331
  for module_needed in modules_needed:
332
+ if len(module_needed.split(".")) == 2:
333
+ module_needed = "/".join(module_needed.split("."))
334
+ module_folder = module_needed.split("/")[0]
335
+ if not os.path.exists(submodule_path / module_folder):
336
+ os.makedirs(submodule_path / module_folder)
332
337
  module_needed = f"{module_needed}.py"
333
338
  shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
334
339
  else:
@@ -343,9 +348,16 @@ def get_cached_module_file(
343
348
  create_dynamic_module(full_submodule)
344
349
 
345
350
  if not (submodule_path / module_file).exists():
351
+ if len(module_file.split("/")) == 2:
352
+ module_folder = module_file.split("/")[0]
353
+ if not os.path.exists(submodule_path / module_folder):
354
+ os.makedirs(submodule_path / module_folder)
346
355
  shutil.copy(resolved_module_file, submodule_path / module_file)
356
+
347
357
  # Make sure we also have every file with relative
348
358
  for module_needed in modules_needed:
359
+ if len(module_needed.split(".")) == 2:
360
+ module_needed = "/".join(module_needed.split("."))
349
361
  if not (submodule_path / module_needed).exists():
350
362
  get_cached_module_file(
351
363
  pretrained_model_name_or_path,
@@ -368,7 +380,7 @@ def get_class_from_dynamic_module(
368
380
  class_name: Optional[str] = None,
369
381
  cache_dir: Optional[Union[str, os.PathLike]] = None,
370
382
  force_download: bool = False,
371
- resume_download: bool = False,
383
+ resume_download: Optional[bool] = None,
372
384
  proxies: Optional[Dict[str, str]] = None,
373
385
  token: Optional[Union[bool, str]] = None,
374
386
  revision: Optional[str] = None,
@@ -405,8 +417,9 @@ def get_class_from_dynamic_module(
405
417
  force_download (`bool`, *optional*, defaults to `False`):
406
418
  Whether or not to force to (re-)download the configuration files and override the cached versions if they
407
419
  exist.
408
- resume_download (`bool`, *optional*, defaults to `False`):
409
- Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
420
+ resume_download:
421
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 of
422
+ Diffusers.
410
423
  proxies (`Dict[str, str]`, *optional*):
411
424
  A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
412
425
  'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
@@ -422,8 +435,8 @@ def get_class_from_dynamic_module(
422
435
 
423
436
  <Tip>
424
437
 
425
- You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private
426
- or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
438
+ You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or
439
+ [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
427
440
 
428
441
  </Tip>
429
442
 
@@ -112,7 +112,8 @@ def load_or_create_model_card(
112
112
  repo_id_or_path (`str`):
113
113
  The repo id (e.g., "runwayml/stable-diffusion-v1-5") or local path where to look for the model card.
114
114
  token (`str`, *optional*):
115
- Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details.
115
+ Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more
116
+ details.
116
117
  is_pipeline (`bool`):
117
118
  Boolean to indicate if we're adding tag to a [`DiffusionPipeline`].
118
119
  from_training: (`bool`): Boolean flag to denote if the model card is being created from a training script.
@@ -282,7 +283,7 @@ def _get_model_file(
282
283
  cache_dir: Optional[str] = None,
283
284
  force_download: bool = False,
284
285
  proxies: Optional[Dict] = None,
285
- resume_download: bool = False,
286
+ resume_download: Optional[bool] = None,
286
287
  local_files_only: bool = False,
287
288
  token: Optional[str] = None,
288
289
  user_agent: Optional[Union[Dict, str]] = None,
@@ -295,6 +295,46 @@ try:
295
295
  except importlib_metadata.PackageNotFoundError:
296
296
  _torchvision_available = False
297
297
 
298
+ _matplotlib_available = importlib.util.find_spec("matplotlib") is not None
299
+ try:
300
+ _matplotlib_version = importlib_metadata.version("matplotlib")
301
+ logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}")
302
+ except importlib_metadata.PackageNotFoundError:
303
+ _matplotlib_available = False
304
+
305
+ _timm_available = importlib.util.find_spec("timm") is not None
306
+ if _timm_available:
307
+ try:
308
+ _timm_version = importlib_metadata.version("timm")
309
+ logger.info(f"Timm version {_timm_version} available.")
310
+ except importlib_metadata.PackageNotFoundError:
311
+ _timm_available = False
312
+
313
+
314
+ def is_timm_available():
315
+ return _timm_available
316
+
317
+
318
+ _bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None
319
+ try:
320
+ _bitsandbytes_version = importlib_metadata.version("bitsandbytes")
321
+ logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}")
322
+ except importlib_metadata.PackageNotFoundError:
323
+ _bitsandbytes_available = False
324
+
325
+ # Taken from `huggingface_hub`.
326
+ _is_notebook = False
327
+ try:
328
+ shell_class = get_ipython().__class__ # type: ignore # noqa: F821
329
+ for parent_class in shell_class.__mro__: # e.g. "is subclass of"
330
+ if parent_class.__name__ == "ZMQInteractiveShell":
331
+ _is_notebook = True # Jupyter notebook, Google colab or qtconsole
332
+ break
333
+ except NameError:
334
+ pass # Probably standard Python interpreter
335
+
336
+ _is_google_colab = "google.colab" in sys.modules
337
+
298
338
 
299
339
  def is_torch_available():
300
340
  return _torch_available
@@ -392,6 +432,26 @@ def is_torchvision_available():
392
432
  return _torchvision_available
393
433
 
394
434
 
435
+ def is_matplotlib_available():
436
+ return _matplotlib_available
437
+
438
+
439
+ def is_safetensors_available():
440
+ return _safetensors_available
441
+
442
+
443
+ def is_bitsandbytes_available():
444
+ return _bitsandbytes_available
445
+
446
+
447
+ def is_notebook():
448
+ return _is_notebook
449
+
450
+
451
+ def is_google_colab():
452
+ return _is_google_colab
453
+
454
+
395
455
  # docstyle-ignore
396
456
  FLAX_IMPORT_ERROR = """
397
457
  {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -499,6 +559,20 @@ INVISIBLE_WATERMARK_IMPORT_ERROR = """
499
559
  {0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0`
500
560
  """
501
561
 
562
+ # docstyle-ignore
563
+ PEFT_IMPORT_ERROR = """
564
+ {0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install peft`
565
+ """
566
+
567
+ # docstyle-ignore
568
+ SAFETENSORS_IMPORT_ERROR = """
569
+ {0} requires the safetensors library but it was not found in your environment. You can install it with pip: `pip install safetensors`
570
+ """
571
+
572
+ # docstyle-ignore
573
+ BITSANDBYTES_IMPORT_ERROR = """
574
+ {0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes`
575
+ """
502
576
 
503
577
  BACKENDS_MAPPING = OrderedDict(
504
578
  [
@@ -520,6 +594,9 @@ BACKENDS_MAPPING = OrderedDict(
520
594
  ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
521
595
  ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)),
522
596
  ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)),
597
+ ("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
598
+ ("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)),
599
+ ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
523
600
  ]
524
601
  )
525
602
 
@@ -628,6 +705,20 @@ def is_accelerate_version(operation: str, version: str):
628
705
  return compare_versions(parse(_accelerate_version), operation, version)
629
706
 
630
707
 
708
+ def is_peft_version(operation: str, version: str):
709
+ """
710
+ Args:
711
+ Compares the current PEFT version to a given reference with an operation.
712
+ operation (`str`):
713
+ A string representation of an operator, such as `">"` or `"<="`
714
+ version (`str`):
715
+ A version string
716
+ """
717
+ if not _peft_version:
718
+ return False
719
+ return compare_versions(parse(_peft_version), operation, version)
720
+
721
+
631
722
  def is_k_diffusion_version(operation: str, version: str):
632
723
  """
633
724
  Args:
@@ -16,8 +16,8 @@ def load_image(
16
16
  image (`str` or `PIL.Image.Image`):
17
17
  The image to convert to the PIL Image format.
18
18
  convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional):
19
- A conversion method to apply to the image after loading it.
20
- When set to `None` the image will be converted "RGB".
19
+ A conversion method to apply to the image after loading it. When set to `None` the image will be converted
20
+ "RGB".
21
21
 
22
22
  Returns:
23
23
  `PIL.Image.Image`:
@@ -12,7 +12,7 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
- """ Logging utilities."""
15
+ """Logging utilities."""
16
16
 
17
17
  import logging
18
18
  import os
@@ -14,6 +14,7 @@
14
14
  """
15
15
  PEFT utilities: Utilities related to peft library
16
16
  """
17
+
17
18
  import collections
18
19
  import importlib
19
20
  from typing import Optional
@@ -63,9 +64,11 @@ def recurse_remove_peft_layers(model):
63
64
  module_replaced = False
64
65
 
65
66
  if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
66
- new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
67
- module.weight.device
68
- )
67
+ new_module = torch.nn.Linear(
68
+ module.in_features,
69
+ module.out_features,
70
+ bias=module.bias is not None,
71
+ ).to(module.weight.device)
69
72
  new_module.weight = module.weight
70
73
  if module.bias is not None:
71
74
  new_module.bias = module.bias
@@ -109,6 +112,9 @@ def scale_lora_layers(model, weight):
109
112
  """
110
113
  from peft.tuners.tuners_utils import BaseTunerLayer
111
114
 
115
+ if weight == 1.0:
116
+ return
117
+
112
118
  for module in model.modules():
113
119
  if isinstance(module, BaseTunerLayer):
114
120
  module.scale_layer(weight)
@@ -128,6 +134,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
128
134
  """
129
135
  from peft.tuners.tuners_utils import BaseTunerLayer
130
136
 
137
+ if weight == 1.0:
138
+ return
139
+
131
140
  for module in model.modules():
132
141
  if isinstance(module, BaseTunerLayer):
133
142
  if weight is not None and weight != 0:
@@ -170,6 +179,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
170
179
 
171
180
  # layer names without the Diffusers specific
172
181
  target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
182
+ use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
173
183
 
174
184
  lora_config_kwargs = {
175
185
  "r": r,
@@ -177,6 +187,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
177
187
  "rank_pattern": rank_pattern,
178
188
  "alpha_pattern": alpha_pattern,
179
189
  "target_modules": target_modules,
190
+ "use_dora": use_dora,
180
191
  }
181
192
  return lora_config_kwargs
182
193
 
@@ -227,16 +238,32 @@ def delete_adapter_layers(model, adapter_name):
227
238
  def set_weights_and_activate_adapters(model, adapter_names, weights):
228
239
  from peft.tuners.tuners_utils import BaseTunerLayer
229
240
 
241
+ def get_module_weight(weight_for_adapter, module_name):
242
+ if not isinstance(weight_for_adapter, dict):
243
+ # If weight_for_adapter is a single number, always return it.
244
+ return weight_for_adapter
245
+
246
+ for layer_name, weight_ in weight_for_adapter.items():
247
+ if layer_name in module_name:
248
+ return weight_
249
+
250
+ parts = module_name.split(".")
251
+ # e.g. key = "down_blocks.1.attentions.0"
252
+ key = f"{parts[0]}.{parts[1]}.attentions.{parts[3]}"
253
+ block_weight = weight_for_adapter.get(key, 1.0)
254
+
255
+ return block_weight
256
+
230
257
  # iterate over each adapter, make it active and set the corresponding scaling weight
231
258
  for adapter_name, weight in zip(adapter_names, weights):
232
- for module in model.modules():
259
+ for module_name, module in model.named_modules():
233
260
  if isinstance(module, BaseTunerLayer):
234
261
  # For backward compatbility with previous PEFT versions
235
262
  if hasattr(module, "set_adapter"):
236
263
  module.set_adapter(adapter_name)
237
264
  else:
238
265
  module.active_adapter = adapter_name
239
- module.set_scale(adapter_name, weight)
266
+ module.set_scale(adapter_name, get_module_weight(weight, module_name))
240
267
 
241
268
  # set multiple active adapters
242
269
  for module in model.modules():
@@ -14,6 +14,7 @@
14
14
  """
15
15
  State dict utilities: utility methods for converting state dicts easily
16
16
  """
17
+
17
18
  import enum
18
19
 
19
20
  from .logging import get_logger
@@ -46,6 +47,7 @@ UNET_TO_DIFFUSERS = {
46
47
  ".to_v_lora.up": ".to_v.lora_B",
47
48
  ".lora.up": ".lora_B",
48
49
  ".lora.down": ".lora_A",
50
+ ".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector",
49
51
  }
50
52
 
51
53
 
@@ -103,6 +105,10 @@ DIFFUSERS_OLD_TO_DIFFUSERS = {
103
105
  ".to_v_lora.down": ".v_proj.lora_linear_layer.down",
104
106
  ".to_out_lora.up": ".out_proj.lora_linear_layer.up",
105
107
  ".to_out_lora.down": ".out_proj.lora_linear_layer.down",
108
+ ".to_k.lora_magnitude_vector": ".k_proj.lora_magnitude_vector",
109
+ ".to_v.lora_magnitude_vector": ".v_proj.lora_magnitude_vector",
110
+ ".to_q.lora_magnitude_vector": ".q_proj.lora_magnitude_vector",
111
+ ".to_out.lora_magnitude_vector": ".out_proj.lora_magnitude_vector",
106
112
  }
107
113
 
108
114
  PEFT_TO_KOHYA_SS = {
@@ -247,8 +253,8 @@ def convert_unet_state_dict_to_peft(state_dict):
247
253
 
248
254
  def convert_all_state_dict_to_peft(state_dict):
249
255
  r"""
250
- Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer`
251
- for a valid `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft`
256
+ Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid
257
+ `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft`
252
258
  """
253
259
  try:
254
260
  peft_dict = convert_state_dict_to_peft(state_dict)
@@ -314,6 +320,9 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
314
320
  kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
315
321
  elif "unet" in kohya_key:
316
322
  kohya_key = kohya_key.replace("unet", "lora_unet")
323
+ elif "lora_magnitude_vector" in kohya_key:
324
+ kohya_key = kohya_key.replace("lora_magnitude_vector", "dora_scale")
325
+
317
326
  kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
318
327
  kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
319
328
  kohya_ss_state_dict[kohya_key] = weight