diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor
23
23
  from ...image_processor import VaeImageProcessor
24
24
  from ...models import AutoencoderKL, UNet2DConditionModel
25
25
  from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
26
- from ...utils import deprecate, logging
26
+ from ...utils import deprecate, is_torch_xla_available, logging
27
27
  from ...utils.torch_utils import randn_tensor
28
28
  from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
29
29
  from ..stable_diffusion import StableDiffusionPipelineOutput
@@ -31,6 +31,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
31
31
  from .image_encoder import PaintByExampleImageEncoder
32
32
 
33
33
 
34
+ if is_torch_xla_available():
35
+ import torch_xla.core.xla_model as xm
36
+
37
+ XLA_AVAILABLE = True
38
+ else:
39
+ XLA_AVAILABLE = False
40
+
34
41
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
42
 
36
43
 
@@ -209,7 +216,7 @@ class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
209
216
  safety_checker=safety_checker,
210
217
  feature_extractor=feature_extractor,
211
218
  )
212
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
219
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
213
220
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
214
221
  self.register_to_config(requires_safety_checker=requires_safety_checker)
215
222
 
@@ -568,7 +575,7 @@ class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
568
575
  f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
569
576
  f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
570
577
  f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
571
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
578
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
572
579
  " `pipeline.unet` or your `mask_image` or `image` input."
573
580
  )
574
581
 
@@ -604,6 +611,9 @@ class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
604
611
  step_idx = i // getattr(self.scheduler, "order", 1)
605
612
  callback(step_idx, t, latents)
606
613
 
614
+ if XLA_AVAILABLE:
615
+ xm.mark_step()
616
+
607
617
  self.maybe_free_model_hooks()
608
618
 
609
619
  if not output_type == "latent":
