diffusers 0.33.0__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (478) hide show
  1. diffusers/__init__.py +48 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/dependency_versions_check.py +1 -1
  7. diffusers/dependency_versions_table.py +1 -1
  8. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  9. diffusers/hooks/faster_cache.py +2 -2
  10. diffusers/hooks/group_offloading.py +128 -29
  11. diffusers/hooks/hooks.py +2 -2
  12. diffusers/hooks/layerwise_casting.py +3 -3
  13. diffusers/hooks/pyramid_attention_broadcast.py +1 -1
  14. diffusers/image_processor.py +7 -2
  15. diffusers/loaders/__init__.py +4 -0
  16. diffusers/loaders/ip_adapter.py +5 -14
  17. diffusers/loaders/lora_base.py +212 -111
  18. diffusers/loaders/lora_conversion_utils.py +275 -34
  19. diffusers/loaders/lora_pipeline.py +1554 -819
  20. diffusers/loaders/peft.py +52 -109
  21. diffusers/loaders/single_file.py +2 -2
  22. diffusers/loaders/single_file_model.py +20 -4
  23. diffusers/loaders/single_file_utils.py +225 -5
  24. diffusers/loaders/textual_inversion.py +3 -2
  25. diffusers/loaders/transformer_flux.py +1 -1
  26. diffusers/loaders/transformer_sd3.py +2 -2
  27. diffusers/loaders/unet.py +2 -16
  28. diffusers/loaders/unet_loader_utils.py +1 -1
  29. diffusers/loaders/utils.py +1 -1
  30. diffusers/models/__init__.py +15 -1
  31. diffusers/models/activations.py +5 -5
  32. diffusers/models/adapter.py +2 -3
  33. diffusers/models/attention.py +4 -4
  34. diffusers/models/attention_flax.py +10 -10
  35. diffusers/models/attention_processor.py +14 -10
  36. diffusers/models/auto_model.py +47 -10
  37. diffusers/models/autoencoders/__init__.py +1 -0
  38. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  39. diffusers/models/autoencoders/autoencoder_dc.py +3 -3
  40. diffusers/models/autoencoders/autoencoder_kl.py +4 -4
  41. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  42. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  43. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
  44. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  45. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  46. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  47. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  48. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  49. diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
  50. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  52. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  53. diffusers/models/autoencoders/vae.py +13 -2
  54. diffusers/models/autoencoders/vq_model.py +2 -2
  55. diffusers/models/cache_utils.py +1 -1
  56. diffusers/models/controlnet.py +1 -1
  57. diffusers/models/controlnet_flux.py +1 -1
  58. diffusers/models/controlnet_sd3.py +1 -1
  59. diffusers/models/controlnet_sparsectrl.py +1 -1
  60. diffusers/models/controlnets/__init__.py +1 -0
  61. diffusers/models/controlnets/controlnet.py +3 -3
  62. diffusers/models/controlnets/controlnet_flax.py +1 -1
  63. diffusers/models/controlnets/controlnet_flux.py +16 -15
  64. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  65. diffusers/models/controlnets/controlnet_sana.py +290 -0
  66. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  67. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  68. diffusers/models/controlnets/controlnet_union.py +1 -1
  69. diffusers/models/controlnets/controlnet_xs.py +7 -7
  70. diffusers/models/controlnets/multicontrolnet.py +4 -5
  71. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  72. diffusers/models/downsampling.py +2 -2
  73. diffusers/models/embeddings.py +10 -12
  74. diffusers/models/embeddings_flax.py +2 -2
  75. diffusers/models/lora.py +3 -3
  76. diffusers/models/modeling_utils.py +44 -14
  77. diffusers/models/normalization.py +4 -4
  78. diffusers/models/resnet.py +2 -2
  79. diffusers/models/resnet_flax.py +1 -1
  80. diffusers/models/transformers/__init__.py +5 -0
  81. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  82. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  83. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  84. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  85. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  86. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  87. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  88. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  89. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  90. diffusers/models/transformers/prior_transformer.py +1 -1
  91. diffusers/models/transformers/sana_transformer.py +8 -3
  92. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  93. diffusers/models/transformers/t5_film_transformer.py +3 -3
  94. diffusers/models/transformers/transformer_2d.py +1 -1
  95. diffusers/models/transformers/transformer_allegro.py +1 -1
  96. diffusers/models/transformers/transformer_chroma.py +742 -0
  97. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  98. diffusers/models/transformers/transformer_cogview4.py +317 -25
  99. diffusers/models/transformers/transformer_cosmos.py +579 -0
  100. diffusers/models/transformers/transformer_flux.py +9 -11
  101. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  102. diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
  103. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  104. diffusers/models/transformers/transformer_ltx.py +2 -2
  105. diffusers/models/transformers/transformer_lumina2.py +1 -1
  106. diffusers/models/transformers/transformer_mochi.py +1 -1
  107. diffusers/models/transformers/transformer_omnigen.py +2 -2
  108. diffusers/models/transformers/transformer_sd3.py +7 -7
  109. diffusers/models/transformers/transformer_temporal.py +1 -1
  110. diffusers/models/transformers/transformer_wan.py +24 -8
  111. diffusers/models/transformers/transformer_wan_vace.py +393 -0
  112. diffusers/models/unets/unet_1d.py +1 -1
  113. diffusers/models/unets/unet_1d_blocks.py +1 -1
  114. diffusers/models/unets/unet_2d.py +1 -1
  115. diffusers/models/unets/unet_2d_blocks.py +1 -1
  116. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  117. diffusers/models/unets/unet_2d_condition.py +2 -2
  118. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  119. diffusers/models/unets/unet_3d_blocks.py +1 -1
  120. diffusers/models/unets/unet_3d_condition.py +3 -3
  121. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  122. diffusers/models/unets/unet_kandinsky3.py +1 -1
  123. diffusers/models/unets/unet_motion_model.py +2 -2
  124. diffusers/models/unets/unet_stable_cascade.py +1 -1
  125. diffusers/models/upsampling.py +2 -2
  126. diffusers/models/vae_flax.py +2 -2
  127. diffusers/models/vq_model.py +1 -1
  128. diffusers/pipelines/__init__.py +37 -6
  129. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  130. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  131. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  132. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  133. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  134. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  135. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  136. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  137. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  138. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  139. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  140. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  141. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
  142. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  143. diffusers/pipelines/auto_pipeline.py +6 -7
  144. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  145. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  146. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  147. diffusers/pipelines/chroma/__init__.py +49 -0
  148. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  149. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  150. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  151. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
  152. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
  153. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
  154. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
  155. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  156. diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
  157. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  158. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  159. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  160. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  161. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  162. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
  163. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  164. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  165. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  166. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  167. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  168. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
  169. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
  170. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
  171. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  172. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  173. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  174. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  175. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  176. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  177. diffusers/pipelines/cosmos/__init__.py +54 -0
  178. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  179. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  180. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  181. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  182. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  183. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  184. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  185. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  186. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  187. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  188. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  189. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  190. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  191. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  192. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  193. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  194. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  195. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  196. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  197. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  198. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  199. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  200. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  201. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  202. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  203. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  204. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
  205. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  206. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  207. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  208. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  209. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  210. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  211. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  212. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  213. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  214. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  215. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  216. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  217. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  218. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  219. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  220. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  221. diffusers/pipelines/flux/modeling_flux.py +1 -1
  222. diffusers/pipelines/flux/pipeline_flux.py +10 -17
  223. diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
  224. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
  225. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
  226. diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
  227. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
  228. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
  229. diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
  230. diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
  231. diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
  232. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  233. diffusers/pipelines/free_init_utils.py +2 -2
  234. diffusers/pipelines/free_noise_utils.py +3 -3
  235. diffusers/pipelines/hidream_image/__init__.py +47 -0
  236. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  237. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  238. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  239. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  240. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
  241. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  242. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  243. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  244. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  245. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  246. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  247. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  248. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  249. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  250. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  251. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  252. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  253. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  254. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  255. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  256. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  257. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  258. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  259. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  260. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  261. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  262. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  263. diffusers/pipelines/kolors/text_encoder.py +3 -3
  264. diffusers/pipelines/kolors/tokenizer.py +1 -1
  265. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  266. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  267. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  268. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  269. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  270. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  271. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  272. diffusers/pipelines/ltx/__init__.py +4 -0
  273. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  274. diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
  275. diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
  276. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
  277. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  278. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  279. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  280. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  281. diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
  282. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  283. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  284. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  285. diffusers/pipelines/onnx_utils.py +15 -2
  286. diffusers/pipelines/pag/pag_utils.py +2 -2
  287. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  288. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  289. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  290. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  291. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  292. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  293. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  294. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  295. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  296. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  297. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  298. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  299. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  300. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  301. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  302. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  303. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  304. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  305. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  306. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  307. diffusers/pipelines/pipeline_flax_utils.py +3 -4
  308. diffusers/pipelines/pipeline_loading_utils.py +89 -13
  309. diffusers/pipelines/pipeline_utils.py +105 -33
  310. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
  311. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
  312. diffusers/pipelines/sana/__init__.py +4 -0
  313. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  314. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  315. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  316. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  317. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  318. diffusers/pipelines/shap_e/camera.py +1 -1
  319. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  320. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  321. diffusers/pipelines/shap_e/renderer.py +3 -3
  322. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  323. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  324. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  325. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  326. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  327. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  328. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  329. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  330. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  331. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  332. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  333. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
  334. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  335. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
  336. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
  337. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
  338. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  339. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  340. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  341. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  342. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  343. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  344. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  345. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  346. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  347. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  348. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  349. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  350. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
  351. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  352. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  353. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  354. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  355. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  356. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  357. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  358. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  359. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  360. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  361. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  362. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  363. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  364. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  365. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  366. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  367. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  368. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  369. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  370. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  371. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  372. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  373. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  374. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  375. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  376. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  377. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  378. diffusers/pipelines/unclip/text_proj.py +2 -2
  379. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  380. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  381. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  382. diffusers/pipelines/visualcloze/__init__.py +52 -0
  383. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  384. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  385. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  386. diffusers/pipelines/wan/__init__.py +2 -0
  387. diffusers/pipelines/wan/pipeline_wan.py +17 -12
  388. diffusers/pipelines/wan/pipeline_wan_i2v.py +42 -20
  389. diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
  390. diffusers/pipelines/wan/pipeline_wan_video2video.py +18 -18
  391. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  392. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  393. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  394. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  395. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  396. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  397. diffusers/quantizers/__init__.py +179 -1
  398. diffusers/quantizers/base.py +6 -1
  399. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  400. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  401. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  402. diffusers/quantizers/gguf/utils.py +16 -13
  403. diffusers/quantizers/quantization_config.py +18 -16
  404. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  405. diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
  406. diffusers/schedulers/__init__.py +3 -1
  407. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  408. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  409. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  410. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  411. diffusers/schedulers/scheduling_ddim.py +8 -8
  412. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  413. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  414. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  415. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  416. diffusers/schedulers/scheduling_ddpm.py +9 -9
  417. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  418. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  419. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  420. diffusers/schedulers/scheduling_deis_multistep.py +8 -8
  421. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  422. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
  423. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  424. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  425. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  426. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
  427. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  428. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  429. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  430. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  431. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  432. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  433. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  434. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  435. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  436. diffusers/schedulers/scheduling_ipndm.py +2 -2
  437. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  438. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  439. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  440. diffusers/schedulers/scheduling_lcm.py +3 -3
  441. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  442. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  443. diffusers/schedulers/scheduling_pndm.py +4 -4
  444. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  445. diffusers/schedulers/scheduling_repaint.py +9 -9
  446. diffusers/schedulers/scheduling_sasolver.py +15 -15
  447. diffusers/schedulers/scheduling_scm.py +1 -1
  448. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  449. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  450. diffusers/schedulers/scheduling_tcd.py +3 -3
  451. diffusers/schedulers/scheduling_unclip.py +5 -5
  452. diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
  453. diffusers/schedulers/scheduling_utils.py +1 -1
  454. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  455. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  456. diffusers/training_utils.py +13 -5
  457. diffusers/utils/__init__.py +5 -0
  458. diffusers/utils/accelerate_utils.py +1 -1
  459. diffusers/utils/doc_utils.py +1 -1
  460. diffusers/utils/dummy_pt_objects.py +120 -0
  461. diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
  462. diffusers/utils/dynamic_modules_utils.py +21 -3
  463. diffusers/utils/export_utils.py +1 -1
  464. diffusers/utils/import_utils.py +81 -18
  465. diffusers/utils/logging.py +1 -1
  466. diffusers/utils/outputs.py +2 -1
  467. diffusers/utils/peft_utils.py +91 -8
  468. diffusers/utils/state_dict_utils.py +20 -3
  469. diffusers/utils/testing_utils.py +59 -7
  470. diffusers/utils/torch_utils.py +25 -5
  471. diffusers/video_processor.py +2 -2
  472. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/METADATA +3 -3
  473. diffusers-0.34.0.dist-info/RECORD +639 -0
  474. diffusers-0.33.0.dist-info/RECORD +0 -608
  475. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
  476. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/WHEEL +0 -0
  477. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
  478. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,6 @@ import html
