diffusers 0.33.1__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (478) hide show
  1. diffusers/__init__.py +48 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/dependency_versions_check.py +1 -1
  7. diffusers/dependency_versions_table.py +1 -1
  8. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  9. diffusers/hooks/faster_cache.py +2 -2
  10. diffusers/hooks/group_offloading.py +128 -29
  11. diffusers/hooks/hooks.py +2 -2
  12. diffusers/hooks/layerwise_casting.py +3 -3
  13. diffusers/hooks/pyramid_attention_broadcast.py +1 -1
  14. diffusers/image_processor.py +7 -2
  15. diffusers/loaders/__init__.py +4 -0
  16. diffusers/loaders/ip_adapter.py +5 -14
  17. diffusers/loaders/lora_base.py +212 -111
  18. diffusers/loaders/lora_conversion_utils.py +275 -34
  19. diffusers/loaders/lora_pipeline.py +1554 -819
  20. diffusers/loaders/peft.py +52 -109
  21. diffusers/loaders/single_file.py +2 -2
  22. diffusers/loaders/single_file_model.py +20 -4
  23. diffusers/loaders/single_file_utils.py +225 -5
  24. diffusers/loaders/textual_inversion.py +3 -2
  25. diffusers/loaders/transformer_flux.py +1 -1
  26. diffusers/loaders/transformer_sd3.py +2 -2
  27. diffusers/loaders/unet.py +2 -16
  28. diffusers/loaders/unet_loader_utils.py +1 -1
  29. diffusers/loaders/utils.py +1 -1
  30. diffusers/models/__init__.py +15 -1
  31. diffusers/models/activations.py +5 -5
  32. diffusers/models/adapter.py +2 -3
  33. diffusers/models/attention.py +4 -4
  34. diffusers/models/attention_flax.py +10 -10
  35. diffusers/models/attention_processor.py +14 -10
  36. diffusers/models/auto_model.py +47 -10
  37. diffusers/models/autoencoders/__init__.py +1 -0
  38. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  39. diffusers/models/autoencoders/autoencoder_dc.py +3 -3
  40. diffusers/models/autoencoders/autoencoder_kl.py +4 -4
  41. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  42. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  43. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
  44. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  45. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  46. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  47. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  48. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  49. diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
  50. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  52. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  53. diffusers/models/autoencoders/vae.py +13 -2
  54. diffusers/models/autoencoders/vq_model.py +2 -2
  55. diffusers/models/cache_utils.py +1 -1
  56. diffusers/models/controlnet.py +1 -1
  57. diffusers/models/controlnet_flux.py +1 -1
  58. diffusers/models/controlnet_sd3.py +1 -1
  59. diffusers/models/controlnet_sparsectrl.py +1 -1
  60. diffusers/models/controlnets/__init__.py +1 -0
  61. diffusers/models/controlnets/controlnet.py +3 -3
  62. diffusers/models/controlnets/controlnet_flax.py +1 -1
  63. diffusers/models/controlnets/controlnet_flux.py +16 -15
  64. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  65. diffusers/models/controlnets/controlnet_sana.py +290 -0
  66. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  67. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  68. diffusers/models/controlnets/controlnet_union.py +1 -1
  69. diffusers/models/controlnets/controlnet_xs.py +7 -7
  70. diffusers/models/controlnets/multicontrolnet.py +4 -5
  71. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  72. diffusers/models/downsampling.py +2 -2
  73. diffusers/models/embeddings.py +10 -12
  74. diffusers/models/embeddings_flax.py +2 -2
  75. diffusers/models/lora.py +3 -3
  76. diffusers/models/modeling_utils.py +44 -14
  77. diffusers/models/normalization.py +4 -4
  78. diffusers/models/resnet.py +2 -2
  79. diffusers/models/resnet_flax.py +1 -1
  80. diffusers/models/transformers/__init__.py +5 -0
  81. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  82. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  83. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  84. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  85. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  86. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  87. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  88. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  89. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  90. diffusers/models/transformers/prior_transformer.py +1 -1
  91. diffusers/models/transformers/sana_transformer.py +8 -3
  92. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  93. diffusers/models/transformers/t5_film_transformer.py +3 -3
  94. diffusers/models/transformers/transformer_2d.py +1 -1
  95. diffusers/models/transformers/transformer_allegro.py +1 -1
  96. diffusers/models/transformers/transformer_chroma.py +742 -0
  97. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  98. diffusers/models/transformers/transformer_cogview4.py +317 -25
  99. diffusers/models/transformers/transformer_cosmos.py +579 -0
  100. diffusers/models/transformers/transformer_flux.py +9 -11
  101. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  102. diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
  103. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  104. diffusers/models/transformers/transformer_ltx.py +2 -2
  105. diffusers/models/transformers/transformer_lumina2.py +1 -1
  106. diffusers/models/transformers/transformer_mochi.py +1 -1
  107. diffusers/models/transformers/transformer_omnigen.py +2 -2
  108. diffusers/models/transformers/transformer_sd3.py +7 -7
  109. diffusers/models/transformers/transformer_temporal.py +1 -1
  110. diffusers/models/transformers/transformer_wan.py +24 -8
  111. diffusers/models/transformers/transformer_wan_vace.py +393 -0
  112. diffusers/models/unets/unet_1d.py +1 -1
  113. diffusers/models/unets/unet_1d_blocks.py +1 -1
  114. diffusers/models/unets/unet_2d.py +1 -1
  115. diffusers/models/unets/unet_2d_blocks.py +1 -1
  116. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  117. diffusers/models/unets/unet_2d_condition.py +2 -2
  118. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  119. diffusers/models/unets/unet_3d_blocks.py +1 -1
  120. diffusers/models/unets/unet_3d_condition.py +3 -3
  121. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  122. diffusers/models/unets/unet_kandinsky3.py +1 -1
  123. diffusers/models/unets/unet_motion_model.py +2 -2
  124. diffusers/models/unets/unet_stable_cascade.py +1 -1
  125. diffusers/models/upsampling.py +2 -2
  126. diffusers/models/vae_flax.py +2 -2
  127. diffusers/models/vq_model.py +1 -1
  128. diffusers/pipelines/__init__.py +37 -6
  129. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  130. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  131. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  132. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  133. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  134. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  135. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  136. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  137. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  138. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  139. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  140. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  141. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
  142. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  143. diffusers/pipelines/auto_pipeline.py +6 -7
  144. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  145. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  146. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  147. diffusers/pipelines/chroma/__init__.py +49 -0
  148. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  149. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  150. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  151. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
  152. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
  153. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
  154. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
  155. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  156. diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
  157. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  158. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  159. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  160. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  161. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  162. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
  163. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  164. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  165. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  166. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  167. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  168. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
  169. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
  170. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
  171. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  172. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  173. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  174. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  175. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  176. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  177. diffusers/pipelines/cosmos/__init__.py +54 -0
  178. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  179. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  180. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  181. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  182. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  183. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  184. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  185. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  186. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  187. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  188. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  189. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  190. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  191. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  192. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  193. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  194. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  195. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  196. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  197. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  198. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  199. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  200. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  201. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  202. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  203. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  204. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
  205. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  206. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  207. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  208. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  209. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  210. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  211. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  212. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  213. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  214. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  215. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  216. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  217. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  218. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  219. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  220. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  221. diffusers/pipelines/flux/modeling_flux.py +1 -1
  222. diffusers/pipelines/flux/pipeline_flux.py +10 -17
  223. diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
  224. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
  225. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
  226. diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
  227. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
  228. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
  229. diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
  230. diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
  231. diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
  232. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  233. diffusers/pipelines/free_init_utils.py +2 -2
  234. diffusers/pipelines/free_noise_utils.py +3 -3
  235. diffusers/pipelines/hidream_image/__init__.py +47 -0
  236. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  237. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  238. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  239. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  240. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
  241. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  242. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  243. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  244. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  245. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  246. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  247. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  248. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  249. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  250. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  251. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  252. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  253. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  254. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  255. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  256. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  257. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  258. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  259. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  260. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  261. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  262. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  263. diffusers/pipelines/kolors/text_encoder.py +3 -3
  264. diffusers/pipelines/kolors/tokenizer.py +1 -1
  265. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  266. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  267. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  268. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  269. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  270. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  271. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  272. diffusers/pipelines/ltx/__init__.py +4 -0
  273. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  274. diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
  275. diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
  276. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
  277. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  278. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  279. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  280. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  281. diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
  282. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  283. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  284. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  285. diffusers/pipelines/onnx_utils.py +15 -2
  286. diffusers/pipelines/pag/pag_utils.py +2 -2
  287. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  288. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  289. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  290. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  291. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  292. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  293. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  294. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  295. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  296. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  297. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  298. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  299. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  300. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  301. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  302. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  303. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  304. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  305. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  306. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  307. diffusers/pipelines/pipeline_flax_utils.py +3 -4
  308. diffusers/pipelines/pipeline_loading_utils.py +89 -13
  309. diffusers/pipelines/pipeline_utils.py +105 -33
  310. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
  311. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
  312. diffusers/pipelines/sana/__init__.py +4 -0
  313. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  314. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  315. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  316. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  317. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  318. diffusers/pipelines/shap_e/camera.py +1 -1
  319. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  320. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  321. diffusers/pipelines/shap_e/renderer.py +3 -3
  322. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  323. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  324. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  325. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  326. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  327. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  328. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  329. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  330. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  331. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  332. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  333. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
  334. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  335. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
  336. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
  337. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
  338. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  339. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  340. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  341. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  342. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  343. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  344. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  345. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  346. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  347. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  348. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  349. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  350. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
  351. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  352. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  353. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  354. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  355. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  356. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  357. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  358. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  359. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  360. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  361. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  362. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  363. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  364. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  365. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  366. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  367. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  368. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  369. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  370. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  371. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  372. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  373. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  374. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  375. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  376. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  377. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  378. diffusers/pipelines/unclip/text_proj.py +2 -2
  379. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  380. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  381. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  382. diffusers/pipelines/visualcloze/__init__.py +52 -0
  383. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  384. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  385. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  386. diffusers/pipelines/wan/__init__.py +2 -0
  387. diffusers/pipelines/wan/pipeline_wan.py +13 -10
  388. diffusers/pipelines/wan/pipeline_wan_i2v.py +38 -18
  389. diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
  390. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  391. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  392. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  393. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  394. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  395. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  396. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  397. diffusers/quantizers/__init__.py +179 -1
  398. diffusers/quantizers/base.py +6 -1
  399. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  400. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  401. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  402. diffusers/quantizers/gguf/utils.py +16 -13
  403. diffusers/quantizers/quantization_config.py +18 -16
  404. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  405. diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
  406. diffusers/schedulers/__init__.py +3 -1
  407. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  408. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  409. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  410. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  411. diffusers/schedulers/scheduling_ddim.py +8 -8
  412. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  413. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  414. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  415. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  416. diffusers/schedulers/scheduling_ddpm.py +9 -9
  417. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  418. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  419. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  420. diffusers/schedulers/scheduling_deis_multistep.py +8 -8
  421. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  422. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
  423. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  424. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  425. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  426. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
  427. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  428. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  429. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  430. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  431. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  432. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  433. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  434. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  435. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  436. diffusers/schedulers/scheduling_ipndm.py +2 -2
  437. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  438. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  439. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  440. diffusers/schedulers/scheduling_lcm.py +3 -3
  441. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  442. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  443. diffusers/schedulers/scheduling_pndm.py +4 -4
  444. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  445. diffusers/schedulers/scheduling_repaint.py +9 -9
  446. diffusers/schedulers/scheduling_sasolver.py +15 -15
  447. diffusers/schedulers/scheduling_scm.py +1 -1
  448. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  449. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  450. diffusers/schedulers/scheduling_tcd.py +3 -3
  451. diffusers/schedulers/scheduling_unclip.py +5 -5
  452. diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
  453. diffusers/schedulers/scheduling_utils.py +1 -1
  454. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  455. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  456. diffusers/training_utils.py +13 -5
  457. diffusers/utils/__init__.py +5 -0
  458. diffusers/utils/accelerate_utils.py +1 -1
  459. diffusers/utils/doc_utils.py +1 -1
  460. diffusers/utils/dummy_pt_objects.py +120 -0
  461. diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
  462. diffusers/utils/dynamic_modules_utils.py +21 -3
  463. diffusers/utils/export_utils.py +1 -1
  464. diffusers/utils/import_utils.py +81 -18
  465. diffusers/utils/logging.py +1 -1
  466. diffusers/utils/outputs.py +2 -1
  467. diffusers/utils/peft_utils.py +91 -8
  468. diffusers/utils/state_dict_utils.py +20 -3
  469. diffusers/utils/testing_utils.py +59 -7
  470. diffusers/utils/torch_utils.py +25 -5
  471. diffusers/video_processor.py +2 -2
  472. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/METADATA +70 -55
  473. diffusers-0.34.0.dist-info/RECORD +639 -0
  474. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/WHEEL +1 -1
  475. diffusers-0.33.1.dist-info/RECORD +0 -608
  476. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
  477. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
  478. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
