diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +595 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
  384. diffusers-0.33.1.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2024 The HuggingFace Inc. team.
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
3
  # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,16 +20,21 @@ import itertools
20
20
  import json
21
21
  import os
22
22
  import re
23
+ import shutil
24
+ import tempfile
23
25
  from collections import OrderedDict
24
- from functools import partial, wraps
26
+ from contextlib import ExitStack, contextmanager
27
+ from functools import wraps
25
28
  from pathlib import Path
26
- from typing import Any, Callable, List, Optional, Tuple, Union
29
+ from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Type, Union
27
30
 
28
31
  import safetensors
29
32
  import torch
30
- from huggingface_hub import create_repo, split_torch_state_dict_into_shards
33
+ import torch.utils.checkpoint
34
+ from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards
31
35
  from huggingface_hub.utils import validate_hf_hub_args
32
36
  from torch import Tensor, nn
37
+ from typing_extensions import Self
33
38
 
34
39
  from .. import __version__
35
40
  from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
@@ -48,6 +53,7 @@ from ..utils import (
48
53
  is_accelerate_available,
49
54
  is_bitsandbytes_available,
50
55
  is_bitsandbytes_version,
56
+ is_peft_available,
51
57
  is_torch_version,
52
58
  logging,
53
59
  )
@@ -61,16 +67,49 @@ from .model_loading_utils import (
61
67
  _fetch_index_file,
62
68
  _fetch_index_file_legacy,
63
69
  _load_state_dict_into_model,
64
- _merge_sharded_checkpoints,
65
70
  load_model_dict_into_meta,
66
71
  load_state_dict,
67
72
  )
68
73
 
69
74
 
75
+ class ContextManagers:
76
+ """
77
+ Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
78
+ in the `fastcore` library.
79
+ """
80
+
81
+ def __init__(self, context_managers: List[ContextManager]):
82
+ self.context_managers = context_managers
83
+ self.stack = ExitStack()
84
+
85
+ def __enter__(self):
86
+ for context_manager in self.context_managers:
87
+ self.stack.enter_context(context_manager)
88
+
89
+ def __exit__(self, *args, **kwargs):
90
+ self.stack.__exit__(*args, **kwargs)
91
+
92
+
70
93
  logger = logging.get_logger(__name__)
71
94
 
72
95
  _REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
73
96
 
97
+ TORCH_INIT_FUNCTIONS = {
98
+ "uniform_": nn.init.uniform_,
99
+ "normal_": nn.init.normal_,
100
+ "trunc_normal_": nn.init.trunc_normal_,
101
+ "constant_": nn.init.constant_,
102
+ "xavier_uniform_": nn.init.xavier_uniform_,
103
+ "xavier_normal_": nn.init.xavier_normal_,
104
+ "kaiming_uniform_": nn.init.kaiming_uniform_,
105
+ "kaiming_normal_": nn.init.kaiming_normal_,
106
+ "uniform": nn.init.uniform,
107
+ "normal": nn.init.normal,
108
+ "xavier_uniform": nn.init.xavier_uniform,
109
+ "xavier_normal": nn.init.xavier_normal,
110
+ "kaiming_uniform": nn.init.kaiming_uniform,
111
+ "kaiming_normal": nn.init.kaiming_normal,
112
+ }
74
113
 
75
114
  if is_torch_version(">=", "1.9.0"):
76
115
  _LOW_CPU_MEM_USAGE_DEFAULT = True
@@ -80,10 +119,22 @@ else:
80
119
 
81
120
  if is_accelerate_available():
82
121
  import accelerate
122
+ from accelerate import dispatch_model
123
+ from accelerate.utils import load_offloaded_weights, save_offload_index
83
124
 
84
125
 
85
126
  def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
127
+ from ..hooks.group_offloading import _get_group_onload_device
128
+
129
+ try:
130
+ # Try to get the onload device from the group offloading hook
131
+ return _get_group_onload_device(parameter)
132
+ except ValueError:
133
+ pass
134
+
86
135
  try:
136
+ # If the onload device is not available due to no group offloading hooks, try to get the device
137
+ # from the first parameter or buffer
87
138
  parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
88
139
  return next(parameters_and_buffers).device
89
140
  except StopIteration:
@@ -102,9 +153,24 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
102
153
  """
103
154
  Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
104
155
  """
156
+ # 1. Check if we have attached any dtype modifying hooks (eg. layerwise casting)
157
+ if isinstance(parameter, nn.Module):
158
+ for name, submodule in parameter.named_modules():
159
+ if not hasattr(submodule, "_diffusers_hook"):
160
+ continue
161
+ registry = submodule._diffusers_hook
162
+ hook = registry.get_hook("layerwise_casting")
163
+ if hook is not None:
164
+ return hook.compute_dtype
165
+
166
+ # 2. If no dtype modifying hooks are attached, return the dtype of the first floating point parameter/buffer
105
167
  last_dtype = None
106
- for param in parameter.parameters():
168
+
169
+ for name, param in parameter.named_parameters():
107
170
  last_dtype = param.dtype
171
+ if parameter._keep_in_fp32_modules and any(m in name for m in parameter._keep_in_fp32_modules):
172
+ continue
173
+
108
174
  if param.is_floating_point():
109
175
  return param.dtype
110
176
 
@@ -134,6 +200,54 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
134
200
  return last_tuple[1].dtype
135
201
 
136
202
 
203
+ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
204
+ """
205
+ Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
206
+ checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
207
+ parameters.
208
+
209
+ """
210
+ if model_to_load.device.type == "meta":
211
+ return False
212
+
213
+ if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
214
+ return False
215
+
216
+ # Some models explicitly do not support param buffer assignment
217
+ if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
218
+ logger.debug(
219
+ f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
220
+ )
221
+ return False
222
+
223
+ # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
224
+ first_key = next(iter(model_to_load.state_dict().keys()))
225
+ if start_prefix + first_key in state_dict:
226
+ return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
227
+
228
+ return False
229
+
230
+
231
+ @contextmanager
232
+ def no_init_weights():
233
+ """
234
+ Context manager to globally disable weight initialization to speed up loading large models. To do that, all the
235
+ torch.nn.init function are all replaced with skip.
236
+ """
237
+
238
+ def _skip_init(*args, **kwargs):
239
+ pass
240
+
241
+ for name, init_func in TORCH_INIT_FUNCTIONS.items():
242
+ setattr(torch.nn.init, name, _skip_init)
243
+ try:
244
+ yield
245
+ finally:
246
+ # Restore the original initialization functions
247
+ for name, init_func in TORCH_INIT_FUNCTIONS.items():
248
+ setattr(torch.nn.init, name, init_func)
249
+
250
+
137
251
  class ModelMixin(torch.nn.Module, PushToHubMixin):