16
16
  import inspect
17
17
  from typing import Any, Callable, Dict, List, Optional, Union
18
18
 
19
- import ftfy
20
19
  import regex as re
21
20
  import torch
22
21
  from PIL import Image
@@ -26,7 +25,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
26
25
  from ...loaders import WanLoraLoaderMixin
27
26
  from ...models import AutoencoderKLWan, WanTransformer3DModel
28
27
  from ...schedulers import FlowMatchEulerDiscreteScheduler
29
- from ...utils import is_torch_xla_available, logging, replace_example_docstring
28
+ from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
30
29
  from ...utils.torch_utils import randn_tensor
31
30
  from ...video_processor import VideoProcessor
32
31
  from ..pipeline_utils import DiffusionPipeline
@@ -42,6 +41,9 @@ else:
42
41
 
43
42
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
43
 
44
+ if is_ftfy_available():
45
+ import ftfy
46
+
45
47
 
46
48
  EXAMPLE_DOC_STRING = """
47
49
  Examples:
@@ -417,12 +419,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
417
419
  )
418
420
 
419
421
  if latents is None:
420
- if isinstance(generator, list):
421
- init_latents = [
422
- retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
423
- ]
424
- else:
425
- init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
422
+ init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video]
426
423
 
427
424
  init_latents = torch.cat(init_latents, dim=0).to(dtype)
428
425
 
@@ -439,7 +436,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
439
436
  if hasattr(self.scheduler, "add_noise"):
440
437
  latents = self.scheduler.add_noise(init_latents, noise, timestep)
441
438
  else:
442
- latents = self.scheduelr.scale_noise(init_latents, timestep, noise)
439
+ latents = self.scheduler.scale_noise(init_latents, timestep, noise)
443
440
  else:
444
441
  latents = latents.to(device)
445
442
 
@@ -511,7 +508,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
511
508
 
512
509
  Args:
513
510
  prompt (`str` or `List[str]`, *optional*):
514
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
511
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
515
512
  instead.
516
513
  height (`int`, defaults to `480`):
517
514
  The height in pixels of the generated image.
@@ -523,11 +520,13 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
523
520
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
524
521
  expense of slower inference.
525
522
  guidance_scale (`float`, defaults to `5.0`):
526
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
527
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
528
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
529
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
530
- usually at the expense of lower image quality.
523
+ Guidance scale as defined in [Classifier-Free Diffusion
524
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
525
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
526
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
527
+ the text `prompt`, usually at the expense of lower image quality.
528
+ strength (`float`, defaults to `0.8`):
529
+ Higher strength leads to more differences between original image and generated video.
531
530
  num_videos_per_prompt (`int`, *optional*, defaults to 1):
532
531
  The number of images to generate per prompt.
533
532
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -540,7 +539,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
540
539
  prompt_embeds (`torch.Tensor`, *optional*):
541
540
  Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
542
541
  provided, text embeddings are generated from the `prompt` input argument.
543
- output_type (`str`, *optional*, defaults to `"pil"`):
542
+ output_type (`str`, *optional*, defaults to `"np"`):
544
543
  The output format of the generated image. Choose between `PIL.Image` or `np.array`.
545
544
  return_dict (`bool`, *optional*, defaults to `True`):
546
545
  Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
@@ -557,8 +556,9 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
557
556
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
558
557
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
559
558
  `._callback_tensor_inputs` attribute of your pipeline class.
560
- autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
561
- The dtype to use for the torch.amp.autocast.
559
+ max_sequence_length (`int`, defaults to `512`):
560
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
561
+ truncated. If the prompt is shorter, it will be padded to this length.
562
562
 
563
563
  Examples:
564
564
 
@@ -1,5 +1,5 @@
1
1
  # Copyright (c) 2022 Dominic Rampas MIT License
2
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
1
1
  # Copyright (c) 2023 Dominic Rampas MIT License
2
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
1
1
  # Copyright (c) 2023 Dominic Rampas MIT License
2
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -21,7 +21,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
21
21
  from ...schedulers import DDPMWuerstchenScheduler
22
22
  from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
23
23
  from ...utils.torch_utils import randn_tensor
24
- from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
24
+ from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
25
25
  from .modeling_paella_vq_model import PaellaVQModel
26
26
  from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
27
27
 
@@ -56,7 +56,7 @@ EXAMPLE_DOC_STRING = """
56
56
  """
57
57
 
58
58
 
59
- class WuerstchenDecoderPipeline(DiffusionPipeline):
59
+ class WuerstchenDecoderPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
60
60
  """
