diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
diffusers/loaders/peft.py CHANGED
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
  import inspect
16
+ import json
16
17
  import os
17
18
  from functools import partial
18
19
  from pathlib import Path
@@ -28,13 +29,13 @@ from ..utils import (
28
29
  convert_unet_state_dict_to_peft,
29
30
  delete_adapter_layers,
30
31
  get_adapter_name,
31
- get_peft_kwargs,
32
32
  is_peft_available,
33
33
  is_peft_version,
34
34
  logging,
35
35
  set_adapter_layers,
36
36
  set_weights_and_activate_adapters,
37
37
  )
38
+ from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys
38
39
  from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
39
40
  from .unet_loader_utils import _maybe_expand_lora_scales
40
41
 
@@ -52,32 +53,18 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
52
53
  "HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
53
54
  "LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
54
55
  "SanaTransformer2DModel": lambda model_cls, weights: weights,
56
+ "AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
55
57
  "Lumina2Transformer2DModel": lambda model_cls, weights: weights,
56
58
  "WanTransformer3DModel": lambda model_cls, weights: weights,
57
59
  "CogView4Transformer2DModel": lambda model_cls, weights: weights,
60
+ "HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
61
+ "HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
62
+ "WanVACETransformer3DModel": lambda model_cls, weights: weights,
63
+ "ChromaTransformer2DModel": lambda model_cls, weights: weights,
64
+ "QwenImageTransformer2DModel": lambda model_cls, weights: weights,
58
65
  }
59
66
 
60
67
 
61
- def _maybe_raise_error_for_ambiguity(config):
62
- rank_pattern = config["rank_pattern"].copy()
63
- target_modules = config["target_modules"]
64
-
65
- for key in list(rank_pattern.keys()):
66
- # try to detect ambiguity
67
- # `target_modules` can also be a str, in which case this loop would loop
68
- # over the chars of the str. The technically correct way to match LoRA keys
69
- # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
70
- # But this cuts it for now.
71
- exact_matches = [mod for mod in target_modules if mod == key]
72
- substring_matches = [mod for mod in target_modules if key in mod and mod != key]
73
-
74
- if exact_matches and substring_matches:
75
- if is_peft_version("<", "0.14.1"):
76
- raise ValueError(
77
- "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
78
- )
79
-
80
-
81
68
  class PeftAdapterMixin:
82
69
  """
83
70
  A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
@@ -99,17 +86,6 @@ class PeftAdapterMixin:
99
86
  @classmethod
100
87
  # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
101
88
  def _optionally_disable_offloading(cls, _pipeline):
102
- """
103
- Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
104
-
105
- Args:
106
- _pipeline (`DiffusionPipeline`):
107
- The pipeline to disable offloading for.
108
-
109
- Returns:
110
- tuple:
111
- A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
112
- """
113
89
  return _func_optionally_disable_offloading(_pipeline=_pipeline)
114
90
 
115
91
  def load_lora_adapter(
@@ -181,10 +157,15 @@ class PeftAdapterMixin:
181
157
  Note that hotswapping adapters of the text encoder is not yet supported. There are some further
182
158
  limitations to this technique, which are documented here:
183
159
  https://huggingface.co/docs/peft/main/en/package_reference/hotswap
160
+ metadata:
161
+ LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
162
+ initialize `LoraConfig`.
184
163
  """
185
- from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
164
+ from peft import inject_adapter_in_model, set_peft_model_state_dict
186
165
  from peft.tuners.tuners_utils import BaseTunerLayer
187
166
 
167
+ from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
168
+
188
169
  cache_dir = kwargs.pop("cache_dir", None)
189
170
  force_download = kwargs.pop("force_download", False)
190
171
  proxies = kwargs.pop("proxies", None)
@@ -198,6 +179,7 @@ class PeftAdapterMixin:
198
179
  network_alphas = kwargs.pop("network_alphas", None)
199
180
  _pipeline = kwargs.pop("_pipeline", None)
200
181
  low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