@@ -37,6 +37,7 @@ from ...schedulers import (
37
37
  from ...utils import (
38
38
  USE_PEFT_BACKEND,
39
39
  BaseOutput,
40
+ is_torch_xla_available,
40
41
  logging,
41
42
  replace_example_docstring,
42
43
  scale_lora_layers,
@@ -48,8 +49,16 @@ from ..free_init_utils import FreeInitMixin
48
49
  from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
49
50
 
50
51
 
52
+ if is_torch_xla_available():
53
+ import torch_xla.core.xla_model as xm
54
+
55
+ XLA_AVAILABLE = True
56
+ else:
57
+ XLA_AVAILABLE = False
58
+
51
59
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
60
 
61
+
53
62
  EXAMPLE_DOC_STRING = """
54
63
  Examples:
55
64
  ```py
@@ -195,7 +204,7 @@ class PIAPipeline(
195
204
  feature_extractor=feature_extractor,
196
205
  image_encoder=image_encoder,
197
206
  )
198
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
207
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
199
208
  self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
200
209
 
201
210
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
@@ -928,6 +937,9 @@ class PIAPipeline(
928
937
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
929
938
  progress_bar.update()
930
939
 
940
+ if XLA_AVAILABLE:
941
+ xm.mark_step()
942
+
931
943
  # 9. Post processing
932
944
  if output_type == "latent":
933
945
  video = latents
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2024 The HuggingFace Inc. team.
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
3
  # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -237,15 +237,15 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
237
237
  If you get the error message below, you need to finetune the weights for your downstream task:
238
238
 
239
239
  ```
240
- Some weights of FlaxUNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
240
+ Some weights of FlaxUNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
241
241
  ```
242
242
 
243
243
  Parameters:
244
244
  pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
245
245
  Can be either:
246
246
 
247
- - A string, the *repo id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained pipeline
248
- hosted on the Hub.
247
+ - A string, the *repo id* (for example `stable-diffusion-v1-5/stable-diffusion-v1-5`) of a
248
+ pretrained pipeline hosted on the Hub.
249
249
  - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
250
250
  using [`~FlaxDiffusionPipeline.save_pretrained`].
251
251
  dtype (`str` or `jnp.dtype`, *optional*):
@@ -293,7 +293,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
293
293
  >>> # Requires to be logged in to Hugging Face hub,
294
294
  >>> # see more in [the documentation](https://huggingface.co/docs/hub/security-tokens)
295
295
  >>> pipeline, params = FlaxDiffusionPipeline.from_pretrained(
296
- ... "runwayml/stable-diffusion-v1-5",
296
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5",
297
297
  ... variant="bf16",
298
298
  ... dtype=jnp.bfloat16,
299
299
  ... )
@@ -301,7 +301,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
301
301
  >>> # Download pipeline, but use a different scheduler
302
302
  >>> from diffusers import FlaxDPMSolverMultistepScheduler
303
303
 
304
- >>> model_id = "runwayml/stable-diffusion-v1-5"
304
+ >>> model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
305
305
  >>> dpmpp, dpmpp_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
306
306
  ... model_id,
307
307
  ... subfolder="scheduler",
@@ -559,7 +559,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
559
559
  ... )
560
560
 
561
561
  >>> text2img = FlaxStableDiffusionPipeline.from_pretrained(
562
- ... "runwayml/stable-diffusion-v1-5", variant="bf16", dtype=jnp.bfloat16
562
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", variant="bf16", dtype=jnp.bfloat16
563
563
  ... )
564
564
  >>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components)
565
565
  ```
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2024 The HuggingFace Inc. team.
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -12,19 +12,19 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
-
16
-
17
15
  import importlib
18
16
  import os
19
17
  import re
20
18
  import warnings
21
19
  from pathlib import Path
22
- from typing import Any, Dict, List, Optional, Union
20
+ from typing import Any, Callable, Dict, List, Optional, Union
23
21
 
22
+ import requests
24
23
  import torch
25
- from huggingface_hub import ModelCard, model_info
26
- from huggingface_hub.utils import validate_hf_hub_args
24
+ from huggingface_hub import DDUFEntry, ModelCard, model_info, snapshot_download
25
+ from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
27
26
  from packaging import version
27
+ from requests.exceptions import HTTPError
28
28
 
29
29
  from .. import __version__
30
30
  from ..utils import (
@@ -38,14 +38,16 @@ from ..utils import (
38
38
  is_accelerate_available,
39
39
  is_peft_available,
40
40
  is_transformers_available,
41
+ is_transformers_version,
41
42
  logging,
42
43
  )
43
44
  from ..utils.torch_utils import is_compiled_module
45
+ from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transformers_model_from_dduf
44
46
 
45
47
 
46
48
  if is_transformers_available():
47
49
  import transformers
48
- from transformers import PreTrainedModel
50
+ from transformers import PreTrainedModel, PreTrainedTokenizerBase
49
51
  from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
50
52
  from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
51
53
  from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
@@ -102,7 +104,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
102
104
  extension is replaced with ".safetensors"
103
105
  """
104
106
  passed_components = passed_components or []
105
- if folder_names is not None:
107
+ if folder_names:
106
108
  filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
107
109
 
108
110
  # extract all components of the pipeline and their associated files
@@ -139,7 +141,25 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
139
141
  return True
140
142
 
141
143
 
142
- def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
144
+ def filter_model_files(filenames):
145
+ """Filter model repo files for just files/folders that contain model weights"""
146
+ weight_names = [
147
+ WEIGHTS_NAME,
148
+ SAFETENSORS_WEIGHTS_NAME,
149
+ FLAX_WEIGHTS_NAME,
150
+ ONNX_WEIGHTS_NAME,
151
+ ONNX_EXTERNAL_WEIGHTS_NAME,
152
+ ]
153
+
154
+ if is_transformers_available():
155
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
156
+
157
+ allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
158
+
159
+ return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
160
+
161
+
162
+ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
143
163
  weight_names = [
144
164
  WEIGHTS_NAME,
145
165
  SAFETENSORS_WEIGHTS_NAME,
@@ -167,6 +187,10 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
167
187
  variant_index_re = re.compile(
168
188
  rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
169
189
  )
190
+ legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
191
+ legacy_variant_index_re = re.compile(
192
+ rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.{variant}\.index\.json$"
193
+ )
170
194
 
171
195
  # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
172
196
  non_variant_file_re = re.compile(
@@ -175,54 +199,68 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
175
199
  # `text_encoder/pytorch_model.bin.index.json`
176
200
  non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
177
201
 
178
- if variant is not None:
179
- variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
180
- variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
181
- variant_filenames = variant_weights | variant_indexes
182
- else:
183
- variant_filenames = set()
202
+ def filter_for_compatible_extensions(filenames, ignore_patterns=None):
203
+ if not ignore_patterns:
204
+ return filenames
205
+
206
+ # ignore patterns uses glob style patterns e.g *.safetensors but we're only
207
+ # interested in the extension name
208
+ return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}
209
+
210
+ def filter_with_regex(filenames, pattern_re):
211
+ return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
212
+
213
+ # Group files by component
214
+ components = {}
215
+ for filename in filenames:
216
+ if not len(filename.split("/")) == 2:
217
+ components.setdefault("", []).append(filename)
218
+ continue
184
219
 
185
- non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
186
- non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
187
- non_variant_filenames = non_variant_weights | non_variant_indexes
220
+ component, _ = filename.split("/")
221
+ components.setdefault(component, []).append(filename)
188
222
 
189
- # all variant filenames will be used by default
190
- usable_filenames = set(variant_filenames)
223
+ usable_filenames = set()
224
+ variant_filenames = set()
225
+ for component, component_filenames in components.items():
226
+ component_filenames = filter_for_compatible_extensions(component_filenames, ignore_patterns=ignore_patterns)
227
+
228
+ component_variants = set()
229
+ component_legacy_variants = set()
230
+ component_non_variants = set()
231
+ if variant is not None:
232
+ component_variants = filter_with_regex(component_filenames, variant_file_re)
233
+ component_variant_index_files = filter_with_regex(component_filenames, variant_index_re)
234
+
235
+ component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re)
236
+ component_legacy_variant_index_files = filter_with_regex(component_filenames, legacy_variant_index_re)
237
+
238
+ if component_variants or component_legacy_variants:
239
+ variant_filenames.update(
240
+ component_variants | component_variant_index_files
241
+ if component_variants
242
+ else component_legacy_variants | component_legacy_variant_index_files
243
+ )
191
244
 
192
- def convert_to_variant(filename):
193
- if "index" in filename:
194
- variant_filename = filename.replace("index", f"index.{variant}")
195
- elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
196
- variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
197
245
  else:
198
- variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
199
- return variant_filename
246
+ component_non_variants = filter_with_regex(component_filenames, non_variant_file_re)
247
+ component_variant_index_files = filter_with_regex(component_filenames, non_variant_index_re)
200
248
 
201
- def find_component(filename):
202
- if not len(filename.split("/")) == 2:
203
- return
204
- component = filename.split("/")[0]
205
- return component
206
-
207
- def has_sharded_variant(component, variant, variant_filenames):
208
- # If component exists check for sharded variant index filename
209
- # If component doesn't exist check main dir for sharded variant index filename
210
- component = component + "/" if component else ""
211
- variant_index_re = re.compile(
212
- rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
213
- )
214
- return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
249
+ usable_filenames.update(component_non_variants | component_variant_index_files)
215
250
 
216
- for filename in non_variant_filenames:
217
- if convert_to_variant(filename) in variant_filenames:
218
- continue
251
+ usable_filenames.update(variant_filenames)
219
252
 
220
- component = find_component(filename)
221
- # If a sharded variant exists skip adding to allowed patterns
222
- if has_sharded_variant(component, variant, variant_filenames):
223
- continue
253
+ if len(variant_filenames) == 0 and variant is not None:
254
+ error_message = f"You are trying to load model files of the `variant={variant}`, but no such modeling files are available. "
255
+ raise ValueError(error_message)
224
256
 
225
- usable_filenames.add(filename)
257
+ if len(variant_filenames) > 0 and usable_filenames != variant_filenames:
258
+ logger.warning(
259
+ f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
260
+ f"[{', '.join(variant_filenames)}]\nLoaded non-{variant} filenames:\n"
261
+ f"[{', '.join(usable_filenames - variant_filenames)}\nIf this behavior is not "
262
+ f"expected, please check your folder structure."
263
+ )
226
264
 
227
265
  return usable_filenames, variant_filenames
228
266
 
@@ -285,9 +323,7 @@ def maybe_raise_or_warn(
285
323
  model_cls = unwrapped_sub_model.__class__
286
324
 
287
325
  if not issubclass(model_cls, expected_class_obj):
288
- raise ValueError(
289
- f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
290
- )
326
+ raise ValueError(f"{passed_class_obj[name]} is of type: {model_cls}, but should be {expected_class_obj}")
291
327
  else:
292
328
  logger.warning(
293
329
  f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
@@ -554,6 +590,11 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
554
590
  loaded_sub_model = passed_class_obj[name]
555
591
 
556
592
  else:
593
+ sub_model_dtype = (
594
+ torch_dtype.get(name, torch_dtype.get("default", torch.float32))
595
+ if isinstance(torch_dtype, dict)
596
+ else torch_dtype
597
+ )
557
598
  loaded_sub_model = _load_empty_model(
558
599
  library_name=library_name,
559
600
  class_name=class_name,
@@ -562,7 +603,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
562
603
  is_pipeline_module=is_pipeline_module,
563
604
  pipeline_class=pipeline_class,
564
605
  name=name,
565
- torch_dtype=torch_dtype,
606
+ torch_dtype=sub_model_dtype,
566
607
  cached_folder=kwargs.get("cached_folder", None),
567
608
  force_download=kwargs.get("force_download", None),
568
609
  proxies=kwargs.get("proxies", None),
@@ -578,7 +619,12 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
578
619
  # Obtain a sorted dictionary for mapping the model-level components
579
620
  # to their sizes.
580
621
  module_sizes = {
581
- module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
622
+ module_name: compute_module_sizes(
623
+ module,
624
+ dtype=torch_dtype.get(module_name, torch_dtype.get("default", torch.float32))
625
+ if isinstance(torch_dtype, dict)
626
+ else torch_dtype,
627
+ )[""]
582
628
  for module_name, module in init_empty_modules.items()
583
629
  if isinstance(module, torch.nn.Module)
584
630
  }
@@ -627,6 +673,8 @@ def load_sub_model(
627
673
  low_cpu_mem_usage: bool,
628
674
  cached_folder: Union[str, os.PathLike],
629
675
  use_safetensors: bool,
676
+ dduf_entries: Optional[Dict[str, DDUFEntry]],
677
+ provider_options: Any,
630
678
  ):
631
679
  """Helper method to load the module `name` from `library_name` and `class_name`"""
632
680
 
@@ -663,7 +711,7 @@ def load_sub_model(
663
711
  f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
664
712
  )
665
713
 
666
- load_method = getattr(class_obj, load_method_name)
714
+ load_method = _get_load_method(class_obj, load_method_name, is_dduf=dduf_entries is not None)
667
715
 
668
716
  # add kwargs to loading method
669
717
  diffusers_module = importlib.import_module(__name__.split(".")[0])
@@ -673,6 +721,7 @@ def load_sub_model(
673
721
  if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
674
722
  loading_kwargs["provider"] = provider
675
723
  loading_kwargs["sess_options"] = sess_options
724
+ loading_kwargs["provider_options"] = provider_options
676
725
 
677
726
  is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
678
727
 
@@ -721,7 +770,10 @@ def load_sub_model(
721
770
  loading_kwargs["low_cpu_mem_usage"] = False
722
771
 
723
772
  # check if the module is in a subdirectory
724
- if os.path.isdir(os.path.join(cached_folder, name)):
773
+ if dduf_entries:
774
+ loading_kwargs["dduf_entries"] = dduf_entries
775
+ loaded_sub_model = load_method(name, **loading_kwargs)
776
+ elif os.path.isdir(os.path.join(cached_folder, name)):
725
777
  loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
726
778
  else:
727
779
  # else load from the root directory
@@ -746,6 +798,22 @@ def load_sub_model(
746
798
  return loaded_sub_model
747
799
 
748
800
 
801
+ def _get_load_method(class_obj: object, load_method_name: str, is_dduf: bool) -> Callable:
802
+ """
803
+ Return the method to load the sub model.
804
+
805
+ In practice, this method will return the `"from_pretrained"` (or `load_method_name`) method of the class object
806
+ except if loading from a DDUF checkpoint. In that case, transformers models and tokenizers have a specific loading
807
+ method that we need to use.
808
+ """
809
+ if is_dduf:
810
+ if issubclass(class_obj, PreTrainedTokenizerBase):
811
+ return lambda *args, **kwargs: _load_tokenizer_from_dduf(class_obj, *args, **kwargs)
812
+ if issubclass(class_obj, PreTrainedModel):
813
+ return lambda *args, **kwargs: _load_transformers_model_from_dduf(class_obj, *args, **kwargs)
814
+ return getattr(class_obj, load_method_name)
815
+
816
+
749
817
  def _fetch_class_library_tuple(module):
750
818
  # import it here to avoid circular import
751
819
  diffusers_module = importlib.import_module(__name__.split(".")[0])
@@ -813,9 +881,9 @@ def _maybe_raise_warning_for_inpainting(pipeline_class, pretrained_model_name_or
813
881
  "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
814
882
  f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
815
883
  " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
816
- " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
884
+ " checkpoint: https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting instead or adapting your"
817
885
  f" checkpoint {pretrained_model_name_or_path} to the format of"
818
- " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
886
+ " https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting. Note that we do not actively maintain"
819
887
  " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
820
888
  )
821
889
  deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
@@ -898,10 +966,6 @@ def _get_custom_components_and_folders(
898
966
  f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
899
967
  )
900
968
 
901
- if len(variant_filenames) == 0 and variant is not None:
902
- error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
903
- raise ValueError(error_message)
904
-
905
969
  return custom_components, folder_names
906
970
 
907
971
 
@@ -909,7 +973,6 @@ def _get_ignore_patterns(
909
973
  passed_components,
910
974
  model_folder_names: List[str],
911
975
  model_filenames: List[str],
912
- variant_filenames: List[str],
913
976
  use_safetensors: bool,
914
977
  from_flax: bool,
915
978
  allow_pickle: bool,
@@ -940,16 +1003,6 @@ def _get_ignore_patterns(
940
1003
  if not use_onnx:
941
1004
  ignore_patterns += ["*.onnx", "*.pb"]
942
1005
 
943
- safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
944
- safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
945
- if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
946
- logger.warning(
947
- f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
948
- f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
949
- f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
950
- f"expected, please check your folder structure."
951
- )
952
-
953
1006
  else:
954
1007
  ignore_patterns = ["*.safetensors", "*.msgpack"]
955
1008
 
@@ -957,14 +1010,71 @@ def _get_ignore_patterns(
957
1010
  if not use_onnx:
958
1011
  ignore_patterns += ["*.onnx", "*.pb"]
959
1012
 
960
- bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
961
- bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
962
- if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
963
- logger.warning(
964
- f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
965
- f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
966
- f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
967
- f"your folder structure."
968
- )
969
-
970
1013
  return ignore_patterns
1014
+
1015
+
1016
+ def _download_dduf_file(
1017
+ pretrained_model_name: str,
1018
+ dduf_file: str,
1019
+ pipeline_class_name: str,
1020
+ cache_dir: str,
1021
+ proxies: str,
1022
+ local_files_only: bool,
1023
+ token: str,
1024
+ revision: str,
1025
+ ):
1026
+ model_info_call_error = None
1027
+ if not local_files_only:
1028
+ try:
1029
+ info = model_info(pretrained_model_name, token=token, revision=revision)
1030
+ except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
1031
+ logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
1032
+ local_files_only = True
1033
+ model_info_call_error = e # save error to reraise it if model is not cached locally
1034
+
1035
+ if (
1036
+ not local_files_only
1037
+ and dduf_file is not None
1038
+ and dduf_file not in (sibling.rfilename for sibling in info.siblings)
1039
+ ):
1040
+ raise ValueError(f"Requested {dduf_file} file is not available in {pretrained_model_name}.")
1041
+
1042
+ try:
1043
+ user_agent = {"pipeline_class": pipeline_class_name, "dduf": True}
1044
+ cached_folder = snapshot_download(
1045
+ pretrained_model_name,
1046
+ cache_dir=cache_dir,
1047
+ proxies=proxies,
1048
+ local_files_only=local_files_only,
1049
+ token=token,
1050
+ revision=revision,
1051
+ allow_patterns=[dduf_file],
1052
+ user_agent=user_agent,
1053
+ )
1054
+ return cached_folder
1055
+ except FileNotFoundError:
1056
+ # Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache.
1057
+ # This can happen in two cases:
1058
+ # 1. If the user passed `local_files_only=True` => we raise the error directly
1059
+ # 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error
1060
+ if model_info_call_error is None:
1061
+ # 1. user passed `local_files_only=True`
1062
+ raise
1063
+ else:
1064
+ # 2. we forced `local_files_only=True` when `model_info` failed
1065
+ raise EnvironmentError(
1066
+ f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred"
1067
+ " while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace"
1068
+ " above."
1069
+ ) from model_info_call_error
1070
+
1071
+
1072
+ def _maybe_raise_error_for_incorrect_transformers(config_dict):
1073
+ has_transformers_component = False
1074
+ for k in config_dict:
1075
+ if isinstance(config_dict[k], list):
1076
+ has_transformers_component = config_dict[k][0] == "transformers"
1077
+ if has_transformers_component:
1078
+ break
1079
+ if has_transformers_component and not is_transformers_version(">", "4.47.1"):
1080
+ raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")