61
61
  Pipeline for generating images from the Wuerstchen model.
62
62
 
@@ -247,11 +247,11 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
247
247
  Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
248
248
  timesteps are used. Must be in descending order.
249
249
  guidance_scale (`float`, *optional*, defaults to 0.0):
250
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
251
- `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
252
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
253
- `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
254
- linked to the text `prompt`, usually at the expense of lower image quality.
250
+ Guidance scale as defined in [Classifier-Free Diffusion
251
+ Guidance](https://huggingface.co/papers/2207.12598). `decoder_guidance_scale` is defined as `w` of
252
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by
253
+ setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are
254
+ closely linked to the text `prompt`, usually at the expense of lower image quality.
255
255
  negative_prompt (`str` or `List[str]`, *optional*):
256
256
  The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
257
257
  if `decoder_guidance_scale` is less than `1`).
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -18,7 +18,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
18
18
 
19
19
  from ...schedulers import DDPMWuerstchenScheduler
20
20
  from ...utils import deprecate, replace_example_docstring
21
- from ..pipeline_utils import DiffusionPipeline
21
+ from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
22
22
  from .modeling_paella_vq_model import PaellaVQModel
23
23
  from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
24
24
  from .modeling_wuerstchen_prior import WuerstchenPrior
@@ -40,7 +40,7 @@ TEXT2IMAGE_EXAMPLE_DOC_STRING = """
40
40
  """