1
+ # Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -20,12 +20,12 @@ import torch.nn as nn
20
20
 
21
21
  from ...configuration_utils import ConfigMixin, register_to_config
22
22
  from ...loaders import PeftAdapterMixin
23
- from ...models.attention_processor import AttentionProcessor
24
- from ...models.modeling_utils import ModelMixin
25
23
  from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
24
+ from ..attention_processor import AttentionProcessor
26
25
  from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
27
26
  from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
28
27
  from ..modeling_outputs import Transformer2DModelOutput
28
+ from ..modeling_utils import ModelMixin
29
29
  from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
30
30
 
31
31
 
@@ -430,7 +430,7 @@ class FluxMultiControlNetModel(ModelMixin):
430
430
  ) -> Union[FluxControlNetOutput, Tuple]:
431
431
  # ControlNet-Union with multiple conditions
432
432
  # only load one ControlNet for saving memories
433
- if len(self.nets) == 1 and self.nets[0].union:
433
+ if len(self.nets) == 1:
434
434
  controlnet = self.nets[0]
435
435
 
436
436
  for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
@@ -454,17 +454,18 @@ class FluxMultiControlNetModel(ModelMixin):
454
454
  control_block_samples = block_samples
455
455
  control_single_block_samples = single_block_samples
456
456
  else:
457
- control_block_samples = [
458
- control_block_sample + block_sample
459
- for control_block_sample, block_sample in zip(control_block_samples, block_samples)
460
- ]
461
-
462
- control_single_block_samples = [
463
- control_single_block_sample + block_sample
464
- for control_single_block_sample, block_sample in zip(
465
- control_single_block_samples, single_block_samples
466
- )
467
- ]
457
+ if block_samples is not None and control_block_samples is not None:
458
+ control_block_samples = [
459
+ control_block_sample + block_sample
460
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
461
+ ]
462
+ if single_block_samples is not None and control_single_block_samples is not None:
463
+ control_single_block_samples = [
464
+ control_single_block_sample + block_sample
465
+ for control_single_block_sample, block_sample in zip(
466
+ control_single_block_samples, single_block_samples
467
+ )
468
+ ]
468
469
 
469
470
  # Regular Multi-ControlNets
470
471
  # load all ControlNets into memories
@@ -1,4 +1,4 @@
1
- # Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -103,7 +103,7 @@ class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
103
103
  activation_fn=activation_fn,
104
104
  ff_inner_dim=int(self.inner_dim * mlp_ratio),
105
105
  cross_attention_dim=cross_attention_dim,
106
- qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
106
+ qk_norm=True, # See https://huggingface.co/papers/2302.05442 for details.
107
107
  skip=False, # always False as it is the first half of the model