138
252
  r"""
139
253
  Base class for all models.
@@ -150,10 +264,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
150
264
  _keys_to_ignore_on_load_unexpected = None
151
265
  _no_split_modules = None
152
266
  _keep_in_fp32_modules = None
267
+ _skip_layerwise_casting_patterns = None
268
+ _supports_group_offloading = True
153
269
 
154
270
  def __init__(self):
155
271
  super().__init__()
156
272
 
273
+ self._gradient_checkpointing_func = None
274
+
157
275
  def __getattr__(self, name: str) -> Any:
158
276
  """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
159
277
  config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
@@ -179,14 +297,35 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
179
297
  """
180
298
  return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
181
299
 
182
- def enable_gradient_checkpointing(self) -> None:
300
+ def enable_gradient_checkpointing(self, gradient_checkpointing_func: Optional[Callable] = None) -> None:
183
301
  """
184
302
  Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
185
303
  *checkpoint activations* in other frameworks).
304
+
305
+ Args:
306
+ gradient_checkpointing_func (`Callable`, *optional*):
307
+ The function to use for gradient checkpointing. If `None`, the default PyTorch checkpointing function
308
+ is used (`torch.utils.checkpoint.checkpoint`).
186
309
  """
187
310
  if not self._supports_gradient_checkpointing:
188
- raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
189
- self.apply(partial(self._set_gradient_checkpointing, value=True))
311
+ raise ValueError(
312
+ f"{self.__class__.__name__} does not support gradient checkpointing. Please make sure to set the boolean attribute "
313
+ f"`_supports_gradient_checkpointing` to `True` in the class definition."
314
+ )
315
+
316
+ if gradient_checkpointing_func is None:
317
+
318
+ def _gradient_checkpointing_func(module, *args):
319
+ ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
320
+ return torch.utils.checkpoint.checkpoint(
321
+ module.__call__,
322
+ *args,
323
+ **ckpt_kwargs,
324
+ )
325
+
326
+ gradient_checkpointing_func = _gradient_checkpointing_func
327
+
328
+ self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
190
329
 
191
330
  def disable_gradient_checkpointing(self) -> None:
192
331
  """
@@ -194,7 +333,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
194
333
  *checkpoint activations* in other frameworks).
195
334
  """
196
335
  if self._supports_gradient_checkpointing:
197
- self.apply(partial(self._set_gradient_checkpointing, value=False))
336
+ self._set_gradient_checkpointing(enable=False)
198
337
 
199
338
  def set_use_npu_flash_attention(self, valid: bool) -> None:
200
339
  r"""
@@ -227,14 +366,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
227
366
  self.set_use_npu_flash_attention(False)
228
367
 
229
368
  def set_use_xla_flash_attention(
230
- self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None
369
+ self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None, **kwargs
231
370
  ) -> None:
232
371
  # Recursively walk through all the children.
233
372
  # Any children which exposes the set_use_xla_flash_attention method
234
373
  # gets the message
235
374
  def fn_recursive_set_flash_attention(module: torch.nn.Module):
236
375
  if hasattr(module, "set_use_xla_flash_attention"):
237
- module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec)
376
+ module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec, **kwargs)
238
377
 
239
378
  for child in module.children():
240
379
  fn_recursive_set_flash_attention(child)
@@ -243,11 +382,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
243
382
  if isinstance(module, torch.nn.Module):
244
383
  fn_recursive_set_flash_attention(module)
245
384
 
246
- def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None):
385
+ def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None, **kwargs):
247
386
  r"""
248
387
  Enable the flash attention pallals kernel for torch_xla.
249
388
  """
250
- self.set_use_xla_flash_attention(True, partition_spec)
389
+ self.set_use_xla_flash_attention(True, partition_spec, **kwargs)
251
390
 
252
391
  def disable_xla_flash_attention(self):
253
392
  r"""
@@ -314,6 +453,152 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
314
453
  """
315
454
  self.set_use_memory_efficient_attention_xformers(False)
316
455
 
456
+ def enable_layerwise_casting(
457
+ self,
458
+ storage_dtype: torch.dtype = torch.float8_e4m3fn,
459
+ compute_dtype: Optional[torch.dtype] = None,
460
+ skip_modules_pattern: Optional[Tuple[str, ...]] = None,
461
+ skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
462
+ non_blocking: bool = False,
463
+ ) -> None:
464
+ r"""
465
+ Activates layerwise casting for the current model.
466
+
467
+ Layerwise casting is a technique that casts the model weights to a lower precision dtype for storage but
468
+ upcasts them on-the-fly to a higher precision dtype for computation. This process can significantly reduce the
469
+ memory footprint from model weights, but may lead to some quality degradation in the outputs. Most degradations
470
+ are negligible, mostly stemming from weight casting in normalization and modulation layers.
471
+
472
+ By default, most models in diffusers set the `_skip_layerwise_casting_patterns` attribute to ignore patch
473
+ embedding, positional embedding and normalization layers. This is because these layers are most likely
474
+ precision-critical for quality. If you wish to change this behavior, you can set the
475
+ `_skip_layerwise_casting_patterns` attribute to `None`, or call
476
+ [`~hooks.layerwise_casting.apply_layerwise_casting`] with custom arguments.
477
+
478
+ Example:
479
+ Using [`~models.ModelMixin.enable_layerwise_casting`]:
480
+
481
+ ```python
482
+ >>> from diffusers import CogVideoXTransformer3DModel
483
+
484
+ >>> transformer = CogVideoXTransformer3DModel.from_pretrained(
485
+ ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
486
+ ... )
487
+
488
+ >>> # Enable layerwise casting via the model, which ignores certain modules by default
489
+ >>> transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
490
+ ```
491
+
492
+ Args:
493
+ storage_dtype (`torch.dtype`):
494
+ The dtype to which the model should be cast for storage.
495
+ compute_dtype (`torch.dtype`):
496
+ The dtype to which the model weights should be cast during the forward pass.
497
+ skip_modules_pattern (`Tuple[str, ...]`, *optional*):
498
+ A list of patterns to match the names of the modules to skip during the layerwise casting process. If
499
+ set to `None`, default skip patterns are used to ignore certain internal layers of modules and PEFT
500
+ layers.
501
+ skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, *optional*):
502
+ A list of module classes to skip during the layerwise casting process.
503
+ non_blocking (`bool`, *optional*, defaults to `False`):
504
+ If `True`, the weight casting operations are non-blocking.
505
+ """
506
+ from ..hooks import apply_layerwise_casting
507
+
508
+ user_provided_patterns = True
509
+ if skip_modules_pattern is None:
510
+ from ..hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
511
+
512
+ skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
513
+ user_provided_patterns = False
514
+ if self._keep_in_fp32_modules is not None:
515
+ skip_modules_pattern += tuple(self._keep_in_fp32_modules)
516
+ if self._skip_layerwise_casting_patterns is not None:
517
+ skip_modules_pattern += tuple(self._skip_layerwise_casting_patterns)
518
+ skip_modules_pattern = tuple(set(skip_modules_pattern))
519
+
520
+ if is_peft_available() and not user_provided_patterns:
521
+ # By default, we want to skip all peft layers because they have a very low memory footprint.
522
+ # If users want to apply layerwise casting on peft layers as well, they can utilize the
523
+ # `~diffusers.hooks.layerwise_casting.apply_layerwise_casting` function which provides
524
+ # them with more flexibility and control.
525
+
526
+ from peft.tuners.loha.layer import LoHaLayer
527
+ from peft.tuners.lokr.layer import LoKrLayer
528
+ from peft.tuners.lora.layer import LoraLayer
529
+
530
+ for layer in (LoHaLayer, LoKrLayer, LoraLayer):
531
+ skip_modules_pattern += tuple(layer.adapter_layer_names)
532
+
533
+ if compute_dtype is None:
534
+ logger.info("`compute_dtype` not provided when enabling layerwise casting. Using dtype of the model.")
535
+ compute_dtype = self.dtype
536
+
537
+ apply_layerwise_casting(
538
+ self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
539
+ )
540
+
541
+ def enable_group_offload(
542
+ self,
543
+ onload_device: torch.device,
544
+ offload_device: torch.device = torch.device("cpu"),
545
+ offload_type: str = "block_level",
546
+ num_blocks_per_group: Optional[int] = None,
547
+ non_blocking: bool = False,
548
+ use_stream: bool = False,
549
+ record_stream: bool = False,
550
+ low_cpu_mem_usage=False,
551
+ ) -> None:
552
+ r"""
553
+ Activates group offloading for the current model.
554
+
555
+ See [`~hooks.group_offloading.apply_group_offloading`] for more information.
556
+
557
+ Example:
558
+
559
+ ```python
560
+ >>> from diffusers import CogVideoXTransformer3DModel
561
+
562
+ >>> transformer = CogVideoXTransformer3DModel.from_pretrained(
563
+ ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
564
+ ... )
565
+
566
+ >>> transformer.enable_group_offload(
567
+ ... onload_device=torch.device("cuda"),
568
+ ... offload_device=torch.device("cpu"),
569
+ ... offload_type="leaf_level",
570
+ ... use_stream=True,
571
+ ... )
572
+ ```
573
+ """
574
+ from ..hooks import apply_group_offloading
575
+
576
+ if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
577
+ msg = (
578
+ "Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "
579
+ "forward pass is executed with tiling enabled. Please make sure to either:\n"
580
+ "1. Run a forward pass with small input shapes.\n"
581
+ "2. Or, run a forward pass with tiling disabled (can still use small dummy inputs)."
582
+ )
583
+ logger.warning(msg)
584
+ if not self._supports_group_offloading:
585
+ raise ValueError(
586
+ f"{self.__class__.__name__} does not support group offloading. Please make sure to set the boolean attribute "
587
+ f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
588
+ f"open an issue at https://github.com/huggingface/diffusers/issues."
589
+ )
590
+ apply_group_offloading(
591
+ self,
592
+ onload_device,
593
+ offload_device,
594
+ offload_type,
595
+ num_blocks_per_group,
596
+ non_blocking,
597
+ use_stream,
598
+ record_stream,
599
+ low_cpu_mem_usage=low_cpu_mem_usage,
600
+ )
601
+
317
602
  def save_pretrained(
318
603
  self,
319
604
  save_directory: Union[str, os.PathLike],
@@ -426,7 +711,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
426
711
  os.remove(full_filename)
427
712
 
428
713
  for filename, tensors in state_dict_split.filename_to_tensors.items():
429
- shard = {tensor: state_dict[tensor] for tensor in tensors}
714
+ shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
430
715
  filepath = os.path.join(save_directory, filename)
431
716
  if safe_serialization:
432
717
  # At some point we will need to deal better with save_function (used for TPU and other distributed
@@ -483,7 +768,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
483
768
 
484
769
  @classmethod
485
770
  @validate_hf_hub_args
486
- def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
771
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs) -> Self:
487
772
  r"""
488
773
  Instantiate a pretrained PyTorch model from a pretrained model configuration.
489
774
 
@@ -559,6 +844,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
559
844
  If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
560
845
  `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
561
846
  weights. If set to `False`, `safetensors` weights are not loaded.
847
+ disable_mmap ('bool', *optional*, defaults to 'False'):
848
+ Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
849
+ is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
562
850
 
563
851
  <Tip>
564
852
 
@@ -599,11 +887,19 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
599
887
  device_map = kwargs.pop("device_map", None)
600
888
  max_memory = kwargs.pop("max_memory", None)
601
889
  offload_folder = kwargs.pop("offload_folder", None)
602
- offload_state_dict = kwargs.pop("offload_state_dict", False)
890
+ offload_state_dict = kwargs.pop("offload_state_dict", None)
603
891
  low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
604
892
  variant = kwargs.pop("variant", None)
605
893
  use_safetensors = kwargs.pop("use_safetensors", None)
606
894
  quantization_config = kwargs.pop("quantization_config", None)
895
+ dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
896
+ disable_mmap = kwargs.pop("disable_mmap", False)
897
+
898
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
899
+ torch_dtype = torch.float32
900
+ logger.warning(
901
+ f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
902
+ )
607
903
 
608
904
  allow_pickle = False
609
905
  if use_safetensors is None:
@@ -674,14 +970,15 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
674
970
  # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
675
971
  raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
676
972
 
677
- # Load config if we don't provide a configuration
678
- config_path = pretrained_model_name_or_path
679
-
680
973
  user_agent = {
681
974
  "diffusers": __version__,
682
975
  "file_type": "model",
683
976
  "framework": "pytorch",
684
977
  }
978
+ unused_kwargs = {}
979
+
980
+ # Load config if we don't provide a configuration
981
+ config_path = pretrained_model_name_or_path
685
982
 
686
983
  # load config
687
984
  config, unused_kwargs, commit_hash = cls.load_config(
@@ -696,6 +993,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
696
993
  revision=revision,
697
994
  subfolder=subfolder,
698
995
  user_agent=user_agent,
996
+ dduf_entries=dduf_entries,
699
997
  **kwargs,
700
998
  )
701
999
  # no in-place modification of the original config.
@@ -718,13 +1016,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
718
1016
  hf_quantizer = None
719
1017
 
720
1018
  if hf_quantizer is not None:
721
- if device_map is not None:
722
- raise NotImplementedError(
723
- "Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future."
724
- )
725
-
726
1019
  hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
727
1020
  torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
1021
+ device_map = hf_quantizer.update_device_map(device_map)
728
1022
 
729
1023
  # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
730
1024
  user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
@@ -737,9 +1031,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
737
1031
  raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
738
1032
 
739
1033
  # Check if `_keep_in_fp32_modules` is not None
740
- use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
741
- (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
1034
+ use_keep_in_fp32_modules = cls._keep_in_fp32_modules is not None and (
1035
+ hf_quantizer is None or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
742
1036
  )
1037
+
743
1038
  if use_keep_in_fp32_modules:
744
1039
  keep_in_fp32_modules = cls._keep_in_fp32_modules
745
1040
  if not isinstance(keep_in_fp32_modules, list):
@@ -752,10 +1047,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
752
1047
  raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.")
753
1048
  else:
754
1049
  keep_in_fp32_modules = []
755
- #######################################
756
1050
 
757
- # Determine if we're loading from a directory of sharded checkpoints.
758
1051
  is_sharded = False
1052
+ resolved_model_file = None
1053
+
1054
+ # Determine if we're loading from a directory of sharded checkpoints.
1055
+ sharded_metadata = None
759
1056
  index_file = None
760
1057
  is_local = os.path.isdir(pretrained_model_name_or_path)
761
1058
  index_file_kwargs = {
@@ -772,22 +1069,22 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
772
1069
  "revision": revision,
773
1070
  "user_agent": user_agent,
774
1071
  "commit_hash": commit_hash,
1072
+ "dduf_entries": dduf_entries,
775
1073
  }
776
1074
  index_file = _fetch_index_file(**index_file_kwargs)
777
1075
  # In case the index file was not found we still have to consider the legacy format.
778
1076
  # this becomes applicable when the variant is not None.
779
1077
  if variant is not None and (index_file is None or not os.path.exists(index_file)):
780
1078
  index_file = _fetch_index_file_legacy(**index_file_kwargs)
781
- if index_file is not None and index_file.is_file():
1079
+ if index_file is not None and (dduf_entries or index_file.is_file()):
782
1080
  is_sharded = True
783
1081
 
784
1082
  if is_sharded and from_flax:
785
1083
  raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
786
1084
 
787
1085
  # load model
788
- model_file = None
789
1086
  if from_flax:
790
- model_file = _get_model_file(
1087
+ resolved_model_file = _get_model_file(
791
1088
  pretrained_model_name_or_path,
792
1089
  weights_name=FLAX_WEIGHTS_NAME,
793
1090
  cache_dir=cache_dir,
@@ -805,10 +1102,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
805
1102
  # Convert the weights
806
1103
  from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
807
1104
 
808
- model = load_flax_checkpoint_in_pytorch_model(model, model_file)
1105
+ model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
809
1106
  else:
1107
+ # in the case it is sharded, we have already the index
810
1108
  if is_sharded:
811
- sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
1109
+ resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
812
1110
  pretrained_model_name_or_path,
813
1111
  index_file,
814
1112
  cache_dir=cache_dir,
@@ -818,16 +1116,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
818
1116
  user_agent=user_agent,
819
1117
  revision=revision,
820
1118
  subfolder=subfolder or "",
1119
+ dduf_entries=dduf_entries,
821
1120
  )
822
- # TODO: https://github.com/huggingface/diffusers/issues/10013
823
- if hf_quantizer is not None:
824
- model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
825
- logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
826
- is_sharded = False
827
-
828
- elif use_safetensors and not is_sharded:
1121
+ elif use_safetensors:
829
1122
  try:
830
- model_file = _get_model_file(
1123
+ resolved_model_file = _get_model_file(
831
1124
  pretrained_model_name_or_path,
832
1125
  weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
833
1126
  cache_dir=cache_dir,
@@ -839,6 +1132,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
839
1132
  subfolder=subfolder,
840
1133
  user_agent=user_agent,
841
1134
  commit_hash=commit_hash,
1135
+ dduf_entries=dduf_entries,
842
1136
  )
843
1137
 
844
1138
  except IOError as e:
@@ -849,8 +1143,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
849
1143
  "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
850
1144
  )
851
1145
 
852
- if model_file is None and not is_sharded:
853
- model_file = _get_model_file(
1146
+ if resolved_model_file is None and not is_sharded:
1147
+ resolved_model_file = _get_model_file(
854
1148
  pretrained_model_name_or_path,
855
1149
  weights_name=_add_variant(WEIGHTS_NAME, variant),
856
1150
  cache_dir=cache_dir,
@@ -862,156 +1156,107 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
862
1156
  subfolder=subfolder,
863
1157
  user_agent=user_agent,
864
1158
  commit_hash=commit_hash,
1159
+ dduf_entries=dduf_entries,
865
1160
  )
866
1161
 
867
- if low_cpu_mem_usage:
868
- # Instantiate model with empty weights
869
- with accelerate.init_empty_weights():
870
- model = cls.from_config(config, **unused_kwargs)
1162
+ if not isinstance(resolved_model_file, list):
1163
+ resolved_model_file = [resolved_model_file]
871
1164
 
872
- if hf_quantizer is not None:
873
- hf_quantizer.preprocess_model(
874
- model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
875
- )
1165
+ # set dtype to instantiate the model under:
1166
+ # 1. If torch_dtype is not None, we use that dtype
1167
+ # 2. If torch_dtype is float8, we don't use _set_default_torch_dtype and we downcast after loading the model
1168
+ dtype_orig = None
1169
+ if torch_dtype is not None and not torch_dtype == getattr(torch, "float8_e4m3fn", None):
1170
+ if not isinstance(torch_dtype, torch.dtype):
1171
+ raise ValueError(
1172
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1173
+ )
1174
+ dtype_orig = cls._set_default_torch_dtype(torch_dtype)
876
1175
 
877
- # if device_map is None, load the state dict and move the params from meta device to the cpu
878
- if device_map is None and not is_sharded:
879
- # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
880
- # It would error out during the `validate_environment()` call above in the absence of cuda.
881
- if hf_quantizer is None:
882
- param_device = "cpu"
883
- # TODO (sayakpaul, SunMarc): remove this after model loading refactor
884
- else:
885
- param_device = torch.device(torch.cuda.current_device())
886
- state_dict = load_state_dict(model_file, variant=variant)
887
- model._convert_deprecated_attention_blocks(state_dict)
888
-
889
- # move the params from meta device to cpu
890
- missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
891
- if hf_quantizer is not None:
892
- missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
893
- if len(missing_keys) > 0:
894
- raise ValueError(
895
- f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
896
- f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
897
- " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
898
- " those weights or else make sure your checkpoint file is correct."
899
- )
1176
+ init_contexts = [no_init_weights()]
900
1177
 
901
- unexpected_keys = load_model_dict_into_meta(
902
- model,
903
- state_dict,
904
- device=param_device,
905
- dtype=torch_dtype,
906
- model_name_or_path=pretrained_model_name_or_path,
907
- hf_quantizer=hf_quantizer,
908
- keep_in_fp32_modules=keep_in_fp32_modules,
909
- )
1178
+ if low_cpu_mem_usage:
1179
+ init_contexts.append(accelerate.init_empty_weights())
910
1180
 
911
- if cls._keys_to_ignore_on_load_unexpected is not None:
912
- for pat in cls._keys_to_ignore_on_load_unexpected:
913
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1181
+ with ContextManagers(init_contexts):
1182
+ model = cls.from_config(config, **unused_kwargs)
914
1183
 
915
- if len(unexpected_keys) > 0:
916
- logger.warning(
917
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
918
- )
1184
+ if dtype_orig is not None:
1185
+ torch.set_default_dtype(dtype_orig)
919
1186
 
920
- else: # else let accelerate handle loading and dispatching.
921
- # Load weights and dispatch according to the device_map
922
- # by default the device_map is None and the weights are loaded on the CPU
923
- force_hook = True
924
- device_map = _determine_device_map(
925
- model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
926
- )
927
- if device_map is None and is_sharded:
928
- # we load the parameters on the cpu
929
- device_map = {"": "cpu"}
930
- force_hook = False
931
- try:
932
- accelerate.load_checkpoint_and_dispatch(
933
- model,
934
- model_file if not is_sharded else index_file,
935
- device_map,
936
- max_memory=max_memory,
937
- offload_folder=offload_folder,
938
- offload_state_dict=offload_state_dict,
939
- dtype=torch_dtype,
940
- force_hooks=force_hook,
941
- strict=True,
942
- )
943
- except AttributeError as e:
944
- # When using accelerate loading, we do not have the ability to load the state
945
- # dict and rename the weight names manually. Additionally, accelerate skips
946
- # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
947
- # (which look like they should be private variables?), so we can't use the standard hooks
948
- # to rename parameters on load. We need to mimic the original weight names so the correct
949
- # attributes are available. After we have loaded the weights, we convert the deprecated
950
- # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
951
- # the weights so we don't have to do this again.
952
-
953
- if "'Attention' object has no attribute" in str(e):
954
- logger.warning(
955
- f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
956
- " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
957
- " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
958
- " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
959
- " please also re-upload it or open a PR on the original repository."
960
- )
961
- model._temp_convert_self_to_deprecated_attention_blocks()
962
- accelerate.load_checkpoint_and_dispatch(
963
- model,
964
- model_file if not is_sharded else index_file,
965
- device_map,
966
- max_memory=max_memory,
967
- offload_folder=offload_folder,
968
- offload_state_dict=offload_state_dict,
969
- dtype=torch_dtype,
970
- force_hooks=force_hook,
971
- strict=True,
972
- )
973
- model._undo_temp_convert_self_to_deprecated_attention_blocks()
974
- else:
975
- raise e
976
-
977
- loading_info = {
978
- "missing_keys": [],
979
- "unexpected_keys": [],
980
- "mismatched_keys": [],
981
- "error_msgs": [],
982
- }
983
- else:
984
- model = cls.from_config(config, **unused_kwargs)
1187
+ state_dict = None
1188
+ if not is_sharded:
1189
+ # Time to load the checkpoint
1190
+ state_dict = load_state_dict(resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries)
1191
+ # We only fix it for non sharded checkpoints as we don't need it yet for sharded one.
1192
+ model._fix_state_dict_keys_on_load(state_dict)
985
1193
 
986
- state_dict = load_state_dict(model_file, variant=variant)
987
- model._convert_deprecated_attention_blocks(state_dict)
1194
+ if is_sharded:
1195
+ loaded_keys = sharded_metadata["all_checkpoint_keys"]
1196
+ else:
1197
+ loaded_keys = list(state_dict.keys())
988
1198
 
989
- model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
990
- model,
991
- state_dict,
992
- model_file,
993
- pretrained_model_name_or_path,
994
- ignore_mismatched_sizes=ignore_mismatched_sizes,
995
- )
1199
+ if hf_quantizer is not None:
1200
+ hf_quantizer.preprocess_model(
1201
+ model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
1202
+ )
1203
+
1204
+ # Now that the model is loaded, we can determine the device_map
1205
+ device_map = _determine_device_map(
1206
+ model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
1207
+ )
1208
+ if hf_quantizer is not None:
1209
+ hf_quantizer.validate_environment(device_map=device_map)
1210
+
1211
+ (
1212
+ model,
1213
+ missing_keys,
1214
+ unexpected_keys,
1215
+ mismatched_keys,
1216
+ offload_index,
1217
+ error_msgs,
1218
+ ) = cls._load_pretrained_model(
1219
+ model,
1220
+ state_dict,
1221
+ resolved_model_file,
1222
+ pretrained_model_name_or_path,
1223
+ loaded_keys,
1224
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
1225
+ low_cpu_mem_usage=low_cpu_mem_usage,
1226
+ device_map=device_map,
1227
+ offload_folder=offload_folder,
1228
+ offload_state_dict=offload_state_dict,
1229
+ dtype=torch_dtype,
1230
+ hf_quantizer=hf_quantizer,
1231
+ keep_in_fp32_modules=keep_in_fp32_modules,
1232
+ dduf_entries=dduf_entries,
1233
+ )
1234
+ loading_info = {
1235
+ "missing_keys": missing_keys,
1236
+ "unexpected_keys": unexpected_keys,
1237
+ "mismatched_keys": mismatched_keys,
1238
+ "error_msgs": error_msgs,
1239
+ }
996
1240
 
997
- loading_info = {
998
- "missing_keys": missing_keys,
999
- "unexpected_keys": unexpected_keys,
1000
- "mismatched_keys": mismatched_keys,
1001
- "error_msgs": error_msgs,
1002
- }
1241
+ # Dispatch model with hooks on all devices if necessary
1242
+ if device_map is not None:
1243
+ device_map_kwargs = {
1244
+ "device_map": device_map,
1245
+ "offload_dir": offload_folder,
1246
+ "offload_index": offload_index,
1247
+ }
1248
+ dispatch_model(model, **device_map_kwargs)
1003
1249
 
1004
1250
  if hf_quantizer is not None:
1005
1251
  hf_quantizer.postprocess_model(model)
1006
1252
  model.hf_quantizer = hf_quantizer
1007
1253
 
1008
- if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1009
- raise ValueError(
1010
- f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1011
- )
1012
- # When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will
1013
- # completely lose the effectivity of `use_keep_in_fp32_modules`.
1014
- elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules:
1254
+ if (
1255
+ torch_dtype is not None
1256
+ and torch_dtype == getattr(torch, "float8_e4m3fn", None)
1257
+ and hf_quantizer is None
1258
+ and not use_keep_in_fp32_modules
1259
+ ):
1015
1260
  model = model.to(torch_dtype)
1016
1261
 
1017
1262
  if hf_quantizer is not None:
@@ -1023,6 +1268,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1023
1268
 
1024
1269
  # Set model in evaluation mode to deactivate DropOut modules by default
1025
1270
  model.eval()
1271
+
1026
1272
  if output_loading_info:
1027
1273
  return model, loading_info
1028
1274
 
@@ -1031,6 +1277,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1031
1277
  # Adapted from `transformers`.
1032
1278
  @wraps(torch.nn.Module.cuda)
1033
1279
  def cuda(self, *args, **kwargs):
1280
+ from ..hooks.group_offloading import _is_group_offload_enabled
1281
+
1034
1282
  # Checks if the model has been loaded in 4-bit or 8-bit with BNB
1035
1283
  if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
1036
1284
  if getattr(self, "is_loaded_in_8bit", False):
@@ -1043,13 +1291,34 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1043
1291
  "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
1044
1292
  f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
1045
1293
  )
1294
+
1295
+ # Checks if group offloading is enabled
1296
+ if _is_group_offload_enabled(self):
1297
+ logger.warning(
1298
+ f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.cuda()` is not supported."
1299
+ )
1300
+ return self
1301
+
1046
1302
  return super().cuda(*args, **kwargs)
1047
1303
 
1048
1304
  # Adapted from `transformers`.
1049
1305
  @wraps(torch.nn.Module.to)
1050
1306
  def to(self, *args, **kwargs):
1307
+ from ..hooks.group_offloading import _is_group_offload_enabled
1308
+
1309
+ device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs
1051
1310
  dtype_present_in_args = "dtype" in kwargs
1052
1311
 
1312
+ # Try converting arguments to torch.device in case they are passed as strings
1313
+ for arg in args:
1314
+ if not isinstance(arg, str):
1315
+ continue
1316
+ try:
1317
+ torch.device(arg)
1318
+ device_arg_or_kwarg_present = True
1319
+ except RuntimeError:
1320
+ pass
1321
+
1053
1322
  if not dtype_present_in_args:
1054
1323
  for arg in args:
1055
1324
  if isinstance(arg, torch.dtype):
@@ -1074,6 +1343,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1074
1343
  "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
1075
1344
  f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
1076
1345
  )
1346
+
1347
+ if _is_group_offload_enabled(self) and device_arg_or_kwarg_present:
1348
+ logger.warning(
1349
+ f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported."
1350
+ )
1351
+ return self
1352
+
1077
1353
  return super().to(*args, **kwargs)
1078
1354
 
1079
1355
  # Taken from `transformers`.
@@ -1103,54 +1379,127 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1103
1379
  cls,
1104
1380
  model,
1105
1381
  state_dict: OrderedDict,
1106
- resolved_archive_file,
1382
+ resolved_model_file: List[str],
1107
1383
  pretrained_model_name_or_path: Union[str, os.PathLike],
1384
+ loaded_keys: List[str],
1108
1385
  ignore_mismatched_sizes: bool = False,
1386
+ assign_to_params_buffers: bool = False,
1387
+ hf_quantizer: Optional[DiffusersQuantizer] = None,
1388
+ low_cpu_mem_usage: bool = True,
1389
+ dtype: Optional[Union[str, torch.dtype]] = None,
1390
+ keep_in_fp32_modules: Optional[List[str]] = None,
1391
+ device_map: Dict[str, Union[int, str, torch.device]] = None,
1392
+ offload_state_dict: Optional[bool] = None,
1393
+ offload_folder: Optional[Union[str, os.PathLike]] = None,
1394
+ dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
1109
1395
  ):
1110
- # Retrieve missing & unexpected_keys
1111
1396
  model_state_dict = model.state_dict()
1112
- loaded_keys = list(state_dict.keys())
1113
-
1114
1397
  expected_keys = list(model_state_dict.keys())
1115
-
1116
- original_loaded_keys = loaded_keys
1117
-
1118
1398
  missing_keys = list(set(expected_keys) - set(loaded_keys))
1399
+ if hf_quantizer is not None:
1400
+ missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
1119
1401
  unexpected_keys = list(set(loaded_keys) - set(expected_keys))
1402
+ # Some models may have keys that are not in the state by design, removing them before needlessly warning
1403
+ # the user.
1404
+ if cls._keys_to_ignore_on_load_unexpected is not None:
1405
+ for pat in cls._keys_to_ignore_on_load_unexpected:
1406
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1120
1407
 
1121
- # Make sure we are able to load base models as well as derived models (with heads)
1122
- model_to_load = model
1408
+ mismatched_keys = []
1123
1409
 
1124
- def _find_mismatched_keys(
1125
- state_dict,
1126
- model_state_dict,
1127
- loaded_keys,
1128
- ignore_mismatched_sizes,
1129
- ):
1130
- mismatched_keys = []
1131
- if ignore_mismatched_sizes:
1132
- for checkpoint_key in loaded_keys:
1133
- model_key = checkpoint_key
1134
-
1135
- if (
1136
- model_key in model_state_dict
1137
- and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1138
- ):
1139
- mismatched_keys.append(
1140
- (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1141
- )
1142
- del state_dict[checkpoint_key]
1143
- return mismatched_keys
1410
+ assign_to_params_buffers = None
1411
+ error_msgs = []
1412
+
1413
+ # Deal with offload
1414
+ if device_map is not None and "disk" in device_map.values():
1415
+ if offload_folder is None:
1416
+ raise ValueError(
1417
+ "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
1418
+ " for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
1419
+ " offers the weights in this format."
1420
+ )
1421
+ if offload_folder is not None:
1422
+ os.makedirs(offload_folder, exist_ok=True)
1423
+ if offload_state_dict is None:
1424
+ offload_state_dict = True
1425
+
1426
+ offload_index = {} if device_map is not None and "disk" in device_map.values() else None
1427
+ if offload_state_dict:
1428
+ state_dict_folder = tempfile.mkdtemp()
1429
+ state_dict_index = {}
1430
+ else:
1431
+ state_dict_folder = None
1432
+ state_dict_index = None
1144
1433
 
1145
1434
  if state_dict is not None:
1146
- # Whole checkpoint
1147
- mismatched_keys = _find_mismatched_keys(
1435
+ # load_state_dict will manage the case where we pass a dict instead of a file
1436
+ # if state dict is not None, it means that we don't need to read the files from resolved_model_file also
1437
+ resolved_model_file = [state_dict]
1438
+
1439
+ if len(resolved_model_file) > 1:
1440
+ resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
1441
+
1442
+ for shard_file in resolved_model_file:
1443
+ state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
1444
+
1445
+ def _find_mismatched_keys(
1148
1446
  state_dict,
1149
1447
  model_state_dict,
1150
- original_loaded_keys,
1448
+ loaded_keys,
1449
+ ignore_mismatched_sizes,
1450
+ ):
1451
+ mismatched_keys = []
1452
+ if ignore_mismatched_sizes:
1453
+ for checkpoint_key in loaded_keys:
1454
+ model_key = checkpoint_key
1455
+ # If the checkpoint is sharded, we may not have the key here.
1456
+ if checkpoint_key not in state_dict:
1457
+ continue
1458
+
1459
+ if (
1460
+ model_key in model_state_dict
1461
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1462
+ ):
1463
+ mismatched_keys.append(
1464
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1465
+ )
1466
+ del state_dict[checkpoint_key]
1467
+ return mismatched_keys
1468
+
1469
+ mismatched_keys += _find_mismatched_keys(
1470
+ state_dict,
1471
+ model_state_dict,
1472
+ loaded_keys,
1151
1473
  ignore_mismatched_sizes,
1152
1474
  )
1153
- error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
1475
+
1476
+ if low_cpu_mem_usage:
1477
+ offload_index, state_dict_index = load_model_dict_into_meta(
1478
+ model,
1479
+ state_dict,
1480
+ device_map=device_map,
1481
+ dtype=dtype,
1482
+ hf_quantizer=hf_quantizer,
1483
+ keep_in_fp32_modules=keep_in_fp32_modules,
1484
+ unexpected_keys=unexpected_keys,
1485
+ offload_folder=offload_folder,
1486
+ offload_index=offload_index,
1487
+ state_dict_index=state_dict_index,
1488
+ state_dict_folder=state_dict_folder,
1489
+ )
1490
+ else:
1491
+ if assign_to_params_buffers is None:
1492
+ assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
1493
+
1494
+ error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
1495
+
1496
+ if offload_index is not None and len(offload_index) > 0:
1497
+ save_offload_index(offload_index, offload_folder)
1498
+ offload_index = None
1499
+
1500
+ if offload_state_dict:
1501
+ load_offloaded_weights(model, state_dict_index, state_dict_folder)
1502
+ shutil.rmtree(state_dict_folder)
1154
1503
 
1155
1504
  if len(error_msgs) > 0:
1156
1505
  error_msg = "\n\t".join(error_msgs)
@@ -1162,17 +1511,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1162
1511
 
1163
1512
  if len(unexpected_keys) > 0:
1164
1513
  logger.warning(
1165
- f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
1166
- f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1167
- f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
1168
- " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1169
- " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1170
- f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
1171
- " identical (initializing a BertForSequenceClassification model from a"
1172
- " BertForSequenceClassification model)."
1514
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1173
1515
  )
1174
1516
  else:
1175
1517
  logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1518
+
1176
1519
  if len(missing_keys) > 0:
1177
1520
  logger.warning(
1178
1521
  f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
@@ -1200,7 +1543,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1200
1543
  " able to use it for predictions and inference."
1201
1544
  )
1202
1545
 
1203
- return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1546
+ return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs
1204
1547
 
1205
1548
  @classmethod
1206
1549
  def _get_signature_keys(cls, obj):
@@ -1214,7 +1557,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1214
1557
  # Adapted from `transformers` modeling_utils.py
1215
1558
  def _get_no_split_modules(self, device_map: str):
1216
1559
  """
1217
- Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
1560
+ Get the modules of the model that should not be split when using device_map. We iterate through the modules to
1218
1561
  get the underlying `_no_split_modules`.
1219
1562
 
1220
1563
  Args:
@@ -1241,6 +1584,33 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1241
1584
  modules_to_check += list(module.children())
1242
1585
  return list(_no_split_modules)
1243
1586
 
1587
+ @classmethod
1588
+ def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
1589
+ """
1590
+ Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
1591
+ under specific dtype.
1592
+
1593
+ Args:
1594
+ dtype (`torch.dtype`):
1595
+ a floating dtype to set to.
1596
+
1597
+ Returns:
1598
+ `torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
1599
+ modified. If it wasn't, returns `None`.
1600
+
1601
+ Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
1602
+ `torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
1603
+ """
1604
+ if not dtype.is_floating_point:
1605
+ raise ValueError(
1606
+ f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
1607
+ )
1608
+
1609
+ logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
1610
+ dtype_orig = torch.get_default_dtype()
1611
+ torch.set_default_dtype(dtype)
1612
+ return dtype_orig
1613
+
1244
1614
  @property
1245
1615
  def device(self) -> torch.device:
1246
1616
  """
@@ -1338,7 +1708,31 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1338
1708
  mem = mem + mem_bufs
1339
1709
  return mem
1340
1710
 
1341
- def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
1711
+ def _set_gradient_checkpointing(
1712
+ self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
1713
+ ) -> None:
1714
+ is_gradient_checkpointing_set = False
1715
+
1716
+ for name, module in self.named_modules():
1717
+ if hasattr(module, "gradient_checkpointing"):
1718
+ logger.debug(f"Setting `gradient_checkpointing={enable}` for '{name}'")
1719
+ module._gradient_checkpointing_func = gradient_checkpointing_func
1720
+ module.gradient_checkpointing = enable
1721
+ is_gradient_checkpointing_set = True
1722
+
1723
+ if not is_gradient_checkpointing_set:
1724
+ raise ValueError(
1725
+ f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to "
1726
+ f"use a module that supports gradient checkpointing by creating a boolean attribute `gradient_checkpointing`."
1727
+ )
1728
+
1729
+ def _fix_state_dict_keys_on_load(self, state_dict: OrderedDict) -> None:
1730
+ """
1731
+ This function fix the state dict of the model to take into account some changes that were made in the model
1732
+ architecture:
1733
+ - deprecated attention blocks (happened before we introduced sharded checkpoint,
1734
+ so this is why we apply this method only when loading non sharded checkpoints for now)
1735
+ """
1342
1736
  deprecated_attention_block_paths = []
1343
1737
 
1344
1738
  def recursive_find_attn_block(name, module):
@@ -1381,56 +1775,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1381
1775
  state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
1382
1776
  if f"{path}.proj_attn.bias" in state_dict:
1383
1777
  state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
1384
-
1385
- def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
1386
- deprecated_attention_block_modules = []
1387
-
1388
- def recursive_find_attn_block(module):
1389
- if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1390
- deprecated_attention_block_modules.append(module)
1391
-
1392
- for sub_module in module.children():
1393
- recursive_find_attn_block(sub_module)
1394
-
1395
- recursive_find_attn_block(self)
1396
-
1397
- for module in deprecated_attention_block_modules:
1398
- module.query = module.to_q
1399
- module.key = module.to_k
1400
- module.value = module.to_v
1401
- module.proj_attn = module.to_out[0]
1402
-
1403
- # We don't _have_ to delete the old attributes, but it's helpful to ensure
1404
- # that _all_ the weights are loaded into the new attributes and we're not
1405
- # making an incorrect assumption that this model should be converted when
1406
- # it really shouldn't be.
1407
- del module.to_q
1408
- del module.to_k
1409
- del module.to_v
1410
- del module.to_out
1411
-
1412
- def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
1413
- deprecated_attention_block_modules = []
1414
-
1415
- def recursive_find_attn_block(module) -> None:
1416
- if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1417
- deprecated_attention_block_modules.append(module)
1418
-
1419
- for sub_module in module.children():
1420
- recursive_find_attn_block(sub_module)
1421
-
1422
- recursive_find_attn_block(self)
1423
-
1424
- for module in deprecated_attention_block_modules:
1425
- module.to_q = module.query
1426
- module.to_k = module.key
1427
- module.to_v = module.value
1428
- module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
1429
-
1430
- del module.query
1431
- del module.key
1432
- del module.value
1433
- del module.proj_attn
1778
+ return state_dict
1434
1779
 
1435
1780
 
1436
1781
  class LegacyModelMixin(ModelMixin):