41
41
 
42
42
 
43
- class WuerstchenCombinedPipeline(DiffusionPipeline):
43
+ class WuerstchenCombinedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
44
44
  """
45
45
  Combined Pipeline for text-to-image generation using Wuerstchen
46
46
 
@@ -68,6 +68,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
68
68
  The scheduler to be used for prior pipeline.
69
69
  """
70
70
 
71
+ _last_supported_version = "0.33.1"
71
72
  _load_connected_pipes = True
72
73
 
73
74
  def __init__(
@@ -112,7 +113,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
112
113
  def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
113
114
  self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
114
115
 
115
- def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
116
+ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
116
117
  r"""
117
118
  Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
118
119
  to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
@@ -122,7 +123,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
122
123
  self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
123
124
  self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
124
125
 
125
- def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
126
+ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
126
127
  r"""
127
128
  Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
128
129
  Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
@@ -190,11 +191,11 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
190
191
  width (`int`, *optional*, defaults to 512):
191
192
  The width in pixels of the generated image.
192
193
  prior_guidance_scale (`float`, *optional*, defaults to 4.0):
193
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
194
- `prior_guidance_scale` is defined as `w` of equation 2. of [Imagen
195
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
196
- `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked
197
- to the text `prompt`, usually at the expense of lower image quality.
194
+ Guidance scale as defined in [Classifier-Free Diffusion
195
+ Guidance](https://huggingface.co/papers/2207.12598). `prior_guidance_scale` is defined as `w` of
196
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by
197
+ setting `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are
198
+ closely linked to the text `prompt`, usually at the expense of lower image quality.
198
199
  prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 60):
199
200
  The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the
200
201
  expense of slower inference. For more specific timestep spacing, you can pass customized
@@ -210,11 +211,11 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
210
211
  Custom timesteps to use for the denoising process for the decoder. If not defined, equal spaced
211
212
  `num_inference_steps` timesteps are used. Must be in descending order.
212
213
  decoder_guidance_scale (`float`, *optional*, defaults to 0.0):
213
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
214
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
215
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
216
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
217
- usually at the expense of lower image quality.
214
+ Guidance scale as defined in [Classifier-Free Diffusion
215
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
216
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
217
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
218
+ the text `prompt`, usually at the expense of lower image quality.
218
219
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
219
220
  One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
220
221
  to make generation deterministic.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -325,11 +325,11 @@ class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin)
325
325
  Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
326
326
  timesteps are used. Must be in descending order.
327
327
  guidance_scale (`float`, *optional*, defaults to 8.0):
328
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
329
- `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
330
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
331
- `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
332
- linked to the text `prompt`, usually at the expense of lower image quality.
328
+ Guidance scale as defined in [Classifier-Free Diffusion
329
+ Guidance](https://huggingface.co/papers/2207.12598). `decoder_guidance_scale` is defined as `w` of
330
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by
331
+ setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are
332
+ closely linked to the text `prompt`, usually at the expense of lower image quality.
333
333
  negative_prompt (`str` or `List[str]`, *optional*):
334
334
  The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
335
335
  if `decoder_guidance_scale` is less than `1`).
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,5 +12,183 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import inspect
16
+ from typing import Dict, List, Optional, Union
17
+
18
+ from ..utils import is_transformers_available, logging
15
19
  from .auto import DiffusersAutoQuantizer
16
20
  from .base import DiffusersQuantizer
21
+ from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin
22
+
23
+
24
+ try:
25
+ from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
26
+ except ImportError:
27
+
28
+ class TransformersQuantConfigMixin:
29
+ pass
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class PipelineQuantizationConfig:
36
+ """
37
+ Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].
38
+
39
+ Args:
40
+ quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
41
+ is available to both `diffusers` and `transformers`.
42
+ quant_kwargs (`dict`): Params to initialize the quantization backend class.
43
+ components_to_quantize (`list`): Components of a pipeline to be quantized.
44
+ quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
45
+ components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
46
+ and `components_to_quantize`.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ quant_backend: str = None,
52
+ quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
53
+ components_to_quantize: Optional[List[str]] = None,
54
+ quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
55
+ ):
56
+ self.quant_backend = quant_backend
57
+ # Initialize kwargs to be {} to set to the defaults.
58
+ self.quant_kwargs = quant_kwargs or {}
59
+ self.components_to_quantize = components_to_quantize
60
+ self.quant_mapping = quant_mapping
61
+
62
+ self.post_init()
63
+
64
+ def post_init(self):
65
+ quant_mapping = self.quant_mapping
66
+ self.is_granular = True if quant_mapping is not None else False
67
+
68
+ self._validate_init_args()
69
+
70
+ def _validate_init_args(self):
71
+ if self.quant_backend and self.quant_mapping:
72
+ raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")
73
+
74
+ if not self.quant_mapping and not self.quant_backend:
75
+ raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")
76
+
77
+ if not self.quant_kwargs and not self.quant_mapping:
78
+ raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")
79
+
80
+ if self.quant_backend is not None:
81
+ self._validate_init_kwargs_in_backends()
82
+
83
+ if self.quant_mapping is not None:
84
+ self._validate_quant_mapping_args()
85
+
86
+ def _validate_init_kwargs_in_backends(self):
87
+ quant_backend = self.quant_backend
88
+
89
+ self._check_backend_availability(quant_backend)
90
+
91
+ quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
92
+
93
+ if quant_config_mapping_transformers is not None:
94
+ init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
95
+ init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
96
+ else:
97
+ init_kwargs_transformers = None
98
+
99
+ init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
100
+ init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}
101
+
102
+ if init_kwargs_transformers != init_kwargs_diffusers:
103
+ raise ValueError(
104
+ "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
105
+ f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how "
106
+ "this mapping would look like."
107
+ )
108
+
109
+ def _validate_quant_mapping_args(self):
110
+ quant_mapping = self.quant_mapping
111
+ transformers_map, diffusers_map = self._get_quant_config_list()
112
+
113
+ available_transformers = list(transformers_map.values()) if transformers_map else None
114
+ available_diffusers = list(diffusers_map.values())
115
+
116
+ for module_name, config in quant_mapping.items():
117
+ if any(isinstance(config, cfg) for cfg in available_diffusers):
118
+ continue
119
+
120
+ if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers):
121
+ continue
122
+
123
+ if available_transformers:
124
+ raise ValueError(
125
+ f"Provided config for module_name={module_name} could not be found. "
126
+ f"Available diffusers configs: {available_diffusers}; "
127
+ f"Available transformers configs: {available_transformers}."
128
+ )
129
+ else:
130
+ raise ValueError(
131
+ f"Provided config for module_name={module_name} could not be found. "
132
+ f"Available diffusers configs: {available_diffusers}."
133
+ )
134
+
135
+ def _check_backend_availability(self, quant_backend: str):
136
+ quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
137
+
138
+ available_backends_transformers = (
139
+ list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
140
+ )
141
+ available_backends_diffusers = list(quant_config_mapping_diffusers.keys())
142
+
143
+ if (
144
+ available_backends_transformers and quant_backend not in available_backends_transformers
145
+ ) or quant_backend not in quant_config_mapping_diffusers:
146
+ error_message = f"Provided quant_backend={quant_backend} was not found."
147
+ if available_backends_transformers:
148
+ error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
149
+ error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
150
+ raise ValueError(error_message)
151
+
152
+ def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
153
+ quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
154
+
155
+ quant_mapping = self.quant_mapping
156
+ components_to_quantize = self.components_to_quantize
157
+
158
+ # Granular case
159
+ if self.is_granular and module_name in quant_mapping:
160
+ logger.debug(f"Initializing quantization config class for {module_name}.")
161
+ config = quant_mapping[module_name]
162
+ return config
163
+
164
+ # Global config case
165
+ else:
166
+ should_quantize = False
167
+ # Only quantize the modules requested for.
168
+ if components_to_quantize and module_name in components_to_quantize:
169
+ should_quantize = True
170
+ # No specification for `components_to_quantize` means all modules should be quantized.
171
+ elif not self.is_granular and not components_to_quantize:
172
+ should_quantize = True
173
+
174
+ if should_quantize:
175
+ logger.debug(f"Initializing quantization config class for {module_name}.")
176
+ mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
177
+ quant_config_cls = mapping_to_use[self.quant_backend]
178
+ quant_kwargs = self.quant_kwargs
179
+ return quant_config_cls(**quant_kwargs)
180
+
181
+ # Fallback: no applicable configuration found.
182
+ return None
183
+
184
+ def _get_quant_config_list(self):
185
+ if is_transformers_available():
186
+ from transformers.quantizers.auto import (
187
+ AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
188
+ )
189
+ else:
190
+ quant_config_mapping_transformers = None
191
+
192
+ from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers
193
+
194
+ return quant_config_mapping_transformers, quant_config_mapping_diffusers
@@ -199,7 +199,7 @@ class DiffusersQuantizer(ABC):
199
199
 
200
200
  def dequantize(self, model):
201
201
  """
202
- Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance. Note
202
+ Potentially dequantize the model to retrieve the original model, with some loss in accuracy / performance. Note
203
203
  not all quantization schemes support this.
204
204
  """
205
205
  model = self._dequantize(model)
@@ -227,3 +227,8 @@ class DiffusersQuantizer(ABC):
227
227
  @property
228
228
  @abstractmethod
229
229
  def is_trainable(self): ...
230
+
231
+ @property
232
+ def is_compileable(self) -> bool:
233
+ """Flag indicating whether the quantized model can be compiled"""
234
+ return False
@@ -564,6 +564,10 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
564
564
  # Because we're mandating `bitsandbytes` 0.43.3.
565
565
  return True
566
566
 
567
+ @property
568
+ def is_compileable(self) -> bool:
569
+ return True
570
+
567
571
  def _dequantize(self, model):
568
572
  from .utils import dequantize_and_replace
569
573
 
@@ -49,7 +49,7 @@ def _replace_with_bnb_linear(
49
49
  """
50
50
  Private method that wraps the recursion for module replacement.
51
51
 
52
- Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
52
+ Returns the converted model and a boolean that indicates if the conversion has been successful or not.
53
53
  """
54
54
  for name, module in model.named_children():
55
55
  if current_key_name is None:
@@ -121,8 +121,9 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
121
121
 
122
122
  References:
123
123
  * `bnb.nn.Linear8bit`: [LLM.int8(): 8-bit Matrix Multiplication for Transformers at
124
- Scale](https://arxiv.org/abs/2208.07339)
125
- * `bnb.nn.Linear4bit`: [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
124
+ Scale](https://huggingface.co/papers/2208.07339)
125
+ * `bnb.nn.Linear4bit`: [QLoRA: Efficient Finetuning of Quantized
126
+ LLMs](https://huggingface.co/papers/2305.14314)
126
127
 
127
128
  Parameters:
128
129
  model (`torch.nn.Module`):
@@ -171,9 +172,11 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torc
171
172
 
172
173
  if cls_name == "Params4bit":
173
174
  output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
174
- logger.warning_once(
175
- f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
176
- )
175
+ msg = f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
176
+ if dtype:
177
+ msg = f"The model is going to be first dequantized in {output_tensor.dtype} and type-casted to {dtype}"
178
+ output_tensor = output_tensor.to(dtype)
179
+ logger.warning_once(msg)
177
180
  return output_tensor
178
181
 
179
182
  if state.SCB is None:
@@ -221,7 +224,7 @@ def _dequantize_and_replace(
221
224
  performance drop compared to the original model before quantization - use it only for specific usecases such as
222
225
  QLoRA adapters merging.
223
226
 
224
- Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
227
+ Returns the converted model and a boolean that indicates if the conversion has been successful or not.
225
228
  """
226
229
  quant_method = quantization_config.quantization_method()
227
230
 
@@ -49,7 +49,7 @@ class GGUFQuantizer(DiffusersQuantizer):
49
49
  def validate_environment(self, *args, **kwargs):
50
50
  if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
51
51
  raise ImportError(
52
- "Loading GGUF Parameters requires `accelerate` installed in your enviroment: `pip install 'accelerate>=0.26.0'`"
52
+ "Loading GGUF Parameters requires `accelerate` installed in your environment: `pip install 'accelerate>=0.26.0'`"
53
53
  )
54
54
  if not is_gguf_available() or is_gguf_version("<", "0.10.0"):
55
55
  raise ImportError(
@@ -82,7 +82,7 @@ class GGUFQuantizer(DiffusersQuantizer):
82
82
  inferred_shape = _quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size)
83
83
  if inferred_shape != current_param_shape:
84
84
  raise ValueError(
85
- f"{param_name} has an expected quantized shape of: {inferred_shape}, but receieved shape: {loaded_param_shape}"
85
+ f"{param_name} has an expected quantized shape of: {inferred_shape}, but received shape: {loaded_param_shape}"
86
86
  )
87
87
 
88
88
  return True
@@ -146,13 +146,22 @@ class GGUFQuantizer(DiffusersQuantizer):
146
146
  def is_trainable(self) -> bool:
147
147
  return False
148
148
 
149
+ @property
150
+ def is_compileable(self) -> bool:
151
+ return True
152
+
149
153
  def _dequantize(self, model):
150
154
  is_model_on_cpu = model.device.type == "cpu"
151
155
  if is_model_on_cpu:
152
156
  logger.info(
153
- "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
157
+ "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to accelerator. After dequantization, will move the model back to CPU again to preserve the previous device."
158
+ )
159
+ device = (
160
+ torch.accelerator.current_accelerator()
161
+ if hasattr(torch, "accelerator")
162
+ else torch.cuda.current_device()
154
163
  )
155
- model.to(torch.cuda.current_device())
164
+ model.to(device)
156
165
 
157
166
  model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert)