108
108
  )
109
109
  for layer in range(transformer_num_layers // 2 - 1)
@@ -0,0 +1,290 @@
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...loaders import PeftAdapterMixin
23
+ from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
24
+ from ..attention_processor import AttentionProcessor
25
+ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
26
+ from ..modeling_outputs import Transformer2DModelOutput
27
+ from ..modeling_utils import ModelMixin
28
+ from ..normalization import AdaLayerNormSingle, RMSNorm
29
+ from ..transformers.sana_transformer import SanaTransformerBlock
30
+ from .controlnet import zero_module
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ @dataclass
37
+ class SanaControlNetOutput(BaseOutput):
38
+ controlnet_block_samples: Tuple[torch.Tensor]
39
+
40
+
41
+ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
42
+ _supports_gradient_checkpointing = True
43
+ _no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
44
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm"]
45
+
46
+ @register_to_config
47
+ def __init__(
48
+ self,
49
+ in_channels: int = 32,
50
+ out_channels: Optional[int] = 32,
51
+ num_attention_heads: int = 70,
52
+ attention_head_dim: int = 32,
53
+ num_layers: int = 7,
54
+ num_cross_attention_heads: Optional[int] = 20,
55
+ cross_attention_head_dim: Optional[int] = 112,
56
+ cross_attention_dim: Optional[int] = 2240,
57
+ caption_channels: int = 2304,
58
+ mlp_ratio: float = 2.5,
59
+ dropout: float = 0.0,
60
+ attention_bias: bool = False,
61
+ sample_size: int = 32,
62
+ patch_size: int = 1,
63
+ norm_elementwise_affine: bool = False,
64
+ norm_eps: float = 1e-6,
65
+ interpolation_scale: Optional[int] = None,
66
+ ) -> None:
67
+ super().__init__()
68
+
69
+ out_channels = out_channels or in_channels
70
+ inner_dim = num_attention_heads * attention_head_dim
71
+
72
+ # 1. Patch Embedding
73
+ self.patch_embed = PatchEmbed(
74
+ height=sample_size,
75
+ width=sample_size,
76
+ patch_size=patch_size,
77
+ in_channels=in_channels,
78
+ embed_dim=inner_dim,
79
+ interpolation_scale=interpolation_scale,
80
+ pos_embed_type="sincos" if interpolation_scale is not None else None,
81
+ )
82
+
83
+ # 2. Additional condition embeddings
84
+ self.time_embed = AdaLayerNormSingle(inner_dim)
85
+
86
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
87
+ self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
88
+
89
+ # 3. Transformer blocks
90
+ self.transformer_blocks = nn.ModuleList(
91
+ [
92
+ SanaTransformerBlock(
93
+ inner_dim,
94
+ num_attention_heads,
95
+ attention_head_dim,
96
+ dropout=dropout,
97
+ num_cross_attention_heads=num_cross_attention_heads,
98
+ cross_attention_head_dim=cross_attention_head_dim,
99
+ cross_attention_dim=cross_attention_dim,
100
+ attention_bias=attention_bias,
101
+ norm_elementwise_affine=norm_elementwise_affine,
102
+ norm_eps=norm_eps,
103
+ mlp_ratio=mlp_ratio,
104
+ )
105
+ for _ in range(num_layers)
106
+ ]
107
+ )
108
+
109
+ # controlnet_blocks
110
+ self.controlnet_blocks = nn.ModuleList([])
111
+
112
+ self.input_block = zero_module(nn.Linear(inner_dim, inner_dim))
113
+ for _ in range(len(self.transformer_blocks)):
114
+ controlnet_block = nn.Linear(inner_dim, inner_dim)
115
+ controlnet_block = zero_module(controlnet_block)
116
+ self.controlnet_blocks.append(controlnet_block)
117
+
118
+ self.gradient_checkpointing = False
119
+
120
+ @property
121
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
122
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
123
+ r"""
124
+ Returns:
125
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
126
+ indexed by its weight name.
127
+ """
128
+ # set recursively
129
+ processors = {}
130
+
131
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
132
+ if hasattr(module, "get_processor"):
133
+ processors[f"{name}.processor"] = module.get_processor()
134
+
135
+ for sub_name, child in module.named_children():
136
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
137
+
138
+ return processors
139
+
140
+ for name, module in self.named_children():
141
+ fn_recursive_add_processors(name, module, processors)
142
+
143
+ return processors
144
+
145
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
146
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
147
+ r"""
148
+ Sets the attention processor to use to compute attention.
149
+
150
+ Parameters:
151
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
152
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
153
+ for **all** `Attention` layers.
154
+
155
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
156
+ processor. This is strongly recommended when setting trainable attention processors.
157
+
158
+ """
159
+ count = len(self.attn_processors.keys())
160
+
161
+ if isinstance(processor, dict) and len(processor) != count:
162
+ raise ValueError(
163
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
164
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
165
+ )
166
+
167
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
168
+ if hasattr(module, "set_processor"):
169
+ if not isinstance(processor, dict):
170
+ module.set_processor(processor)
171
+ else:
172
+ module.set_processor(processor.pop(f"{name}.processor"))
173
+
174
+ for sub_name, child in module.named_children():
175
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
176
+
177
+ for name, module in self.named_children():
178
+ fn_recursive_attn_processor(name, module, processor)
179
+
180
+ def forward(
181
+ self,
182
+ hidden_states: torch.Tensor,
183
+ encoder_hidden_states: torch.Tensor,
184
+ timestep: torch.LongTensor,
185
+ controlnet_cond: torch.Tensor,
186
+ conditioning_scale: float = 1.0,
187
+ encoder_attention_mask: Optional[torch.Tensor] = None,
188
+ attention_mask: Optional[torch.Tensor] = None,
189
+ attention_kwargs: Optional[Dict[str, Any]] = None,
190
+ return_dict: bool = True,
191
+ ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
192
+ if attention_kwargs is not None:
193
+ attention_kwargs = attention_kwargs.copy()
194
+ lora_scale = attention_kwargs.pop("scale", 1.0)
195
+ else:
196
+ lora_scale = 1.0
197
+
198
+ if USE_PEFT_BACKEND:
199
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
200
+ scale_lora_layers(self, lora_scale)
201
+ else:
202
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
203
+ logger.warning(
204
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
205
+ )
206
+
207
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
208
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
209
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
210
+ # expects mask of shape:
211
+ # [batch, key_tokens]
212
+ # adds singleton query_tokens dimension:
213
+ # [batch, 1, key_tokens]
214
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
215
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
216
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
217
+ if attention_mask is not None and attention_mask.ndim == 2:
218
+ # assume that mask is expressed as:
219
+ # (1 = keep, 0 = discard)
220
+ # convert mask into a bias that can be added to attention scores:
221
+ # (keep = +0, discard = -10000.0)
222
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
223
+ attention_mask = attention_mask.unsqueeze(1)
224
+
225
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
226
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
227
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
228
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
229
+
230
+ # 1. Input
231
+ batch_size, num_channels, height, width = hidden_states.shape
232
+ p = self.config.patch_size
233
+ post_patch_height, post_patch_width = height // p, width // p
234
+
235
+ hidden_states = self.patch_embed(hidden_states)
236
+ hidden_states = hidden_states + self.input_block(self.patch_embed(controlnet_cond.to(hidden_states.dtype)))
237
+
238
+ timestep, embedded_timestep = self.time_embed(
239
+ timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
240
+ )
241
+
242
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
243
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
244
+
245
+ encoder_hidden_states = self.caption_norm(encoder_hidden_states)
246
+
247
+ # 2. Transformer blocks
248
+ block_res_samples = ()
249
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
250
+ for block in self.transformer_blocks:
251
+ hidden_states = self._gradient_checkpointing_func(
252
+ block,
253
+ hidden_states,
254
+ attention_mask,
255
+ encoder_hidden_states,
256
+ encoder_attention_mask,
257
+ timestep,
258
+ post_patch_height,
259
+ post_patch_width,
260
+ )
261
+ block_res_samples = block_res_samples + (hidden_states,)
262
+ else:
263
+ for block in self.transformer_blocks:
264
+ hidden_states = block(
265
+ hidden_states,
266
+ attention_mask,
267
+ encoder_hidden_states,
268
+ encoder_attention_mask,
269
+ timestep,
270
+ post_patch_height,
271
+ post_patch_width,
272
+ )
273
+ block_res_samples = block_res_samples + (hidden_states,)
274
+
275
+ # 3. ControlNet blocks
276
+ controlnet_block_res_samples = ()
277
+ for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
278
+ block_res_sample = controlnet_block(block_res_sample)
279
+ controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
280
+
281
+ if USE_PEFT_BACKEND:
282
+ # remove `lora_scale` from each PEFT layer
283
+ unscale_lora_layers(self, lora_scale)
284
+
285
+ controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
286
+
287
+ if not return_dict:
288
+ return (controlnet_block_res_samples,)
289
+
290
+ return SanaControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
1
+ # Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -96,7 +96,7 @@ class SparseControlNetConditioningEmbedding(nn.Module):
96
96
  class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
97
97
  """
98
98
  A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion
99
- Models](https://arxiv.org/abs/2311.16933).
99
+ Models](https://huggingface.co/papers/2311.16933).
100
100
 
101
101
  Args:
102
102
  in_channels (`int`, defaults to 4):
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -734,17 +734,17 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
734
734
  unet (`UNet2DConditionModel`):
735
735
  The UNet model we want to control.
736
736
  controlnet (`ControlNetXSAdapter`):
737
- The ConntrolNet-XS adapter with which the UNet will be fused. If none is given, a new ConntrolNet-XS
737
+ The ControlNet-XS adapter with which the UNet will be fused. If none is given, a new ControlNet-XS
738
738
  adapter will be created.
739
739
  size_ratio (float, *optional*, defaults to `None`):
740
- Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
740
+ Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
741
741
  ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`):
742
- Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details,
742
+ Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details,
743
743
  where this parameter is called `block_out_channels`.
744
744
  time_embedding_mix (`float`, *optional*, defaults to None):
745
- Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
745
+ Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
746
746
  ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`):
747
- Passed to the `init` of the new controlent if no controlent was given.
747
+ Passed to the `init` of the new controlnet if no controlnet was given.
748
748
  """
749
749
  if controlnet is None:
750
750
  controlnet = ControlNetXSAdapter.from_unet(
@@ -942,7 +942,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
942
942
 
943
943
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
944
944
  def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
945
- r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
945
+ r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.
946
946
 
947
947
  The suffixes after the scaling factors represent the stage blocks where they are being applied.
948
948
 
@@ -4,9 +4,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
4
  import torch
5
5
  from torch import nn
6
6
 
7
- from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput
8
- from ...models.modeling_utils import ModelMixin
9
7
  from ...utils import logging
8
+ from ..controlnets.controlnet import ControlNetModel, ControlNetOutput
9
+ from ..modeling_utils import ModelMixin
10
10
 
11
11
 
12
12
  logger = logging.get_logger(__name__)
@@ -130,9 +130,8 @@ class MultiControlNetModel(ModelMixin):
130
130
  A path to a *directory* containing model weights saved using
131
131
  [`~models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained`], e.g.,
132
132
  `./my_model_directory/controlnet`.
133
- torch_dtype (`str` or `torch.dtype`, *optional*):
134
- Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
135
- will be automatically derived from the model's weights.
133
+ torch_dtype (`torch.dtype`, *optional*):
134
+ Override the default `torch.dtype` and load the model under this dtype.
136
135
  output_loading_info(`bool`, *optional*, defaults to `False`):
137
136
  Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
138
137
  device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
@@ -4,10 +4,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
4
  import torch
5
5
  from torch import nn
6
6
 
7
- from ...models.controlnets.controlnet import ControlNetOutput
8
- from ...models.controlnets.controlnet_union import ControlNetUnionModel
9
- from ...models.modeling_utils import ModelMixin
10
7
  from ...utils import logging
8
+ from ..controlnets.controlnet import ControlNetOutput
9
+ from ..controlnets.controlnet_union import ControlNetUnionModel
10
+ from ..modeling_utils import ModelMixin
11
11
 
12
12
 
13
13
  logger = logging.get_logger(__name__)
@@ -143,9 +143,8 @@ class MultiControlNetUnionModel(ModelMixin):
143
143
  A path to a *directory* containing model weights saved using
144
144
  [`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.save_pretrained`], e.g.,
145
145
  `./my_model_directory/controlnet`.
146
- torch_dtype (`str` or `torch.dtype`, *optional*):
147
- Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
148
- will be automatically derived from the model's weights.
146
+ torch_dtype (`torch.dtype`, *optional*):
147
+ Override the default `torch.dtype` and load the model under this dtype.
149
148
  output_loading_info(`bool`, *optional*, defaults to `False`):
150
149
  Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
151
150
  device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -286,7 +286,7 @@ class KDownsample2D(nn.Module):
286
286
 
287
287
 
288
288
  class CogVideoXDownsample3D(nn.Module):
289
- # Todo: Wait for paper relase.
289
+ # Todo: Wait for paper release.
290
290
  r"""
291
291
  A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
292
292
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -31,7 +31,7 @@ def get_timestep_embedding(
31
31
  downscale_freq_shift: float = 1,
32
32
  scale: float = 1,
33
33
  max_period: int = 10000,
34
- ):
34
+ ) -> torch.Tensor:
35
35
  """
36
36
  This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
37
37
 
@@ -97,7 +97,7 @@ def get_3d_sincos_pos_embed(
97
97
  The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
98
98
  spatial dimensions (height and width).
99
99
  temporal_size (`int`):
100
- The temporal dimension of postional embeddings (number of frames).
100
+ The temporal dimension of positional embeddings (number of frames).
101
101
  spatial_interpolation_scale (`float`, defaults to 1.0):
102
102
  Scale factor for spatial grid interpolation.
103
103
  temporal_interpolation_scale (`float`, defaults to 1.0):
@@ -169,7 +169,7 @@ def _get_3d_sincos_pos_embed_np(
169
169
  The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
170
170
  spatial dimensions (height and width).
171
171
  temporal_size (`int`):
172
- The temporal dimension of postional embeddings (number of frames).
172
+ The temporal dimension of positional embeddings (number of frames).
173
173
  spatial_interpolation_scale (`float`, defaults to 1.0):
174
174
  Scale factor for spatial grid interpolation.
175
175
  temporal_interpolation_scale (`float`, defaults to 1.0):
@@ -1149,9 +1149,7 @@ def get_1d_rotary_pos_embed(
1149
1149
 
1150
1150
  theta = theta * ntk_factor
1151
1151
  freqs = (
1152
- 1.0
1153
- / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
1154
- / linear_factor
1152
+ 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
1155
1153
  ) # [D/2]
1156
1154
  freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
1157
1155
  is_npu = freqs.device.type == "npu"
@@ -1201,11 +1199,11 @@ def apply_rotary_emb(
1201
1199
 
1202
1200
  if use_real_unbind_dim == -1:
1203
1201
  # Used for flux, cogvideox, hunyuan-dit
1204
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
1202
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
1205
1203
  x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
1206
1204
  elif use_real_unbind_dim == -2:
1207
- # Used for Stable Audio, OmniGen and CogView4
1208
- x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
1205
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
1206
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
1209
1207
  x_rotated = torch.cat([-x_imag, x_real], dim=-1)
1210
1208
  else:
1211
1209
  raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
@@ -1327,7 +1325,7 @@ class Timesteps(nn.Module):
1327
1325
  self.downscale_freq_shift = downscale_freq_shift
1328
1326
  self.scale = scale
1329
1327
 
1330
- def forward(self, timesteps):
1328
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
1331
1329
  t_emb = get_timestep_embedding(
1332
1330
  timesteps,
1333
1331
  self.num_channels,
@@ -1401,7 +1399,7 @@ class ImagePositionalEmbeddings(nn.Module):
1401
1399
  Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
1402
1400
  height and width of the latent space.
1403
1401
 
1404
- For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
1402
+ For more details, see figure 10 of the dall-e paper: https://huggingface.co/papers/2102.12092
1405
1403
 
1406
1404
  For VQ-diffusion:
1407
1405
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -89,7 +89,7 @@ class FlaxTimestepEmbedding(nn.Module):
89
89
 
90
90
  class FlaxTimesteps(nn.Module):
91
91
  r"""
92
- Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
92
+ Wrapper Module for sinusoidal Time step Embeddings as described in https://huggingface.co/papers/2006.11239
93
93
 
94
94
  Args:
95
95
  dim (`int`, *optional*, defaults to `32`):
diffusers/models/lora.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -38,7 +38,7 @@ if is_transformers_available():
38
38
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
39
 
40
40
 
41
- def text_encoder_attn_modules(text_encoder):
41
+ def text_encoder_attn_modules(text_encoder: nn.Module):
42
42
  attn_modules = []
43
43
 
44
44
  if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
@@ -52,7 +52,7 @@ def text_encoder_attn_modules(text_encoder):
52
52
  return attn_modules
53
53
 
54
54
 
55
- def text_encoder_mlp_modules(text_encoder):
55
+ def text_encoder_mlp_modules(text_encoder: nn.Module):
56
56
  mlp_modules = []
57
57
 
58
58
  if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):