182
+ metadata = kwargs.pop("metadata", None)
201
183
  allow_pickle = False
202
184
 
203
185
  if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
@@ -205,12 +187,8 @@ class PeftAdapterMixin:
205
187
  "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
206
188
  )
207
189
 
208
- user_agent = {
209
- "file_type": "attn_procs_weights",
210
- "framework": "pytorch",
211
- }
212
-
213
- state_dict = _fetch_state_dict(
190
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
191
+ state_dict, metadata = _fetch_state_dict(
214
192
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
215
193
  weight_name=weight_name,
216
194
  use_safetensors=use_safetensors,
@@ -223,12 +201,17 @@ class PeftAdapterMixin:
223
201
  subfolder=subfolder,
224
202
  user_agent=user_agent,
225
203
  allow_pickle=allow_pickle,
204
+ metadata=metadata,
226
205
  )
227
206
  if network_alphas is not None and prefix is None:
228
207
  raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
208
+ if network_alphas and metadata:
209
+ raise ValueError("Both `network_alphas` and `metadata` cannot be specified.")
229
210
 
230
211
  if prefix is not None:
231
- state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
212
+ state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
213
+ if metadata is not None:
214
+ metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
232
215
 
233
216
  if len(state_dict) > 0:
234
217
  if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
@@ -248,7 +231,7 @@ class PeftAdapterMixin:
248
231
 
249
232
  rank = {}
250
233
  for key, val in state_dict.items():
251
- # Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
234
+ # Cannot figure out rank from lora layers that don't have at least 2 dimensions.
252
235
  # Bias layers in LoRA only have a single dimension
253
236
  if "lora_B" in key and val.ndim > 1:
254
237
  # Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
@@ -259,44 +242,33 @@ class PeftAdapterMixin:
259
242
 
260
243
  if network_alphas is not None and len(network_alphas) >= 1:
261
244
  alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
262
- network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
263
-
264
- lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
265
- _maybe_raise_error_for_ambiguity(lora_config_kwargs)
245
+ network_alphas = {
246
+ k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
247
+ }
266
248
 
267
- if "use_dora" in lora_config_kwargs:
268
- if lora_config_kwargs["use_dora"]:
269
- if is_peft_version("<", "0.9.0"):
270
- raise ValueError(
271
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
272
- )
273
- else:
274
- if is_peft_version("<", "0.9.0"):
275
- lora_config_kwargs.pop("use_dora")
276
-
277
- if "lora_bias" in lora_config_kwargs:
278
- if lora_config_kwargs["lora_bias"]:
279
- if is_peft_version("<=", "0.13.2"):
280
- raise ValueError(
281
- "You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
282
- )
283
- else:
284
- if is_peft_version("<=", "0.13.2"):
285
- lora_config_kwargs.pop("lora_bias")
286
-
287
- lora_config = LoraConfig(**lora_config_kwargs)
288
249
  # adapter_name
289
250
  if adapter_name is None:
290
251
  adapter_name = get_adapter_name(self)
291
252
 
253
+ # create LoraConfig
254
+ lora_config = _create_lora_config(
255
+ state_dict,
256
+ network_alphas,
257
+ metadata,
258
+ rank,
259
+ model_state_dict=self.state_dict(),
260
+ adapter_name=adapter_name,
261
+ )
262
+
292
263
  # <Unsafe code
293
264
  # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
294
265
  # Now we remove any existing hooks to `_pipeline`.
295
266
 
296
267
  # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
297
- # otherwise loading LoRA weights will lead to an error
298
- is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
299
-
268
+ # otherwise loading LoRA weights will lead to an error.
269
+ is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
270
+ _pipeline
271
+ )
300
272
  peft_kwargs = {}
301
273
  if is_peft_version(">=", "0.13.1"):
302
274
  peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -328,7 +300,7 @@ class PeftAdapterMixin:
328
300
  new_sd[k] = v
329
301
  return new_sd
330
302
 
331
- # To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
303
+ # To handle scenarios where we cannot successfully set state dict. If it's unsuccessful,
332
304
  # we should also delete the `peft_config` associated to the `adapter_name`.
333
305
  try:
334
306
  if hotswap:
@@ -342,13 +314,15 @@ class PeftAdapterMixin:
342
314
  config=lora_config,
343
315
  )
344
316
  except Exception as e:
345
- logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}")
317
+ logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error: \n{e}")
346
318
  raise
347
319
  # the hotswap function raises if there are incompatible keys, so if we reach this point we can set
348
320
  # it to None
349
321
  incompatible_keys = None
350
322
  else:
351
- inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
323
+ inject_adapter_in_model(
324
+ lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs
325
+ )
352
326
  incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
353
327
 
354
328
  if self._prepare_lora_hotswap_kwargs is not None:
@@ -377,46 +351,28 @@ class PeftAdapterMixin:
377
351
  module.delete_adapter(adapter_name)
378
352
 
379
353
  self.peft_config.pop(adapter_name)
380
- logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
354
+ logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}")
381
355
  raise
382
356
 
383
- warn_msg = ""
384
- if incompatible_keys is not None:
385
- # Check only for unexpected keys.
386
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
387
- if unexpected_keys:
388
- lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
389
- if lora_unexpected_keys:
390
- warn_msg = (
391
- f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
392
- f" {', '.join(lora_unexpected_keys)}. "
393
- )
394
-
395
- # Filter missing keys specific to the current adapter.
396
- missing_keys = getattr(incompatible_keys, "missing_keys", None)
397
- if missing_keys:
398
- lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
399
- if lora_missing_keys:
400
- warn_msg += (
401
- f"Loading adapter weights from state_dict led to missing keys in the model:"
402
- f" {', '.join(lora_missing_keys)}."
403
- )
404
-
405
- if warn_msg:
406
- logger.warning(warn_msg)
357
+ _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name)
407
358
 
408
359
  # Offload back.
409
360
  if is_model_cpu_offload:
410
361
  _pipeline.enable_model_cpu_offload()
411
362
  elif is_sequential_cpu_offload:
412
363
  _pipeline.enable_sequential_cpu_offload()
364
+ elif is_group_offload:
365
+ for component in _pipeline.components.values():
366
+ if isinstance(component, torch.nn.Module):
367
+ _maybe_remove_and_reapply_group_offloading(component)
413
368
  # Unsafe code />
414
369
 
415
370
  if prefix is not None and not state_dict:
371
+ model_class_name = self.__class__.__name__
416
372
  logger.warning(
417
- f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. "
373
+ f"No LoRA keys associated to {model_class_name} found with the {prefix=}. "
418
374
  "This is safe to ignore if LoRA state dict didn't originally have any "
419
- f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` "
375
+ f"{model_class_name} related params. You can also try specifying `prefix=None` "
420
376
  "to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
421
377
  "https://github.com/huggingface/diffusers/issues/new"
422
378
  )
@@ -439,17 +395,13 @@ class PeftAdapterMixin:
439
395
  underlying model has multiple adapters loaded.
440
396
  upcast_before_saving (`bool`, defaults to `False`):
441
397
  Whether to cast the underlying model to `torch.float32` before serialization.
442
- save_function (`Callable`):
443
- The function to use to save the state dictionary. Useful during distributed training when you need to
444
- replace `torch.save` with another method. Can be configured with the environment variable
445
- `DIFFUSERS_SAVE_MODE`.
446
398
  safe_serialization (`bool`, *optional*, defaults to `True`):
447
399
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
448
400
  weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
449
401
  """
450
402
  from peft.utils import get_peft_model_state_dict
451
403
 
452
- from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
404
+ from .lora_base import LORA_ADAPTER_METADATA_KEY, LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
453
405
 
454
406
  if adapter_name is None:
455
407
  adapter_name = get_adapter_name(self)
@@ -457,6 +409,8 @@ class PeftAdapterMixin:
457
409
  if adapter_name not in getattr(self, "peft_config", {}):
458
410
  raise ValueError(f"Adapter name {adapter_name} not found in the model.")
459
411
 
412
+ lora_adapter_metadata = self.peft_config[adapter_name].to_dict()
413
+
460
414
  lora_layers_to_save = get_peft_model_state_dict(
461
415
  self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
462
416
  )
@@ -466,7 +420,15 @@ class PeftAdapterMixin:
466
420
  if safe_serialization:
467
421
 
468
422
  def save_function(weights, filename):
469
- return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
423
+ # Inject framework format.
424
+ metadata = {"format": "pt"}
425
+ if lora_adapter_metadata is not None:
426
+ for key, value in lora_adapter_metadata.items():
427
+ if isinstance(value, set):
428
+ lora_adapter_metadata[key] = list(value)
429
+ metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
430
+
431
+ return safetensors.torch.save_file(weights, filename, metadata=metadata)
470
432
 
471
433
  else:
472
434
  save_function = torch.save
@@ -479,7 +441,6 @@ class PeftAdapterMixin:
479
441
  else:
480
442
  weight_name = LORA_WEIGHT_NAME
481
443
 
482
- # TODO: we could consider saving the `peft_config` as well.
483
444
  save_path = Path(save_directory, weight_name).as_posix()
484
445
  save_function(lora_layers_to_save, save_path)
485
446
  logger.info(f"Model weights saved in {save_path}")
@@ -490,7 +451,7 @@ class PeftAdapterMixin:
490
451
  weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
491
452
  ):
492
453
  """
493
- Set the currently active adapters for use in the UNet.
454
+ Set the currently active adapters for use in the diffusion network (e.g. unet, transformer, etc.).
494
455
 
495
456
  Args:
496
457
  adapter_names (`List[str]` or `str`):
@@ -512,7 +473,7 @@ class PeftAdapterMixin:
512
473
  "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
513
474
  )
514
475
  pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
515
- pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
476
+ pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
516
477
  ```
517
478
  """
518
479
  if not USE_PEFT_BACKEND:
@@ -710,7 +671,7 @@ class PeftAdapterMixin:
710
671
  if self.lora_scale != 1.0:
711
672
  module.scale_layer(self.lora_scale)
712
673
 
713
- # For BC with prevous PEFT versions, we need to check the signature
674
+ # For BC with previous PEFT versions, we need to check the signature
714
675
  # of the `merge` method to see if it supports the `adapter_names` argument.
715
676
  supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
716
677
  if "adapter_names" in supported_merge_kwargs:
@@ -738,11 +699,16 @@ class PeftAdapterMixin:
738
699
  if not USE_PEFT_BACKEND:
739
700
  raise ValueError("PEFT backend is required for `unload_lora()`.")
740
701
 
702
+ from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
741
703
  from ..utils import recurse_remove_peft_layers
742
704
 
743
705
  recurse_remove_peft_layers(self)
744
706
  if hasattr(self, "peft_config"):
745
707
  del self.peft_config
708
+ if hasattr(self, "_hf_peft_config_loaded"):
709
+ self._hf_peft_config_loaded = None
710
+
711
+ _maybe_remove_and_reapply_group_offloading(self)
746
712
 
747
713
  def disable_lora(self):
748
714
  """
@@ -760,7 +726,7 @@ class PeftAdapterMixin:
760
726
  pipeline.load_lora_weights(
761
727
  "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
762
728
  )
763
- pipeline.disable_lora()
729
+ pipeline.unet.disable_lora()
764
730
  ```
765
731
  """
766
732
  if not USE_PEFT_BACKEND:
@@ -783,7 +749,7 @@ class PeftAdapterMixin:
783
749
  pipeline.load_lora_weights(
784
750
  "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
785
751
  )
786
- pipeline.enable_lora()
752
+ pipeline.unet.enable_lora()
787
753
  ```
788
754
  """
789
755
  if not USE_PEFT_BACKEND:
@@ -810,7 +776,7 @@ class PeftAdapterMixin:
810
776
  pipeline.load_lora_weights(
811
777
  "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
812
778
  )
813
- pipeline.delete_adapters("cinematic")
779
+ pipeline.unet.delete_adapters("cinematic")
814
780
  ```
815
781
  """
816
782
  if not USE_PEFT_BACKEND:
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -453,7 +453,7 @@ class FromSingleFileMixin:
453
453
  logger.warning(
454
454
  "Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
455
455
  "This may lead to errors if the model components are not correctly inferred. \n"
456
- "To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
456
+ "To avoid this warning, please explicitly pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
457
457
  "e.g. `from_single_file(<my model checkpoint path>, config=<path to local diffusers model repo>) \n"
458
458
  "or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
459
459
  "the necessary config files.\n"
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -21,15 +21,20 @@ import torch
21
21
  from huggingface_hub.utils import validate_hf_hub_args
22
22
  from typing_extensions import Self
23
23
 
24
+ from .. import __version__
24
25
  from ..quantizers import DiffusersAutoQuantizer
25
- from ..utils import deprecate, is_accelerate_available, logging
26
+ from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
27
+ from ..utils.torch_utils import empty_device_cache
26
28
  from .single_file_utils import (
27
29
  SingleFileComponentError,
28
30
  convert_animatediff_checkpoint_to_diffusers,
29
31
  convert_auraflow_transformer_checkpoint_to_diffusers,
30
32
  convert_autoencoder_dc_checkpoint_to_diffusers,
33
+ convert_chroma_transformer_checkpoint_to_diffusers,
31
34
  convert_controlnet_checkpoint,
35
+ convert_cosmos_transformer_checkpoint_to_diffusers,
32
36
  convert_flux_transformer_checkpoint_to_diffusers,
37
+ convert_hidream_transformer_to_diffusers,
33
38
  convert_hunyuan_video_transformer_to_diffusers,
34
39
  convert_ldm_unet_checkpoint,
35
40
  convert_ldm_vae_checkpoint,
@@ -57,8 +62,12 @@ logger = logging.get_logger(__name__)
57
62
  if is_accelerate_available():
58
63
  from accelerate import dispatch_model, init_empty_weights
59
64
 
60
- from ..models.modeling_utils import load_model_dict_into_meta
65
+ from ..models.model_loading_utils import load_model_dict_into_meta
61
66
 
67
+ if is_torch_version(">=", "1.9.0") and is_accelerate_available():
68
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
69
+ else:
70
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
62
71
 
63
72
  SINGLE_FILE_LOADABLE_CLASSES = {
64
73
  "StableCascadeUNet": {
@@ -95,6 +104,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
95
104
  "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
96
105
  "default_subfolder": "transformer",
97
106
  },
107
+ "ChromaTransformer2DModel": {
108
+ "checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
109
+ "default_subfolder": "transformer",
110
+ },
98
111
  "LTXVideoTransformer3DModel": {
99
112
  "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
100
113
  "default_subfolder": "transformer",
@@ -128,13 +141,33 @@ SINGLE_FILE_LOADABLE_CLASSES = {
128
141
  "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
129
142
  "default_subfolder": "transformer",
130
143
  },
144
+ "WanVACETransformer3DModel": {
145
+ "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
146
+ "default_subfolder": "transformer",
147
+ },
131
148
  "AutoencoderKLWan": {
132
149
  "checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
133
150
  "default_subfolder": "vae",
134
151
  },
152
+ "HiDreamImageTransformer2DModel": {
153
+ "checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
154
+ "default_subfolder": "transformer",
155
+ },
156
+ "CosmosTransformer3DModel": {
157
+ "checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
158
+ "default_subfolder": "transformer",
159
+ },
160
+ "QwenImageTransformer2DModel": {
161
+ "checkpoint_mapping_fn": lambda x: x,
162
+ "default_subfolder": "transformer",
163
+ },
135
164
  }
136
165
 
137
166
 
167
+ def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
168
+ return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
169
+
170
+
138
171
  def _get_single_file_loadable_mapping_class(cls):
139
172
  diffusers_module = importlib.import_module(__name__.split(".")[0])
140
173
  for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
@@ -186,9 +219,8 @@ class FromOriginalModelMixin:
186
219
  original_config (`str`, *optional*):
187
220
  Dict or path to a yaml file containing the configuration for the model in its original format.
188
221
  If a dict is provided, it will be used to initialize the model configuration.
189
- torch_dtype (`str` or `torch.dtype`, *optional*):
190
- Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
191
- dtype is automatically derived from the model's weights.
222
+ torch_dtype (`torch.dtype`, *optional*):
223
+ Override the default `torch.dtype` and load the model with another dtype.
192
224
  force_download (`bool`, *optional*, defaults to `False`):
193
225
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
194
226
  cached versions if they exist.
@@ -208,6 +240,11 @@ class FromOriginalModelMixin:
208
240
  revision (`str`, *optional*, defaults to `"main"`):
209
241
  The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
210
242
  allowed by Git.
243
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 and
244
+ is_accelerate_available() else `False`): Speed up model loading only loading the pretrained weights and
245
+ not initializing the weights. This also tries to not use more than 1x model size in CPU memory
246
+ (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using
247
+ an older version of PyTorch, setting this argument to `True` will raise an error.
211
248
  disable_mmap ('bool', *optional*, defaults to 'False'):
212
249
  Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
213
250
  is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
@@ -257,9 +294,15 @@ class FromOriginalModelMixin:
257
294
  config_revision = kwargs.pop("config_revision", None)
258
295
  torch_dtype = kwargs.pop("torch_dtype", None)
259
296
  quantization_config = kwargs.pop("quantization_config", None)
297
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
260
298
  device = kwargs.pop("device", None)
261
299
  disable_mmap = kwargs.pop("disable_mmap", False)
262
300
 
301
+ user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
302
+ # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
303
+ if quantization_config is not None:
304
+ user_agent["quant"] = quantization_config.quant_method.value
305
+
263
306
  if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
264
307
  torch_dtype = torch.float32
265
308
  logger.warning(
@@ -278,6 +321,7 @@ class FromOriginalModelMixin:
278
321
  local_files_only=local_files_only,
279
322
  revision=revision,
280
323
  disable_mmap=disable_mmap,
324
+ user_agent=user_agent,
281
325
  )
282
326
  if quantization_config is not None:
283
327
  hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
@@ -355,19 +399,23 @@ class FromOriginalModelMixin:
355
399
  model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
356
400
  diffusers_model_config.update(model_kwargs)
357
401
 
402
+ ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
403
+ with ctx():
404
+ model = cls.from_config(diffusers_model_config)
405
+
358
406
  checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
359
- diffusers_format_checkpoint = checkpoint_mapping_fn(
360
- config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
361
- )
407
+
408
+ if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
409
+ diffusers_format_checkpoint = checkpoint_mapping_fn(
410
+ config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
411
+ )
412
+ else:
413
+ diffusers_format_checkpoint = checkpoint
414
+
362
415
  if not diffusers_format_checkpoint:
363
416
  raise SingleFileComponentError(
364
417
  f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
365
418
  )
366
-
367
- ctx = init_empty_weights if is_accelerate_available() else nullcontext
368
- with ctx():
369
- model = cls.from_config(diffusers_model_config)
370
-
371
419
  # Check if `_keep_in_fp32_modules` is not None
372
420
  use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
373
421
  (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
@@ -389,7 +437,7 @@ class FromOriginalModelMixin:
389
437
  )
390
438
 
391
439
  device_map = None
392
- if is_accelerate_available():
440
+ if low_cpu_mem_usage:
393
441
  param_device = torch.device(device) if device else torch.device("cpu")
394
442
  empty_state_dict = model.state_dict()
395
443
  unexpected_keys = [
@@ -405,6 +453,7 @@ class FromOriginalModelMixin:
405
453
  keep_in_fp32_modules=keep_in_fp32_modules,
406
454
  unexpected_keys=unexpected_keys,
407
455
  )
456
+ empty_device_cache()
408
457
  else:
409
458
  _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
410
459