diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -752,7 +752,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
752
752
  condition = self.controlnet_cond_embedding(cond)
753
753
  feat_seq = torch.mean(condition, dim=(2, 3))
754
754
  feat_seq = feat_seq + self.task_embedding[control_idx]
755
- if from_multi:
755
+ if from_multi or len(control_type_idx) == 1:
756
756
  inputs.append(feat_seq.unsqueeze(1))
757
757
  condition_list.append(condition)
758
758
  else:
@@ -772,7 +772,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
772
772
  for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
773
773
  alpha = self.spatial_ch_projs(x[:, idx])
774
774
  alpha = alpha.unsqueeze(-1).unsqueeze(-1)
775
- if from_multi:
775
+ if from_multi or len(control_type_idx) == 1:
776
776
  controlnet_cond_fuser += condition + alpha
777
777
  else:
778
778
  controlnet_cond_fuser += condition + alpha * scale
@@ -819,11 +819,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
819
819
  # 6. scaling
820
820
  if guess_mode and not self.config.global_pool_conditions:
821
821
  scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
822
- if from_multi:
822
+ if from_multi or len(control_type_idx) == 1:
823
823
  scales = scales * conditioning_scale[0]
824
824
  down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
825
825
  mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
826
- elif from_multi:
826
+ elif from_multi or len(control_type_idx) == 1:
827
827
  down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
828
828
  mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
829
829
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -734,17 +734,17 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
734
734
  unet (`UNet2DConditionModel`):
735
735
  The UNet model we want to control.
736
736
  controlnet (`ControlNetXSAdapter`):
737
- The ConntrolNet-XS adapter with which the UNet will be fused. If none is given, a new ConntrolNet-XS
737
+ The ControlNet-XS adapter with which the UNet will be fused. If none is given, a new ControlNet-XS
738
738
  adapter will be created.
739
739
  size_ratio (float, *optional*, defaults to `None`):
740
- Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
740
+ Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
741
741
  ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`):
742
- Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details,
742
+ Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details,
743
743
  where this parameter is called `block_out_channels`.
744
744
  time_embedding_mix (`float`, *optional*, defaults to None):
745
- Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
745
+ Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
746
746
  ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`):
747
- Passed to the `init` of the new controlent if no controlent was given.
747
+ Passed to the `init` of the new controlnet if no controlnet was given.
748
748
  """
749
749
  if controlnet is None:
750
750
  controlnet = ControlNetXSAdapter.from_unet(
@@ -942,7 +942,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
942
942
 
943
943
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
944
944
  def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
945
- r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
945
+ r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.
946
946
 
947
947
  The suffixes after the scaling factors represent the stage blocks where they are being applied.
948
948
 
@@ -4,9 +4,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
4
  import torch
5
5
  from torch import nn
6
6
 
7
- from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput
8
- from ...models.modeling_utils import ModelMixin
9
7
  from ...utils import logging
8
+ from ..controlnets.controlnet import ControlNetModel, ControlNetOutput
9
+ from ..modeling_utils import ModelMixin
10
10
 
11
11
 
12
12
  logger = logging.get_logger(__name__)
@@ -130,9 +130,8 @@ class MultiControlNetModel(ModelMixin):
130
130
  A path to a *directory* containing model weights saved using
131
131
  [`~models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained`], e.g.,
132
132
  `./my_model_directory/controlnet`.
133
- torch_dtype (`str` or `torch.dtype`, *optional*):
134
- Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
135
- will be automatically derived from the model's weights.
133
+ torch_dtype (`torch.dtype`, *optional*):
134
+ Override the default `torch.dtype` and load the model under this dtype.
136
135
  output_loading_info(`bool`, *optional*, defaults to `False`):
137
136
  Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
138
137
  device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
@@ -4,10 +4,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
4
  import torch
5
5
  from torch import nn
6
6
 
7
- from ...models.controlnets.controlnet import ControlNetOutput
8
- from ...models.controlnets.controlnet_union import ControlNetUnionModel
9
- from ...models.modeling_utils import ModelMixin
10
7
  from ...utils import logging
8
+ from ..controlnets.controlnet import ControlNetOutput
9
+ from ..controlnets.controlnet_union import ControlNetUnionModel
10
+ from ..modeling_utils import ModelMixin
11
11
 
12
12
 
13
13
  logger = logging.get_logger(__name__)
@@ -143,9 +143,8 @@ class MultiControlNetUnionModel(ModelMixin):
143
143
  A path to a *directory* containing model weights saved using
144
144
  [`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.save_pretrained`], e.g.,
145
145
  `./my_model_directory/controlnet`.
146
- torch_dtype (`str` or `torch.dtype`, *optional*):
147
- Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
148
- will be automatically derived from the model's weights.
146
+ torch_dtype (`torch.dtype`, *optional*):
147
+ Override the default `torch.dtype` and load the model under this dtype.
149
148
  output_loading_info(`bool`, *optional*, defaults to `False`):
150
149
  Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
151
150
  device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -286,7 +286,7 @@ class KDownsample2D(nn.Module):
286
286
 
287
287
 
288
288
  class CogVideoXDownsample3D(nn.Module):
289
- # Todo: Wait for paper relase.
289
+ # Todo: Wait for paper release.
290
290
  r"""
291
291
  A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
292
292
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -31,7 +31,7 @@ def get_timestep_embedding(
31
31
  downscale_freq_shift: float = 1,
32
32
  scale: float = 1,
33
33
  max_period: int = 10000,
34
- ):
34
+ ) -> torch.Tensor:
35
35
  """
36
36
  This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
37
37
 
@@ -97,7 +97,7 @@ def get_3d_sincos_pos_embed(
97
97
  The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
98
98
  spatial dimensions (height and width).
99
99
  temporal_size (`int`):
100
- The temporal dimension of postional embeddings (number of frames).
100
+ The temporal dimension of positional embeddings (number of frames).
101
101
  spatial_interpolation_scale (`float`, defaults to 1.0):
102
102
  Scale factor for spatial grid interpolation.
103
103
  temporal_interpolation_scale (`float`, defaults to 1.0):
@@ -169,7 +169,7 @@ def _get_3d_sincos_pos_embed_np(
169
169
  The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
170
170
  spatial dimensions (height and width).
171
171
  temporal_size (`int`):
172
- The temporal dimension of postional embeddings (number of frames).
172
+ The temporal dimension of positional embeddings (number of frames).
173
173
  spatial_interpolation_scale (`float`, defaults to 1.0):
174
174
  Scale factor for spatial grid interpolation.
175
175
  temporal_interpolation_scale (`float`, defaults to 1.0):
@@ -319,7 +319,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
319
319
  return emb
320
320
 
321
321
 
322
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
322
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False):
323
323
  """
324
324
  This function generates 1D positional embeddings from a grid.
325
325
 
@@ -352,6 +352,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
352
352
  emb_cos = torch.cos(out) # (M, D/2)
353
353
 
354
354
  emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
355
+
356
+ # flip sine and cosine embeddings
357
+ if flip_sin_to_cos:
358
+ emb = torch.cat([emb[:, embed_dim // 2 :], emb[:, : embed_dim // 2]], dim=1)
359
+
355
360
  return emb
356
361
 
357
362
 
@@ -1149,9 +1154,7 @@ def get_1d_rotary_pos_embed(
1149
1154
 
1150
1155
  theta = theta * ntk_factor
1151
1156
  freqs = (
1152
- 1.0
1153
- / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
1154
- / linear_factor
1157
+ 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
1155
1158
  ) # [D/2]
1156
1159
  freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
1157
1160
  is_npu = freqs.device.type == "npu"
@@ -1178,6 +1181,7 @@ def apply_rotary_emb(
1178
1181
  freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
1179
1182
  use_real: bool = True,
1180
1183
  use_real_unbind_dim: int = -1,
1184
+ sequence_dim: int = 2,
1181
1185
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1182
1186
  """
1183
1187
  Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
@@ -1195,17 +1199,24 @@ def apply_rotary_emb(
1195
1199
  """
1196
1200
  if use_real:
1197
1201
  cos, sin = freqs_cis # [S, D]
1198
- cos = cos[None, None]
1199
- sin = sin[None, None]
1202
+ if sequence_dim == 2:
1203
+ cos = cos[None, None, :, :]
1204
+ sin = sin[None, None, :, :]
1205
+ elif sequence_dim == 1:
1206
+ cos = cos[None, :, None, :]
1207
+ sin = sin[None, :, None, :]
1208
+ else:
1209
+ raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
1210
+
1200
1211
  cos, sin = cos.to(x.device), sin.to(x.device)
1201
1212
 
1202
1213
  if use_real_unbind_dim == -1:
1203
1214
  # Used for flux, cogvideox, hunyuan-dit
1204
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
1215
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
1205
1216
  x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
1206
1217
  elif use_real_unbind_dim == -2:
1207
- # Used for Stable Audio, OmniGen and CogView4
1208
- x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
1218
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
1219
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
1209
1220
  x_rotated = torch.cat([-x_imag, x_real], dim=-1)
1210
1221
  else:
1211
1222
  raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
@@ -1240,37 +1251,6 @@ def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
1240
1251
  return x
1241
1252
 
1242
1253
 
1243
- class FluxPosEmbed(nn.Module):
1244
- # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
1245
- def __init__(self, theta: int, axes_dim: List[int]):
1246
- super().__init__()
1247
- self.theta = theta
1248
- self.axes_dim = axes_dim
1249
-
1250
- def forward(self, ids: torch.Tensor) -> torch.Tensor:
1251
- n_axes = ids.shape[-1]
1252
- cos_out = []
1253
- sin_out = []
1254
- pos = ids.float()
1255
- is_mps = ids.device.type == "mps"
1256
- is_npu = ids.device.type == "npu"
1257
- freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
1258
- for i in range(n_axes):
1259
- cos, sin = get_1d_rotary_pos_embed(
1260
- self.axes_dim[i],
1261
- pos[:, i],
1262
- theta=self.theta,
1263
- repeat_interleave_real=True,
1264
- use_real=True,
1265
- freqs_dtype=freqs_dtype,
1266
- )
1267
- cos_out.append(cos)
1268
- sin_out.append(sin)
1269
- freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
1270
- freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
1271
- return freqs_cos, freqs_sin
1272
-
1273
-
1274
1254
  class TimestepEmbedding(nn.Module):
1275
1255
  def __init__(
1276
1256
  self,
@@ -1327,7 +1307,7 @@ class Timesteps(nn.Module):
1327
1307
  self.downscale_freq_shift = downscale_freq_shift
1328
1308
  self.scale = scale
1329
1309
 
1330
- def forward(self, timesteps):
1310
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
1331
1311
  t_emb = get_timestep_embedding(
1332
1312
  timesteps,
1333
1313
  self.num_channels,
@@ -1401,7 +1381,7 @@ class ImagePositionalEmbeddings(nn.Module):
1401
1381
  Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
1402
1382
  height and width of the latent space.
1403
1383
 
1404
- For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
1384
+ For more details, see figure 10 of the dall-e paper: https://huggingface.co/papers/2102.12092
1405
1385
 
1406
1386
  For VQ-diffusion:
1407
1387
 
@@ -2621,3 +2601,13 @@ class MultiIPAdapterImageProjection(nn.Module):
2621
2601
  projected_image_embeds.append(image_embed)
2622
2602
 
2623
2603
  return projected_image_embeds
2604
+
2605
+
2606
+ class FluxPosEmbed(nn.Module):
2607
+ def __new__(cls, *args, **kwargs):
2608
+ deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`."
2609
+ deprecate("FluxPosEmbed", "1.0.0", deprecation_message)
2610
+
2611
+ from .transformers.transformer_flux import FluxPosEmbed
2612
+
2613
+ return FluxPosEmbed(*args, **kwargs)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -89,7 +89,7 @@ class FlaxTimestepEmbedding(nn.Module):
89
89
 
90
90
  class FlaxTimesteps(nn.Module):
91
91
  r"""
92
- Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
92
+ Wrapper Module for sinusoidal Time step Embeddings as described in https://huggingface.co/papers/2006.11239
93
93
 
94
94
  Args:
95
95
  dim (`int`, *optional*, defaults to `32`):
diffusers/models/lora.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -38,7 +38,7 @@ if is_transformers_available():
38
38
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
39
 
40
40
 
41
- def text_encoder_attn_modules(text_encoder):
41
+ def text_encoder_attn_modules(text_encoder: nn.Module):
42
42
  attn_modules = []
43
43
 
44
44
  if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
@@ -52,7 +52,7 @@ def text_encoder_attn_modules(text_encoder):
52
52
  return attn_modules
53
53
 
54
54
 
55
- def text_encoder_mlp_modules(text_encoder):
55
+ def text_encoder_mlp_modules(text_encoder: nn.Module):
56
56
  mlp_modules = []
57
57
 
58
58
  if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
@@ -14,11 +14,13 @@
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
16
 
17
+ import functools
17
18
  import importlib
18
19
  import inspect
19
20
  import os
20
21
  from array import array
21
- from collections import OrderedDict
22
+ from collections import OrderedDict, defaultdict
23
+ from concurrent.futures import ThreadPoolExecutor, as_completed
22
24
  from pathlib import Path
23
25
  from typing import Dict, List, Optional, Union
24
26
  from zipfile import is_zipfile
@@ -30,6 +32,7 @@ from huggingface_hub.utils import EntryNotFoundError
30
32
 
31
33
  from ..quantizers import DiffusersQuantizer
32
34
  from ..utils import (
35
+ DEFAULT_HF_PARALLEL_LOADING_WORKERS,
33
36
  GGUF_FILE_EXTENSION,
34
37
  SAFE_WEIGHTS_INDEX_NAME,
35
38
  SAFETENSORS_FILE_EXTENSION,
@@ -38,6 +41,7 @@ from ..utils import (
38
41
  _get_model_file,
39
42
  deprecate,
40
43
  is_accelerate_available,
44
+ is_accelerate_version,
41
45
  is_gguf_available,
42
46
  is_torch_available,
43
47
  is_torch_version,
@@ -252,6 +256,10 @@ def load_model_dict_into_meta(
252
256
  param = param.to(dtype)
253
257
  set_module_kwargs["dtype"] = dtype
254
258
 
259
+ if is_accelerate_version(">", "1.8.1"):
260
+ set_module_kwargs["non_blocking"] = True
261
+ set_module_kwargs["clear_cache"] = False
262
+
255
263
  # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
256
264
  # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
257
265
  # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
@@ -304,6 +312,161 @@ def load_model_dict_into_meta(
304
312
  return offload_index, state_dict_index
305
313
 
306
314
 
315
+ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
316
+ """
317
+ Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
318
+ checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
319
+ parameters.
320
+
321
+ """
322
+ if model_to_load.device.type == "meta":
323
+ return False
324
+
325
+ if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
326
+ return False
327
+
328
+ # Some models explicitly do not support param buffer assignment
329
+ if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
330
+ logger.debug(
331
+ f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
332
+ )
333
+ return False
334
+
335
+ # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
336
+ first_key = next(iter(model_to_load.state_dict().keys()))
337
+ if start_prefix + first_key in state_dict:
338
+ return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
339
+
340
+ return False
341
+
342
+
343
+ def _load_shard_file(
344
+ shard_file,
345
+ model,
346
+ model_state_dict,
347
+ device_map=None,
348
+ dtype=None,
349
+ hf_quantizer=None,
350
+ keep_in_fp32_modules=None,
351
+ dduf_entries=None,
352
+ loaded_keys=None,
353
+ unexpected_keys=None,
354
+ offload_index=None,
355
+ offload_folder=None,
356
+ state_dict_index=None,
357
+ state_dict_folder=None,
358
+ ignore_mismatched_sizes=False,
359
+ low_cpu_mem_usage=False,
360
+ ):
361
+ state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
362
+ mismatched_keys = _find_mismatched_keys(
363
+ state_dict,
364
+ model_state_dict,
365
+ loaded_keys,
366
+ ignore_mismatched_sizes,
367
+ )
368
+ error_msgs = []
369
+ if low_cpu_mem_usage:
370
+ offload_index, state_dict_index = load_model_dict_into_meta(
371
+ model,
372
+ state_dict,
373
+ device_map=device_map,
374
+ dtype=dtype,
375
+ hf_quantizer=hf_quantizer,
376
+ keep_in_fp32_modules=keep_in_fp32_modules,
377
+ unexpected_keys=unexpected_keys,
378
+ offload_folder=offload_folder,
379
+ offload_index=offload_index,
380
+ state_dict_index=state_dict_index,
381
+ state_dict_folder=state_dict_folder,
382
+ )
383
+ else:
384
+ assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
385
+
386
+ error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
387
+ return offload_index, state_dict_index, mismatched_keys, error_msgs
388
+
389
+
390
+ def _load_shard_files_with_threadpool(
391
+ shard_files,
392
+ model,
393
+ model_state_dict,
394
+ device_map=None,
395
+ dtype=None,
396
+ hf_quantizer=None,
397
+ keep_in_fp32_modules=None,
398
+ dduf_entries=None,
399
+ loaded_keys=None,
400
+ unexpected_keys=None,
401
+ offload_index=None,
402
+ offload_folder=None,
403
+ state_dict_index=None,
404
+ state_dict_folder=None,
405
+ ignore_mismatched_sizes=False,
406
+ low_cpu_mem_usage=False,
407
+ ):
408
+ # Do not spawn anymore workers than you need
409
+ num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
410
+
411
+ logger.info(f"Loading model weights in parallel with {num_workers} workers...")
412
+
413
+ error_msgs = []
414
+ mismatched_keys = []
415
+
416
+ load_one = functools.partial(
417
+ _load_shard_file,
418
+ model=model,
419
+ model_state_dict=model_state_dict,
420
+ device_map=device_map,
421
+ dtype=dtype,
422
+ hf_quantizer=hf_quantizer,
423
+ keep_in_fp32_modules=keep_in_fp32_modules,
424
+ dduf_entries=dduf_entries,
425
+ loaded_keys=loaded_keys,
426
+ unexpected_keys=unexpected_keys,
427
+ offload_index=offload_index,
428
+ offload_folder=offload_folder,
429
+ state_dict_index=state_dict_index,
430
+ state_dict_folder=state_dict_folder,
431
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
432
+ low_cpu_mem_usage=low_cpu_mem_usage,
433
+ )
434
+
435
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
436
+ with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
437
+ futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
438
+ for future in as_completed(futures):
439
+ result = future.result()
440
+ offload_index, state_dict_index, _mismatched_keys, _error_msgs = result
441
+ error_msgs += _error_msgs
442
+ mismatched_keys += _mismatched_keys
443
+ pbar.update(1)
444
+
445
+ return offload_index, state_dict_index, mismatched_keys, error_msgs
446
+
447
+
448
+ def _find_mismatched_keys(
449
+ state_dict,
450
+ model_state_dict,
451
+ loaded_keys,
452
+ ignore_mismatched_sizes,
453
+ ):
454
+ mismatched_keys = []
455
+ if ignore_mismatched_sizes:
456
+ for checkpoint_key in loaded_keys:
457
+ model_key = checkpoint_key
458
+ # If the checkpoint is sharded, we may not have the key here.
459
+ if checkpoint_key not in state_dict:
460
+ continue
461
+
462
+ if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
463
+ mismatched_keys.append(
464
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
465
+ )
466
+ del state_dict[checkpoint_key]
467
+ return mismatched_keys
468
+
469
+
307
470
  def _load_state_dict_into_model(
308
471
  model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
309
472
  ) -> List[str]:
@@ -520,3 +683,72 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
520
683
  parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
521
684
 
522
685
  return parsed_parameters
686
+
687
+
688
+ def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
689
+ mismatched_keys = []
690
+ if not ignore_mismatched_sizes:
691
+ return mismatched_keys
692
+ for checkpoint_key in loaded_keys:
693
+ model_key = checkpoint_key
694
+ # If the checkpoint is sharded, we may not have the key here.
695
+ if checkpoint_key not in state_dict:
696
+ continue
697
+
698
+ if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
699
+ mismatched_keys.append(
700
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
701
+ )
702
+ del state_dict[checkpoint_key]
703
+ return mismatched_keys
704
+
705
+
706
+ def _expand_device_map(device_map, param_names):
707
+ """
708
+ Expand a device map to return the correspondence parameter name to device.
709
+ """
710
+ new_device_map = {}
711
+ for module, device in device_map.items():
712
+ new_device_map.update(
713
+ {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
714
+ )
715
+ return new_device_map
716
+
717
+
718
+ # Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
719
+ def _caching_allocator_warmup(
720
+ model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer]
721
+ ) -> None:
722
+ """
723
+ This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
724
+ device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
725
+ which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
726
+ very large margin.
727
+ """
728
+ factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
729
+
730
+ # Keep only accelerator devices
731
+ accelerator_device_map = {
732
+ param: torch.device(device)
733
+ for param, device in expanded_device_map.items()
734
+ if str(device) not in ["cpu", "disk"]
735
+ }
736
+ if not accelerator_device_map:
737
+ return
738
+
739
+ elements_per_device = defaultdict(int)
740
+ for param_name, device in accelerator_device_map.items():
741
+ try:
742
+ p = model.get_parameter(param_name)
743
+ except AttributeError:
744
+ try:
745
+ p = model.get_buffer(param_name)
746
+ except AttributeError:
747
+ raise AttributeError(f"Parameter or buffer with name={param_name} not found in model")
748
+ # TODO: account for TP when needed.
749
+ elements_per_device[device] += p.numel()
750
+
751
+ # This will kick off the caching allocator to avoid having to Malloc afterwards
752
+ for device, elem_count in elements_per_device.items():
753
+ warmup_elems = max(1, elem_count // factor)
754
+ _ = torch.empty(warmup_elems, dtype=dtype, device=device, requires_grad=False)
@@ -369,8 +369,7 @@ class FlaxModelMixin(PushToHubMixin):
369
369
  raise EnvironmentError(
370
370
  f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
371
371
  "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
372
- "token having permission to this repo with `token` or log in with `huggingface-cli "
373
- "login`."
372
+ "token having permission to this repo with `token` or log in with `hf auth login`."
374
373
  )
375
374
  except RevisionNotFoundError:
376
375
  raise EnvironmentError(