diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -83,8 +83,8 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
83
83
  model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
84
84
  diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
85
85
  / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
86
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder,
87
- however, no such scaling factor was used, hence the value of 1.0 as the default.
86
+ Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. For this
87
+ Autoencoder, however, no such scaling factor was used, hence the value of 1.0 as the default.
88
88
  force_upcast (`bool`, *optional*, default to `False`):
89
89
  If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
90
90
  can be fine-tuned / trained to a lower range without losing too much precision, in which case
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -255,7 +255,7 @@ class Decoder(nn.Module):
255
255
  num_layers=self.layers_per_block + 1,
256
256
  in_channels=prev_output_channel,
257
257
  out_channels=output_channel,
258
- prev_output_channel=None,
258
+ prev_output_channel=prev_output_channel,
259
259
  add_upsample=not is_final_block,
260
260
  resnet_eps=1e-6,
261
261
  resnet_act_fn=act_fn,
@@ -744,6 +744,17 @@ class DiagonalGaussianDistribution(object):
744
744
  return self.mean
745
745
 
746
746
 
747
+ class IdentityDistribution(object):
748
+ def __init__(self, parameters: torch.Tensor):
749
+ self.parameters = parameters
750
+
751
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
752
+ return self.parameters
753
+
754
+ def mode(self) -> torch.Tensor:
755
+ return self.parameters
756
+
757
+
747
758
  class EncoderTiny(nn.Module):
748
759
  r"""
749
760
  The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -66,7 +66,7 @@ class VQModel(ModelMixin, ConfigMixin):
66
66
  model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
67
67
  diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
68
68
  / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
69
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
69
+ Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
70
70
  norm_type (`str`, *optional*, defaults to `"group"`):
71
71
  Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
72
72
  """
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from contextlib import contextmanager
16
+
15
17
  from ..utils.logging import get_logger
16
18
 
17
19
 
@@ -25,6 +27,7 @@ class CacheMixin:
25
27
  Supported caching techniques:
26
28
  - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
27
29
  - [FasterCache](https://huggingface.co/papers/2410.19355)
30
+ - [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching)
28
31
  """
29
32
 
30
33
  _cache_config = None
@@ -62,8 +65,10 @@ class CacheMixin:
62
65
 
63
66
  from ..hooks import (
64
67
  FasterCacheConfig,
68
+ FirstBlockCacheConfig,
65
69
  PyramidAttentionBroadcastConfig,
66
70
  apply_faster_cache,
71
+ apply_first_block_cache,
67
72
  apply_pyramid_attention_broadcast,
68
73
  )
69
74
 
@@ -72,31 +77,36 @@ class CacheMixin:
72
77
  f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
73
78
  )
74
79
 
75
- if isinstance(config, PyramidAttentionBroadcastConfig):
76
- apply_pyramid_attention_broadcast(self, config)
77
- elif isinstance(config, FasterCacheConfig):
80
+ if isinstance(config, FasterCacheConfig):
78
81
  apply_faster_cache(self, config)
82
+ elif isinstance(config, FirstBlockCacheConfig):
83
+ apply_first_block_cache(self, config)
84
+ elif isinstance(config, PyramidAttentionBroadcastConfig):
85
+ apply_pyramid_attention_broadcast(self, config)
79
86
  else:
80
87
  raise ValueError(f"Cache config {type(config)} is not supported.")
81
88
 
82
89
  self._cache_config = config
83
90
 
84
91
  def disable_cache(self) -> None:
85
- from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
92
+ from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
86
93
  from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
94
+ from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
87
95
  from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
88
96
 
89
97
  if self._cache_config is None:
90
98
  logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
91
99
  return
92
100
 
93
- if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
94
- registry = HookRegistry.check_if_exists_or_initialize(self)
95
- registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
96
- elif isinstance(self._cache_config, FasterCacheConfig):
97
- registry = HookRegistry.check_if_exists_or_initialize(self)
101
+ registry = HookRegistry.check_if_exists_or_initialize(self)
102
+ if isinstance(self._cache_config, FasterCacheConfig):
98
103
  registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
99
104
  registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
105
+ elif isinstance(self._cache_config, FirstBlockCacheConfig):
106
+ registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
107
+ registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
108
+ elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
109
+ registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
100
110
  else:
101
111
  raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
102
112
 
@@ -106,3 +116,15 @@ class CacheMixin:
106
116
  from ..hooks import HookRegistry
107
117
 
108
118
  HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
119
+
120
+ @contextmanager
121
+ def cache_context(self, name: str):
122
+ r"""Context manager that provides additional methods for cache management."""
123
+ from ..hooks import HookRegistry
124
+
125
+ registry = HookRegistry.check_if_exists_or_initialize(self)
126
+ registry._set_context(name)
127
+
128
+ yield
129
+
130
+ registry._set_context(None)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
1
+ # Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
1
+ # Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -9,6 +9,7 @@ if is_torch_available():
9
9
  HunyuanDiT2DControlNetModel,
10
10
  HunyuanDiT2DMultiControlNetModel,
11
11
  )
12
+ from .controlnet_sana import SanaControlNetModel
12
13
  from .controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
13
14
  from .controlnet_sparsectrl import (
14
15
  SparseControlNetConditioningEmbedding,
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -63,8 +63,8 @@ class ControlNetOutput(BaseOutput):
63
63
 
64
64
  class ControlNetConditioningEmbedding(nn.Module):
65
65
  """
66
- Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
67
- [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
66
+ Quoting from https://huggingface.co/papers/2302.05543: "Stable Diffusion uses a pre-processing method similar to
67
+ VQ-GAN [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
68
68
  training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
69
69
  convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
70
70
  (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
1
+ # Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -20,12 +20,12 @@ import torch.nn as nn
20
20
 
21
21
  from ...configuration_utils import ConfigMixin, register_to_config
22
22
  from ...loaders import PeftAdapterMixin
23
- from ...models.attention_processor import AttentionProcessor
24
- from ...models.modeling_utils import ModelMixin
25
23
  from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
24
+ from ..attention_processor import AttentionProcessor
26
25
  from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
27
26
  from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
28
27
  from ..modeling_outputs import Transformer2DModelOutput
28
+ from ..modeling_utils import ModelMixin
29
29
  from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
30
30
 
31
31
 
@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
343
343
  )
344
344
  block_samples = block_samples + (hidden_states,)
345
345
 
346
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
347
-
348
346
  single_block_samples = ()
349
347
  for index_block, block in enumerate(self.single_transformer_blocks):
350
348
  if torch.is_grad_enabled() and self.gradient_checkpointing:
351
- hidden_states = self._gradient_checkpointing_func(
349
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
352
350
  block,
353
351
  hidden_states,
352
+ encoder_hidden_states,
354
353
  temb,
355
354
  image_rotary_emb,
356
355
  )
357
356
 
358
357
  else:
359
- hidden_states = block(
358
+ encoder_hidden_states, hidden_states = block(
360
359
  hidden_states=hidden_states,
360
+ encoder_hidden_states=encoder_hidden_states,
361
361
  temb=temb,
362
362
  image_rotary_emb=image_rotary_emb,
363
363
  )
364
- single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
364
+ single_block_samples = single_block_samples + (hidden_states,)
365
365
 
366
366
  # controlnet block
367
367
  controlnet_block_samples = ()
@@ -430,7 +430,7 @@ class FluxMultiControlNetModel(ModelMixin):
430
430
  ) -> Union[FluxControlNetOutput, Tuple]:
431
431
  # ControlNet-Union with multiple conditions
432
432
  # only load one ControlNet for saving memories
433
- if len(self.nets) == 1 and self.nets[0].union:
433
+ if len(self.nets) == 1:
434
434
  controlnet = self.nets[0]
435
435
 
436
436
  for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
@@ -454,17 +454,18 @@ class FluxMultiControlNetModel(ModelMixin):
454
454
  control_block_samples = block_samples
455
455
  control_single_block_samples = single_block_samples
456
456
  else:
457
- control_block_samples = [
458
- control_block_sample + block_sample
459
- for control_block_sample, block_sample in zip(control_block_samples, block_samples)
460
- ]
461
-
462
- control_single_block_samples = [
463
- control_single_block_sample + block_sample
464
- for control_single_block_sample, block_sample in zip(
465
- control_single_block_samples, single_block_samples
466
- )
467
- ]
457
+ if block_samples is not None and control_block_samples is not None:
458
+ control_block_samples = [
459
+ control_block_sample + block_sample
460
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
461
+ ]
462
+ if single_block_samples is not None and control_single_block_samples is not None:
463
+ control_single_block_samples = [
464
+ control_single_block_sample + block_sample
465
+ for control_single_block_sample, block_sample in zip(
466
+ control_single_block_samples, single_block_samples
467
+ )
468
+ ]
468
469
 
469
470
  # Regular Multi-ControlNets
470
471
  # load all ControlNets into memories
@@ -1,4 +1,4 @@
1
- # Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -103,7 +103,7 @@ class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
103
103
  activation_fn=activation_fn,
104
104
  ff_inner_dim=int(self.inner_dim * mlp_ratio),
105
105
  cross_attention_dim=cross_attention_dim,
106
- qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
106
+ qk_norm=True, # See https://huggingface.co/papers/2302.05442 for details.
107
107
  skip=False, # always False as it is the first half of the model
108
108
  )
109
109
  for layer in range(transformer_num_layers // 2 - 1)
@@ -0,0 +1,290 @@
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...loaders import PeftAdapterMixin
23
+ from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
24
+ from ..attention_processor import AttentionProcessor
25
+ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
26
+ from ..modeling_outputs import Transformer2DModelOutput
27
+ from ..modeling_utils import ModelMixin
28
+ from ..normalization import AdaLayerNormSingle, RMSNorm
29
+ from ..transformers.sana_transformer import SanaTransformerBlock
30
+ from .controlnet import zero_module
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ @dataclass
37
+ class SanaControlNetOutput(BaseOutput):
38
+ controlnet_block_samples: Tuple[torch.Tensor]
39
+
40
+
41
+ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
42
+ _supports_gradient_checkpointing = True
43
+ _no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
44
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm"]
45
+
46
+ @register_to_config
47
+ def __init__(
48
+ self,
49
+ in_channels: int = 32,
50
+ out_channels: Optional[int] = 32,
51
+ num_attention_heads: int = 70,
52
+ attention_head_dim: int = 32,
53
+ num_layers: int = 7,
54
+ num_cross_attention_heads: Optional[int] = 20,
55
+ cross_attention_head_dim: Optional[int] = 112,
56
+ cross_attention_dim: Optional[int] = 2240,
57
+ caption_channels: int = 2304,
58
+ mlp_ratio: float = 2.5,
59
+ dropout: float = 0.0,
60
+ attention_bias: bool = False,
61
+ sample_size: int = 32,
62
+ patch_size: int = 1,
63
+ norm_elementwise_affine: bool = False,
64
+ norm_eps: float = 1e-6,
65
+ interpolation_scale: Optional[int] = None,
66
+ ) -> None:
67
+ super().__init__()
68
+
69
+ out_channels = out_channels or in_channels
70
+ inner_dim = num_attention_heads * attention_head_dim
71
+
72
+ # 1. Patch Embedding
73
+ self.patch_embed = PatchEmbed(
74
+ height=sample_size,
75
+ width=sample_size,
76
+ patch_size=patch_size,
77
+ in_channels=in_channels,
78
+ embed_dim=inner_dim,
79
+ interpolation_scale=interpolation_scale,
80
+ pos_embed_type="sincos" if interpolation_scale is not None else None,
81
+ )
82
+
83
+ # 2. Additional condition embeddings
84
+ self.time_embed = AdaLayerNormSingle(inner_dim)
85
+
86
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
87
+ self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
88
+
89
+ # 3. Transformer blocks
90
+ self.transformer_blocks = nn.ModuleList(
91
+ [
92
+ SanaTransformerBlock(
93
+ inner_dim,
94
+ num_attention_heads,
95
+ attention_head_dim,
96
+ dropout=dropout,
97
+ num_cross_attention_heads=num_cross_attention_heads,
98
+ cross_attention_head_dim=cross_attention_head_dim,
99
+ cross_attention_dim=cross_attention_dim,
100
+ attention_bias=attention_bias,
101
+ norm_elementwise_affine=norm_elementwise_affine,
102
+ norm_eps=norm_eps,
103
+ mlp_ratio=mlp_ratio,
104
+ )
105
+ for _ in range(num_layers)
106
+ ]
107
+ )
108
+
109
+ # controlnet_blocks
110
+ self.controlnet_blocks = nn.ModuleList([])
111
+
112
+ self.input_block = zero_module(nn.Linear(inner_dim, inner_dim))
113
+ for _ in range(len(self.transformer_blocks)):
114
+ controlnet_block = nn.Linear(inner_dim, inner_dim)
115
+ controlnet_block = zero_module(controlnet_block)
116
+ self.controlnet_blocks.append(controlnet_block)
117
+
118
+ self.gradient_checkpointing = False
119
+
120
+ @property
121
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
122
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
123
+ r"""
124
+ Returns:
125
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
126
+ indexed by its weight name.
127
+ """
128
+ # set recursively
129
+ processors = {}
130
+
131
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
132
+ if hasattr(module, "get_processor"):
133
+ processors[f"{name}.processor"] = module.get_processor()
134
+
135
+ for sub_name, child in module.named_children():
136
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
137
+
138
+ return processors
139
+
140
+ for name, module in self.named_children():
141
+ fn_recursive_add_processors(name, module, processors)
142
+
143
+ return processors
144
+
145
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
146
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
147
+ r"""
148
+ Sets the attention processor to use to compute attention.
149
+
150
+ Parameters:
151
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
152
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
153
+ for **all** `Attention` layers.
154
+
155
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
156
+ processor. This is strongly recommended when setting trainable attention processors.
157
+
158
+ """
159
+ count = len(self.attn_processors.keys())
160
+
161
+ if isinstance(processor, dict) and len(processor) != count:
162
+ raise ValueError(
163
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
164
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
165
+ )
166
+
167
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
168
+ if hasattr(module, "set_processor"):
169
+ if not isinstance(processor, dict):
170
+ module.set_processor(processor)
171
+ else:
172
+ module.set_processor(processor.pop(f"{name}.processor"))
173
+
174
+ for sub_name, child in module.named_children():
175
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
176
+
177
+ for name, module in self.named_children():
178
+ fn_recursive_attn_processor(name, module, processor)
179
+
180
+ def forward(
181
+ self,
182
+ hidden_states: torch.Tensor,
183
+ encoder_hidden_states: torch.Tensor,
184
+ timestep: torch.LongTensor,
185
+ controlnet_cond: torch.Tensor,
186
+ conditioning_scale: float = 1.0,
187
+ encoder_attention_mask: Optional[torch.Tensor] = None,
188
+ attention_mask: Optional[torch.Tensor] = None,
189
+ attention_kwargs: Optional[Dict[str, Any]] = None,
190
+ return_dict: bool = True,
191
+ ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
192
+ if attention_kwargs is not None:
193
+ attention_kwargs = attention_kwargs.copy()
194
+ lora_scale = attention_kwargs.pop("scale", 1.0)
195
+ else:
196
+ lora_scale = 1.0
197
+
198
+ if USE_PEFT_BACKEND:
199
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
200
+ scale_lora_layers(self, lora_scale)
201
+ else:
202
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
203
+ logger.warning(
204
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
205
+ )
206
+
207
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
208
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
209
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
210
+ # expects mask of shape:
211
+ # [batch, key_tokens]
212
+ # adds singleton query_tokens dimension:
213
+ # [batch, 1, key_tokens]
214
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
215
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
216
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
217
+ if attention_mask is not None and attention_mask.ndim == 2:
218
+ # assume that mask is expressed as:
219
+ # (1 = keep, 0 = discard)
220
+ # convert mask into a bias that can be added to attention scores:
221
+ # (keep = +0, discard = -10000.0)
222
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
223
+ attention_mask = attention_mask.unsqueeze(1)
224
+
225
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
226
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
227
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
228
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
229
+
230
+ # 1. Input
231
+ batch_size, num_channels, height, width = hidden_states.shape
232
+ p = self.config.patch_size
233
+ post_patch_height, post_patch_width = height // p, width // p
234
+
235
+ hidden_states = self.patch_embed(hidden_states)
236
+ hidden_states = hidden_states + self.input_block(self.patch_embed(controlnet_cond.to(hidden_states.dtype)))
237
+
238
+ timestep, embedded_timestep = self.time_embed(
239
+ timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
240
+ )
241
+
242
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
243
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
244
+
245
+ encoder_hidden_states = self.caption_norm(encoder_hidden_states)
246
+
247
+ # 2. Transformer blocks
248
+ block_res_samples = ()
249
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
250
+ for block in self.transformer_blocks:
251
+ hidden_states = self._gradient_checkpointing_func(
252
+ block,
253
+ hidden_states,
254
+ attention_mask,
255
+ encoder_hidden_states,
256
+ encoder_attention_mask,
257
+ timestep,
258
+ post_patch_height,
259
+ post_patch_width,
260
+ )
261
+ block_res_samples = block_res_samples + (hidden_states,)
262
+ else:
263
+ for block in self.transformer_blocks:
264
+ hidden_states = block(
265
+ hidden_states,
266
+ attention_mask,
267
+ encoder_hidden_states,
268
+ encoder_attention_mask,
269
+ timestep,
270
+ post_patch_height,
271
+ post_patch_width,
272
+ )
273
+ block_res_samples = block_res_samples + (hidden_states,)
274
+
275
+ # 3. ControlNet blocks
276
+ controlnet_block_res_samples = ()
277
+ for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
278
+ block_res_sample = controlnet_block(block_res_sample)
279
+ controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
280
+
281
+ if USE_PEFT_BACKEND:
282
+ # remove `lora_scale` from each PEFT layer
283
+ unscale_lora_layers(self, lora_scale)
284
+
285
+ controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
286
+
287
+ if not return_dict:
288
+ return (controlnet_block_res_samples,)
289
+
290
+ return SanaControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
1
+ # Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -96,7 +96,7 @@ class SparseControlNetConditioningEmbedding(nn.Module):
96
96
  class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
97
97
  """
98
98
  A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion
99
- Models](https://arxiv.org/abs/2311.16933).
99
+ Models](https://huggingface.co/papers/2311.16933).
100
100
 
101
101
  Args:
102
102
  in_channels (`int`, defaults to 4):