diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +595 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
  384. diffusers-0.33.1.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -139,7 +139,9 @@ def get_3d_sincos_pos_embed(
139
139
 
140
140
  # 3. Concat
141
141
  pos_embed_spatial = pos_embed_spatial[None, :, :]
142
- pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
142
+ pos_embed_spatial = pos_embed_spatial.repeat_interleave(
143
+ temporal_size, dim=0, output_size=pos_embed_spatial.shape[0] * temporal_size
144
+ ) # [T, H*W, D // 4 * 3]
143
145
 
144
146
  pos_embed_temporal = pos_embed_temporal[:, None, :]
145
147
  pos_embed_temporal = pos_embed_temporal.repeat_interleave(
@@ -334,7 +336,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
334
336
  " `from_numpy` is no longer required."
335
337
  " Pass `output_type='pt' to use the new version now."
336
338
  )
337
- deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
339
+ deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
338
340
  return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
339
341
  if embed_dim % 2 != 0:
340
342
  raise ValueError("embed_dim must be divisible by 2")
@@ -1152,10 +1154,13 @@ def get_1d_rotary_pos_embed(
1152
1154
  / linear_factor
1153
1155
  ) # [D/2]
1154
1156
  freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
1157
+ is_npu = freqs.device.type == "npu"
1158
+ if is_npu:
1159
+ freqs = freqs.float()
1155
1160
  if use_real and repeat_interleave_real:
1156
1161
  # flux, hunyuan-dit, cogvideox
1157
- freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
1158
- freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
1162
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
1163
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
1159
1164
  return freqs_cos, freqs_sin
1160
1165
  elif use_real:
1161
1166
  # stable audio, allegro
@@ -1199,7 +1204,7 @@ def apply_rotary_emb(
1199
1204
  x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
1200
1205
  x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
1201
1206
  elif use_real_unbind_dim == -2:
1202
- # Used for Stable Audio
1207
+ # Used for Stable Audio, OmniGen and CogView4
1203
1208
  x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
1204
1209
  x_rotated = torch.cat([-x_imag, x_real], dim=-1)
1205
1210
  else:
@@ -1248,7 +1253,8 @@ class FluxPosEmbed(nn.Module):
1248
1253
  sin_out = []
1249
1254
  pos = ids.float()
1250
1255
  is_mps = ids.device.type == "mps"
1251
- freqs_dtype = torch.float32 if is_mps else torch.float64
1256
+ is_npu = ids.device.type == "npu"
1257
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
1252
1258
  for i in range(n_axes):
1253
1259
  cos, sin = get_1d_rotary_pos_embed(
1254
1260
  self.axes_dim[i],
@@ -1786,7 +1792,7 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
1786
1792
  def forward(self, timestep, caption_feat, caption_mask):
1787
1793
  # timestep embedding:
1788
1794
  time_freq = self.time_proj(timestep)
1789
- time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
1795
+ time_embed = self.timestep_embedder(time_freq.to(dtype=caption_feat.dtype))
1790
1796
 
1791
1797
  # caption condition embedding:
1792
1798
  caption_mask_float = caption_mask.float().unsqueeze(-1)
@@ -2582,6 +2588,11 @@ class MultiIPAdapterImageProjection(nn.Module):
2582
2588
  super().__init__()
2583
2589
  self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
2584
2590
 
2591
+ @property
2592
+ def num_ip_adapters(self) -> int:
2593
+ """Number of IP-Adapters loaded."""
2594
+ return len(self.image_projection_layers)
2595
+
2585
2596
  def forward(self, image_embeds: List[torch.Tensor]):
2586
2597
  projected_image_embeds = []
2587
2598
 
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2024 The HuggingFace Inc. team.
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
3
  # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,12 +20,15 @@ import os
20
20
  from array import array
21
21
  from collections import OrderedDict
22
22
  from pathlib import Path
23
- from typing import List, Optional, Union
23
+ from typing import Dict, List, Optional, Union
24
+ from zipfile import is_zipfile
24
25
 
25
26
  import safetensors
26
27
  import torch
28
+ from huggingface_hub import DDUFEntry
27
29
  from huggingface_hub.utils import EntryNotFoundError
28
30
 
31
+ from ..quantizers import DiffusersQuantizer
29
32
  from ..utils import (
30
33
  GGUF_FILE_EXTENSION,
31
34
  SAFE_WEIGHTS_INDEX_NAME,
@@ -54,7 +57,7 @@ _CLASS_REMAPPING_DICT = {
54
57
 
55
58
  if is_accelerate_available():
56
59
  from accelerate import infer_auto_device_map
57
- from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
60
+ from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device
58
61
 
59
62
 
60
63
  # Adapted from `transformers` (see modeling_utils.py)
@@ -131,27 +134,61 @@ def _fetch_remapped_cls_from_config(config, old_class):
131
134
  return old_class
132
135
 
133
136
 
134
- def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
137
+ def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Union[int, str, torch.device]]]):
138
+ """
139
+ Find the device of param_name from the device_map.
140
+ """
141
+ if device_map is None:
142
+ return "cpu"
143
+ else:
144
+ module_name = param_name
145
+ # find next higher level module that is defined in device_map:
146
+ # bert.lm_head.weight -> bert.lm_head -> bert -> ''
147
+ while len(module_name) > 0 and module_name not in device_map:
148
+ module_name = ".".join(module_name.split(".")[:-1])
149
+ if module_name == "" and "" not in device_map:
150
+ raise ValueError(f"{param_name} doesn't have any device set.")
151
+ return device_map[module_name]
152
+
153
+
154
+ def load_state_dict(
155
+ checkpoint_file: Union[str, os.PathLike],
156
+ dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
157
+ disable_mmap: bool = False,
158
+ map_location: Union[str, torch.device] = "cpu",
159
+ ):
135
160
  """
136
161
  Reads a checkpoint file, returning properly formatted errors if they arise.
137
162
  """
138
- # TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change
139
- # when refactoring the _merge_sharded_checkpoints() method later.
163
+ # TODO: maybe refactor a bit this part where we pass a dict here
140
164
  if isinstance(checkpoint_file, dict):
141
165
  return checkpoint_file
142
166
  try:
143
167
  file_extension = os.path.basename(checkpoint_file).split(".")[-1]
144
168
  if file_extension == SAFETENSORS_FILE_EXTENSION:
145
- return safetensors.torch.load_file(checkpoint_file, device="cpu")
169
+ if dduf_entries:
170
+ # tensors are loaded on cpu
171
+ with dduf_entries[checkpoint_file].as_mmap() as mm:
172
+ return safetensors.torch.load(mm)
173
+ if disable_mmap:
174
+ return safetensors.torch.load(open(checkpoint_file, "rb").read())
175
+ else:
176
+ return safetensors.torch.load_file(checkpoint_file, device=map_location)
146
177
  elif file_extension == GGUF_FILE_EXTENSION:
147
178
  return load_gguf_checkpoint(checkpoint_file)
148
179
  else:
180
+ extra_args = {}
149
181
  weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
150
- return torch.load(
151
- checkpoint_file,
152
- map_location="cpu",
153
- **weights_only_kwarg,
154
- )
182
+ # mmap can only be used with files serialized with zipfile-based format.
183
+ if (
184
+ isinstance(checkpoint_file, str)
185
+ and map_location != "meta"
186
+ and is_torch_version(">=", "2.1.0")
187
+ and is_zipfile(checkpoint_file)
188
+ and not disable_mmap
189
+ ):
190
+ extra_args = {"mmap": True}
191
+ return torch.load(checkpoint_file, map_location=map_location, **weights_only_kwarg, **extra_args)
155
192
  except Exception as e:
156
193
  try:
157
194
  with open(checkpoint_file) as f:
@@ -168,29 +205,31 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
168
205
  ) from e
169
206
  except (UnicodeDecodeError, ValueError):
170
207
  raise OSError(
171
- f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
208
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. "
172
209
  )
173
210
 
174
211
 
175
212
  def load_model_dict_into_meta(
176
213
  model,
177
214
  state_dict: OrderedDict,
178
- device: Optional[Union[str, torch.device]] = None,
179
215
  dtype: Optional[Union[str, torch.dtype]] = None,
180
216
  model_name_or_path: Optional[str] = None,
181
- hf_quantizer=None,
182
- keep_in_fp32_modules=None,
217
+ hf_quantizer: Optional[DiffusersQuantizer] = None,
218
+ keep_in_fp32_modules: Optional[List] = None,
219
+ device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
220
+ unexpected_keys: Optional[List[str]] = None,
221
+ offload_folder: Optional[Union[str, os.PathLike]] = None,
222
+ offload_index: Optional[Dict] = None,
223
+ state_dict_index: Optional[Dict] = None,
224
+ state_dict_folder: Optional[Union[str, os.PathLike]] = None,
183
225
  ) -> List[str]:
184
- if device is not None and not isinstance(device, (str, torch.device)):
185
- raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
186
- if hf_quantizer is None:
187
- device = device or torch.device("cpu")
188
- dtype = dtype or torch.float32
189
- is_quantized = hf_quantizer is not None
226
+ """
227
+ This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
228
+ params on a `meta` device. It replaces the model params with the data from the `state_dict`
229
+ """
190
230
 
191
- accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
231
+ is_quantized = hf_quantizer is not None
192
232
  empty_state_dict = model.state_dict()
193
- unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
194
233
 
195
234
  for param_name, param in state_dict.items():
196
235
  if param_name not in empty_state_dict:
@@ -200,21 +239,38 @@ def load_model_dict_into_meta(
200
239
  # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
201
240
  # in int/uint/bool and not cast them.
202
241
  # TODO: revisit cases when param.dtype == torch.float8_e4m3fn
203
- if torch.is_floating_point(param):
204
- if (
205
- keep_in_fp32_modules is not None
206
- and any(
207
- module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
208
- )
209
- and dtype == torch.float16
242
+ if dtype is not None and torch.is_floating_point(param):
243
+ if keep_in_fp32_modules is not None and any(
244
+ module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
210
245
  ):
211
246
  param = param.to(torch.float32)
212
- if accepts_dtype:
213
- set_module_kwargs["dtype"] = torch.float32
247
+ set_module_kwargs["dtype"] = torch.float32
248
+ # For quantizers have save weights using torch.float8_e4m3fn
249
+ elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None):
250
+ pass
214
251
  else:
215
252
  param = param.to(dtype)
216
- if accepts_dtype:
217
- set_module_kwargs["dtype"] = dtype
253
+ set_module_kwargs["dtype"] = dtype
254
+
255
+ # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
256
+ # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
257
+ # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
258
+ old_param = model
259
+ splits = param_name.split(".")
260
+ for split in splits:
261
+ old_param = getattr(old_param, split)
262
+
263
+ if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
264
+ old_param = None
265
+
266
+ if old_param is not None:
267
+ if dtype is None:
268
+ param = param.to(old_param.dtype)
269
+
270
+ if old_param.is_contiguous():
271
+ param = param.contiguous()
272
+
273
+ param_device = _determine_param_device(param_name, device_map)
218
274
 
219
275
  # bnb params are flattened.
220
276
  # gguf quants have a different shape based on the type of quantization applied
@@ -222,7 +278,9 @@ def load_model_dict_into_meta(
222
278
  if (
223
279
  is_quantized
224
280
  and hf_quantizer.pre_quantized
225
- and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
281
+ and hf_quantizer.check_if_quantized_param(
282
+ model, param, param_name, state_dict, param_device=param_device
283
+ )
226
284
  ):
227
285
  hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
228
286
  else:
@@ -230,21 +288,25 @@ def load_model_dict_into_meta(
230
288
  raise ValueError(
231
289
  f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
232
290
  )
233
-
234
- if is_quantized and (
235
- hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
291
+ if param_device == "disk":
292
+ offload_index = offload_weight(param, param_name, offload_folder, offload_index)
293
+ elif param_device == "cpu" and state_dict_index is not None:
294
+ state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
295
+ elif is_quantized and (
296
+ hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
236
297
  ):
237
- hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
298
+ hf_quantizer.create_quantized_param(
299
+ model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
300
+ )
238
301
  else:
239
- if accepts_dtype:
240
- set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
241
- else:
242
- set_module_tensor_to_device(model, param_name, device, value=param)
302
+ set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
243
303
 
244
- return unexpected_keys
304
+ return offload_index, state_dict_index
245
305
 
246
306
 
247
- def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
307
+ def _load_state_dict_into_model(
308
+ model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
309
+ ) -> List[str]:
248
310
  # Convert old format to new format if needed from a PyTorch state_dict
249
311
  # copy state_dict so _load_from_state_dict can modify it
250
312
  state_dict = state_dict.copy()
@@ -252,15 +314,19 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[
252
314
 
253
315
  # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
254
316
  # so we need to apply the function recursively.
255
- def load(module: torch.nn.Module, prefix: str = ""):
256
- args = (state_dict, prefix, {}, True, [], [], error_msgs)
317
+ def load(module: torch.nn.Module, prefix: str = "", assign_to_params_buffers: bool = False):
318
+ local_metadata = {}
319
+ local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
320
+ if assign_to_params_buffers and not is_torch_version(">=", "2.1"):
321
+ logger.info("You need to have torch>=2.1 in order to load the model with assign_to_params_buffers=True")
322
+ args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
257
323
  module._load_from_state_dict(*args)
258
324
 
259
325
  for name, child in module._modules.items():
260
326
  if child is not None:
261
- load(child, prefix + name + ".")
327
+ load(child, prefix + name + ".", assign_to_params_buffers)
262
328
 
263
- load(model_to_load)
329
+ load(model_to_load, assign_to_params_buffers=assign_to_params_buffers)
264
330
 
265
331
  return error_msgs
266
332
 
@@ -279,6 +345,7 @@ def _fetch_index_file(
279
345
  revision,
280
346
  user_agent,
281
347
  commit_hash,
348
+ dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
282
349
  ):
283
350
  if is_local:
284
351
  index_file = Path(
@@ -304,43 +371,16 @@ def _fetch_index_file(
304
371
  subfolder=None,
305
372
  user_agent=user_agent,
306
373
  commit_hash=commit_hash,
374
+ dduf_entries=dduf_entries,
307
375
  )
308
- index_file = Path(index_file)
376
+ if not dduf_entries:
377
+ index_file = Path(index_file)
309
378
  except (EntryNotFoundError, EnvironmentError):
310
379
  index_file = None
311
380
 
312
381
  return index_file
313
382
 
314
383
 
315
- # Adapted from
316
- # https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
317
- def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
318
- weight_map = sharded_metadata.get("weight_map", None)
319
- if weight_map is None:
320
- raise KeyError("'weight_map' key not found in the shard index file.")
321
-
322
- # Collect all unique safetensors files from weight_map
323
- files_to_load = set(weight_map.values())
324
- is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
325
- merged_state_dict = {}
326
-
327
- # Load tensors from each unique file
328
- for file_name in files_to_load:
329
- part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
330
- if not os.path.exists(part_file_path):
331
- raise FileNotFoundError(f"Part file {file_name} not found.")
332
-
333
- if is_safetensors:
334
- with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
335
- for tensor_key in f.keys():
336
- if tensor_key in weight_map:
337
- merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
338
- else:
339
- merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
340
-
341
- return merged_state_dict
342
-
343
-
344
384
  def _fetch_index_file_legacy(
345
385
  is_local,
346
386
  pretrained_model_name_or_path,
@@ -355,6 +395,7 @@ def _fetch_index_file_legacy(
355
395
  revision,
356
396
  user_agent,
357
397
  commit_hash,
398
+ dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
358
399
  ):
359
400
  if is_local:
360
401
  index_file = Path(
@@ -395,6 +436,7 @@ def _fetch_index_file_legacy(
395
436
  subfolder=None,
396
437
  user_agent=user_agent,
397
438
  commit_hash=commit_hash,
439
+ dduf_entries=dduf_entries,
398
440
  )
399
441
  index_file = Path(index_file)
400
442
  deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2024 The HuggingFace Inc. team.
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2024 The HuggingFace Inc. team.
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2024 The HuggingFace Inc. team.
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.