158
167
  if is_model_on_cpu:
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team and City96. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team and City96. All rights reserved.
2
2
  # #
3
3
  # # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # # you may not use this file except in compliance with the License.
@@ -408,6 +408,18 @@ class GGUFParameter(torch.nn.Parameter):
408
408
  def as_tensor(self):
409
409
  return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad)
410
410
 
411
+ @staticmethod
412
+ def _extract_quant_type(args):
413
+ # When converting from original format checkpoints we often use splits, cats etc on tensors
414
+ # this method ensures that the returned tensor type from those operations remains GGUFParameter
415
+ # so that we preserve quant_type information
416
+ for arg in args:
417
+ if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
418
+ return arg[0].quant_type
419
+ if isinstance(arg, GGUFParameter):
420
+ return arg.quant_type
421
+ return None
422
+
411
423
  @classmethod
412
424
  def __torch_function__(cls, func, types, args=(), kwargs=None):
413
425
  if kwargs is None:
@@ -415,22 +427,13 @@ class GGUFParameter(torch.nn.Parameter):
415
427
 
416
428
  result = super().__torch_function__(func, types, args, kwargs)
417
429
 
418
- # When converting from original format checkpoints we often use splits, cats etc on tensors
419
- # this method ensures that the returned tensor type from those operations remains GGUFParameter
420
- # so that we preserve quant_type information
421
- quant_type = None
422
- for arg in args:
423
- if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
424
- quant_type = arg[0].quant_type
425
- break
426
- if isinstance(arg, GGUFParameter):
427
- quant_type = arg.quant_type
428
- break
429
430
  if isinstance(result, torch.Tensor):
431
+ quant_type = cls._extract_quant_type(args)
430
432
  return cls(result, quant_type=quant_type)
431
433
  # Handle tuples and lists
432
- elif isinstance(result, (tuple, list)):
434
+ elif type(result) in (list, tuple):
433
435
  # Preserve the original type (tuple or list)
436
+ quant_type = cls._extract_quant_type(args)
434
437
  wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result]
435
438
  return type(result)(wrapped)
436
439
  else: