diffusers 0.33.0__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (478) hide show
  1. diffusers/__init__.py +48 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/dependency_versions_check.py +1 -1
  7. diffusers/dependency_versions_table.py +1 -1
  8. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  9. diffusers/hooks/faster_cache.py +2 -2
  10. diffusers/hooks/group_offloading.py +128 -29
  11. diffusers/hooks/hooks.py +2 -2
  12. diffusers/hooks/layerwise_casting.py +3 -3
  13. diffusers/hooks/pyramid_attention_broadcast.py +1 -1
  14. diffusers/image_processor.py +7 -2
  15. diffusers/loaders/__init__.py +4 -0
  16. diffusers/loaders/ip_adapter.py +5 -14
  17. diffusers/loaders/lora_base.py +212 -111
  18. diffusers/loaders/lora_conversion_utils.py +275 -34
  19. diffusers/loaders/lora_pipeline.py +1554 -819
  20. diffusers/loaders/peft.py +52 -109
  21. diffusers/loaders/single_file.py +2 -2
  22. diffusers/loaders/single_file_model.py +20 -4
  23. diffusers/loaders/single_file_utils.py +225 -5
  24. diffusers/loaders/textual_inversion.py +3 -2
  25. diffusers/loaders/transformer_flux.py +1 -1
  26. diffusers/loaders/transformer_sd3.py +2 -2
  27. diffusers/loaders/unet.py +2 -16
  28. diffusers/loaders/unet_loader_utils.py +1 -1
  29. diffusers/loaders/utils.py +1 -1
  30. diffusers/models/__init__.py +15 -1
  31. diffusers/models/activations.py +5 -5
  32. diffusers/models/adapter.py +2 -3
  33. diffusers/models/attention.py +4 -4
  34. diffusers/models/attention_flax.py +10 -10
  35. diffusers/models/attention_processor.py +14 -10
  36. diffusers/models/auto_model.py +47 -10
  37. diffusers/models/autoencoders/__init__.py +1 -0
  38. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  39. diffusers/models/autoencoders/autoencoder_dc.py +3 -3
  40. diffusers/models/autoencoders/autoencoder_kl.py +4 -4
  41. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  42. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  43. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
  44. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  45. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  46. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  47. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  48. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  49. diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
  50. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  52. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  53. diffusers/models/autoencoders/vae.py +13 -2
  54. diffusers/models/autoencoders/vq_model.py +2 -2
  55. diffusers/models/cache_utils.py +1 -1
  56. diffusers/models/controlnet.py +1 -1
  57. diffusers/models/controlnet_flux.py +1 -1
  58. diffusers/models/controlnet_sd3.py +1 -1
  59. diffusers/models/controlnet_sparsectrl.py +1 -1
  60. diffusers/models/controlnets/__init__.py +1 -0
  61. diffusers/models/controlnets/controlnet.py +3 -3
  62. diffusers/models/controlnets/controlnet_flax.py +1 -1
  63. diffusers/models/controlnets/controlnet_flux.py +16 -15
  64. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  65. diffusers/models/controlnets/controlnet_sana.py +290 -0
  66. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  67. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  68. diffusers/models/controlnets/controlnet_union.py +1 -1
  69. diffusers/models/controlnets/controlnet_xs.py +7 -7
  70. diffusers/models/controlnets/multicontrolnet.py +4 -5
  71. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  72. diffusers/models/downsampling.py +2 -2
  73. diffusers/models/embeddings.py +10 -12
  74. diffusers/models/embeddings_flax.py +2 -2
  75. diffusers/models/lora.py +3 -3
  76. diffusers/models/modeling_utils.py +44 -14
  77. diffusers/models/normalization.py +4 -4
  78. diffusers/models/resnet.py +2 -2
  79. diffusers/models/resnet_flax.py +1 -1
  80. diffusers/models/transformers/__init__.py +5 -0
  81. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  82. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  83. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  84. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  85. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  86. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  87. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  88. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  89. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  90. diffusers/models/transformers/prior_transformer.py +1 -1
  91. diffusers/models/transformers/sana_transformer.py +8 -3
  92. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  93. diffusers/models/transformers/t5_film_transformer.py +3 -3
  94. diffusers/models/transformers/transformer_2d.py +1 -1
  95. diffusers/models/transformers/transformer_allegro.py +1 -1
  96. diffusers/models/transformers/transformer_chroma.py +742 -0
  97. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  98. diffusers/models/transformers/transformer_cogview4.py +317 -25
  99. diffusers/models/transformers/transformer_cosmos.py +579 -0
  100. diffusers/models/transformers/transformer_flux.py +9 -11
  101. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  102. diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
  103. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  104. diffusers/models/transformers/transformer_ltx.py +2 -2
  105. diffusers/models/transformers/transformer_lumina2.py +1 -1
  106. diffusers/models/transformers/transformer_mochi.py +1 -1
  107. diffusers/models/transformers/transformer_omnigen.py +2 -2
  108. diffusers/models/transformers/transformer_sd3.py +7 -7
  109. diffusers/models/transformers/transformer_temporal.py +1 -1
  110. diffusers/models/transformers/transformer_wan.py +24 -8
  111. diffusers/models/transformers/transformer_wan_vace.py +393 -0
  112. diffusers/models/unets/unet_1d.py +1 -1
  113. diffusers/models/unets/unet_1d_blocks.py +1 -1
  114. diffusers/models/unets/unet_2d.py +1 -1
  115. diffusers/models/unets/unet_2d_blocks.py +1 -1
  116. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  117. diffusers/models/unets/unet_2d_condition.py +2 -2
  118. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  119. diffusers/models/unets/unet_3d_blocks.py +1 -1
  120. diffusers/models/unets/unet_3d_condition.py +3 -3
  121. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  122. diffusers/models/unets/unet_kandinsky3.py +1 -1
  123. diffusers/models/unets/unet_motion_model.py +2 -2
  124. diffusers/models/unets/unet_stable_cascade.py +1 -1
  125. diffusers/models/upsampling.py +2 -2
  126. diffusers/models/vae_flax.py +2 -2
  127. diffusers/models/vq_model.py +1 -1
  128. diffusers/pipelines/__init__.py +37 -6
  129. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  130. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  131. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  132. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  133. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  134. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  135. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  136. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  137. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  138. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  139. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  140. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  141. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
  142. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  143. diffusers/pipelines/auto_pipeline.py +6 -7
  144. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  145. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  146. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  147. diffusers/pipelines/chroma/__init__.py +49 -0
  148. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  149. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  150. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  151. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
  152. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
  153. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
  154. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
  155. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  156. diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
  157. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  158. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  159. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  160. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  161. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  162. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
  163. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  164. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  165. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  166. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  167. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  168. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
  169. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
  170. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
  171. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  172. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  173. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  174. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  175. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  176. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  177. diffusers/pipelines/cosmos/__init__.py +54 -0
  178. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  179. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  180. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  181. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  182. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  183. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  184. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  185. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  186. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  187. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  188. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  189. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  190. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  191. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  192. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  193. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  194. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  195. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  196. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  197. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  198. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  199. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  200. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  201. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  202. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  203. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  204. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
  205. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  206. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  207. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  208. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  209. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  210. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  211. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  212. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  213. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  214. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  215. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  216. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  217. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  218. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  219. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  220. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  221. diffusers/pipelines/flux/modeling_flux.py +1 -1
  222. diffusers/pipelines/flux/pipeline_flux.py +10 -17
  223. diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
  224. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
  225. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
  226. diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
  227. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
  228. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
  229. diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
  230. diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
  231. diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
  232. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  233. diffusers/pipelines/free_init_utils.py +2 -2
  234. diffusers/pipelines/free_noise_utils.py +3 -3
  235. diffusers/pipelines/hidream_image/__init__.py +47 -0
  236. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  237. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  238. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  239. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  240. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
  241. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  242. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  243. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  244. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  245. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  246. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  247. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  248. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  249. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  250. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  251. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  252. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  253. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  254. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  255. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  256. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  257. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  258. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  259. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  260. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  261. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  262. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  263. diffusers/pipelines/kolors/text_encoder.py +3 -3
  264. diffusers/pipelines/kolors/tokenizer.py +1 -1
  265. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  266. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  267. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  268. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  269. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  270. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  271. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  272. diffusers/pipelines/ltx/__init__.py +4 -0
  273. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  274. diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
  275. diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
  276. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
  277. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  278. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  279. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  280. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  281. diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
  282. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  283. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  284. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  285. diffusers/pipelines/onnx_utils.py +15 -2
  286. diffusers/pipelines/pag/pag_utils.py +2 -2
  287. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  288. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  289. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  290. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  291. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  292. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  293. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  294. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  295. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  296. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  297. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  298. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  299. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  300. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  301. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  302. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  303. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  304. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  305. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  306. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  307. diffusers/pipelines/pipeline_flax_utils.py +3 -4
  308. diffusers/pipelines/pipeline_loading_utils.py +89 -13
  309. diffusers/pipelines/pipeline_utils.py +105 -33
  310. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
  311. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
  312. diffusers/pipelines/sana/__init__.py +4 -0
  313. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  314. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  315. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  316. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  317. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  318. diffusers/pipelines/shap_e/camera.py +1 -1
  319. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  320. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  321. diffusers/pipelines/shap_e/renderer.py +3 -3
  322. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  323. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  324. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  325. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  326. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  327. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  328. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  329. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  330. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  331. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  332. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  333. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
  334. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  335. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
  336. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
  337. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
  338. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  339. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  340. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  341. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  342. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  343. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  344. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  345. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  346. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  347. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  348. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  349. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  350. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
  351. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  352. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  353. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  354. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  355. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  356. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  357. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  358. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  359. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  360. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  361. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  362. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  363. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  364. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  365. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  366. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  367. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  368. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  369. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  370. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  371. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  372. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  373. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  374. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  375. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  376. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  377. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  378. diffusers/pipelines/unclip/text_proj.py +2 -2
  379. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  380. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  381. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  382. diffusers/pipelines/visualcloze/__init__.py +52 -0
  383. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  384. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  385. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  386. diffusers/pipelines/wan/__init__.py +2 -0
  387. diffusers/pipelines/wan/pipeline_wan.py +17 -12
  388. diffusers/pipelines/wan/pipeline_wan_i2v.py +42 -20
  389. diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
  390. diffusers/pipelines/wan/pipeline_wan_video2video.py +18 -18
  391. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  392. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  393. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  394. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  395. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  396. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  397. diffusers/quantizers/__init__.py +179 -1
  398. diffusers/quantizers/base.py +6 -1
  399. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  400. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  401. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  402. diffusers/quantizers/gguf/utils.py +16 -13
  403. diffusers/quantizers/quantization_config.py +18 -16
  404. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  405. diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
  406. diffusers/schedulers/__init__.py +3 -1
  407. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  408. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  409. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  410. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  411. diffusers/schedulers/scheduling_ddim.py +8 -8
  412. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  413. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  414. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  415. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  416. diffusers/schedulers/scheduling_ddpm.py +9 -9
  417. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  418. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  419. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  420. diffusers/schedulers/scheduling_deis_multistep.py +8 -8
  421. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  422. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
  423. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  424. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  425. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  426. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
  427. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  428. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  429. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  430. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  431. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  432. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  433. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  434. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  435. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  436. diffusers/schedulers/scheduling_ipndm.py +2 -2
  437. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  438. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  439. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  440. diffusers/schedulers/scheduling_lcm.py +3 -3
  441. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  442. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  443. diffusers/schedulers/scheduling_pndm.py +4 -4
  444. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  445. diffusers/schedulers/scheduling_repaint.py +9 -9
  446. diffusers/schedulers/scheduling_sasolver.py +15 -15
  447. diffusers/schedulers/scheduling_scm.py +1 -1
  448. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  449. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  450. diffusers/schedulers/scheduling_tcd.py +3 -3
  451. diffusers/schedulers/scheduling_unclip.py +5 -5
  452. diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
  453. diffusers/schedulers/scheduling_utils.py +1 -1
  454. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  455. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  456. diffusers/training_utils.py +13 -5
  457. diffusers/utils/__init__.py +5 -0
  458. diffusers/utils/accelerate_utils.py +1 -1
  459. diffusers/utils/doc_utils.py +1 -1
  460. diffusers/utils/dummy_pt_objects.py +120 -0
  461. diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
  462. diffusers/utils/dynamic_modules_utils.py +21 -3
  463. diffusers/utils/export_utils.py +1 -1
  464. diffusers/utils/import_utils.py +81 -18
  465. diffusers/utils/logging.py +1 -1
  466. diffusers/utils/outputs.py +2 -1
  467. diffusers/utils/peft_utils.py +91 -8
  468. diffusers/utils/state_dict_utils.py +20 -3
  469. diffusers/utils/testing_utils.py +59 -7
  470. diffusers/utils/torch_utils.py +25 -5
  471. diffusers/video_processor.py +2 -2
  472. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/METADATA +3 -3
  473. diffusers-0.34.0.dist-info/RECORD +639 -0
  474. diffusers-0.33.0.dist-info/RECORD +0 -608
  475. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
  476. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/WHEEL +0 -0
  477. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
  478. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -548,6 +548,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
548
548
  use_stream: bool = False,
549
549
  record_stream: bool = False,
550
550
  low_cpu_mem_usage=False,
551
+ offload_to_disk_path: Optional[str] = None,
551
552
  ) -> None:
552
553
  r"""
553
554
  Activates group offloading for the current model.
@@ -588,15 +589,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
588
589
  f"open an issue at https://github.com/huggingface/diffusers/issues."
589
590
  )
590
591
  apply_group_offloading(
591
- self,
592
- onload_device,
593
- offload_device,
594
- offload_type,
595
- num_blocks_per_group,
596
- non_blocking,
597
- use_stream,
598
- record_stream,
592
+ module=self,
593
+ onload_device=onload_device,
594
+ offload_device=offload_device,
595
+ offload_type=offload_type,
596
+ num_blocks_per_group=num_blocks_per_group,
597
+ non_blocking=non_blocking,
598
+ use_stream=use_stream,
599
+ record_stream=record_stream,
599
600
  low_cpu_mem_usage=low_cpu_mem_usage,
601
+ offload_to_disk_path=offload_to_disk_path,
600
602
  )
601
603
 
602
604
  def save_pretrained(
@@ -787,9 +789,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
787
789
  cache_dir (`Union[str, os.PathLike]`, *optional*):
788
790
  Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
789
791
  is not used.
790
- torch_dtype (`str` or `torch.dtype`, *optional*):
791
- Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
792
- dtype is automatically derived from the model's weights.
792
+ torch_dtype (`torch.dtype`, *optional*):
793
+ Override the default `torch.dtype` and load the model with another dtype.
793
794
  force_download (`bool`, *optional*, defaults to `False`):
794
795
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
795
796
  cached versions if they exist.
@@ -815,14 +816,43 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
815
816
  Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
816
817
  guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
817
818
  information.
818
- device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
819
+ device_map (`Union[int, str, torch.device]` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
819
820
  A map that specifies where each submodule should go. It doesn't need to be defined for each
820
821
  parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
821
822
  same device. Defaults to `None`, meaning that the model will be loaded on CPU.
822
823
 
824
+ Examples:
825
+
826
+ ```py
827
+ >>> from diffusers import AutoModel
828
+ >>> import torch
829
+
830
+ >>> # This works.
831
+ >>> model = AutoModel.from_pretrained(
832
+ ... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="cuda"
833
+ ... )
834
+ >>> # This also works (integer accelerator device ID).
835
+ >>> model = AutoModel.from_pretrained(
836
+ ... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map=0
837
+ ... )
838
+ >>> # Specifying a supported offloading strategy like "auto" also works.
839
+ >>> model = AutoModel.from_pretrained(
840
+ ... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="auto"
841
+ ... )
842
+ >>> # Specifying a dictionary as `device_map` also works.
843
+ >>> model = AutoModel.from_pretrained(
844
+ ... "stabilityai/stable-diffusion-xl-base-1.0",
845
+ ... subfolder="unet",
846
+ ... device_map={"": torch.device("cuda")},
847
+ ... )
848
+ ```
849
+
823
850
  Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
824
851
  more information about each option see [designing a device
825
- map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
852
+ map](https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap). You
853
+ can also refer to the [Diffusers-specific
854
+ documentation](https://huggingface.co/docs/diffusers/main/en/training/distributed_inference#model-sharding)
855
+ for more concrete examples.
826
856
  max_memory (`Dict`, *optional*):
827
857
  A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
828
858
  each GPU and the available CPU RAM if unset.
@@ -1388,7 +1418,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1388
1418
  low_cpu_mem_usage: bool = True,
1389
1419
  dtype: Optional[Union[str, torch.dtype]] = None,
1390
1420
  keep_in_fp32_modules: Optional[List[str]] = None,
1391
- device_map: Dict[str, Union[int, str, torch.device]] = None,
1421
+ device_map: Union[str, int, torch.device, Dict[str, Union[int, str, torch.device]]] = None,
1392
1422
  offload_state_dict: Optional[bool] = None,
1393
1423
  offload_folder: Optional[Union[str, os.PathLike]] = None,
1394
1424
  dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2024 HuggingFace Inc.
2
+ # Copyright 2025 HuggingFace Inc.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -237,7 +237,7 @@ class AdaLayerNormSingle(nn.Module):
237
237
  r"""
238
238
  Norm layer adaptive layer norm single (adaLN-single).
239
239
 
240
- As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
240
+ As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3).
241
241
 
242
242
  Parameters:
243
243
  embedding_dim (`int`): The size of each embedding vector.
@@ -510,7 +510,7 @@ else:
510
510
 
511
511
  class RMSNorm(nn.Module):
512
512
  r"""
513
- RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al.
513
+ RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.
514
514
 
515
515
  Args:
516
516
  dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
@@ -600,7 +600,7 @@ class MochiRMSNorm(nn.Module):
600
600
 
601
601
  class GlobalResponseNorm(nn.Module):
602
602
  r"""
603
- Global response normalization as introduced in ConvNeXt-v2 (https://arxiv.org/abs/2301.00808).
603
+ Global response normalization as introduced in ConvNeXt-v2 (https://huggingface.co/papers/2301.00808).
604
604
 
605
605
  Args:
606
606
  dim (`int`): Number of dimensions to use for the `gamma` and `beta`.
@@ -1,5 +1,5 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- # `TemporalConvLayer` Copyright 2024 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ # `TemporalConvLayer` Copyright 2025 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -17,11 +17,15 @@ if is_torch_available():
17
17
  from .t5_film_transformer import T5FilmDecoder
18
18
  from .transformer_2d import Transformer2DModel
19
19
  from .transformer_allegro import AllegroTransformer3DModel
20
+ from .transformer_chroma import ChromaTransformer2DModel
20
21
  from .transformer_cogview3plus import CogView3PlusTransformer2DModel
21
22
  from .transformer_cogview4 import CogView4Transformer2DModel
23
+ from .transformer_cosmos import CosmosTransformer3DModel
22
24
  from .transformer_easyanimate import EasyAnimateTransformer3DModel
23
25
  from .transformer_flux import FluxTransformer2DModel
26
+ from .transformer_hidream_image import HiDreamImageTransformer2DModel
24
27
  from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
28
+ from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
25
29
  from .transformer_ltx import LTXVideoTransformer3DModel
26
30
  from .transformer_lumina2 import Lumina2Transformer2DModel
27
31
  from .transformer_mochi import MochiTransformer3DModel
@@ -29,3 +33,4 @@ if is_torch_available():
29
33
  from .transformer_sd3 import SD3Transformer2DModel
30
34
  from .transformer_temporal import TransformerTemporalModel
31
35
  from .transformer_wan import WanTransformer3DModel
36
+ from .transformer_wan_vace import WanVACETransformer3DModel
@@ -1,4 +1,4 @@
1
- # Copyright 2024 AuraFlow Authors, The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 AuraFlow Authors, The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,15 +13,15 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
- from typing import Dict, Union
16
+ from typing import Any, Dict, Optional, Union
17
17
 
18
18
  import torch
19
19
  import torch.nn as nn
20
20
  import torch.nn.functional as F
21
21
 
22
22
  from ...configuration_utils import ConfigMixin, register_to_config
23
- from ...loaders import FromOriginalModelMixin
24
- from ...utils import logging
23
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25
25
  from ...utils.torch_utils import maybe_allow_in_graph
26
26
  from ..attention_processor import (
27
27
  Attention,
@@ -74,15 +74,23 @@ class AuraFlowPatchEmbed(nn.Module):
74
74
  # PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
75
75
  # because original input are in flattened format, we have to flatten this 2d grid as well.
76
76
  h_p, w_p = h // self.patch_size, w // self.patch_size
77
- original_pe_indexes = torch.arange(self.pos_embed.shape[1])
78
77
  h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
79
- original_pe_indexes = original_pe_indexes.view(h_max, w_max)
78
+
79
+ # Calculate the top-left corner indices for the centered patch grid
80
80
  starth = h_max // 2 - h_p // 2
81
- endh = starth + h_p
82
81
  startw = w_max // 2 - w_p // 2
83
- endw = startw + w_p
84
- original_pe_indexes = original_pe_indexes[starth:endh, startw:endw]
85
- return original_pe_indexes.flatten()
82
+
83
+ # Generate the row and column indices for the desired patch grid
84
+ rows = torch.arange(starth, starth + h_p, device=self.pos_embed.device)
85
+ cols = torch.arange(startw, startw + w_p, device=self.pos_embed.device)
86
+
87
+ # Create a 2D grid of indices
88
+ row_indices, col_indices = torch.meshgrid(rows, cols, indexing="ij")
89
+
90
+ # Convert the 2D grid indices to flattened 1D indices
91
+ selected_indices = (row_indices * w_max + col_indices).flatten()
92
+
93
+ return selected_indices
86
94
 
87
95
  def forward(self, latent):
88
96
  batch_size, num_channels, height, width = latent.size()
@@ -160,14 +168,20 @@ class AuraFlowSingleTransformerBlock(nn.Module):
160
168
  self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
161
169
  self.ff = AuraFlowFeedForward(dim, dim * 4)
162
170
 
163
- def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor):
171
+ def forward(
172
+ self,
173
+ hidden_states: torch.FloatTensor,
174
+ temb: torch.FloatTensor,
175
+ attention_kwargs: Optional[Dict[str, Any]] = None,
176
+ ):
164
177
  residual = hidden_states
178
+ attention_kwargs = attention_kwargs or {}
165
179
 
166
180
  # Norm + Projection.
167
181
  norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
168
182
 
169
183
  # Attention.
170
- attn_output = self.attn(hidden_states=norm_hidden_states)
184
+ attn_output = self.attn(hidden_states=norm_hidden_states, **attention_kwargs)
171
185
 
172
186
  # Process attention outputs for the `hidden_states`.
173
187
  hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
@@ -223,10 +237,15 @@ class AuraFlowJointTransformerBlock(nn.Module):
223
237
  self.ff_context = AuraFlowFeedForward(dim, dim * 4)
224
238
 
225
239
  def forward(
226
- self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
240
+ self,
241
+ hidden_states: torch.FloatTensor,
242
+ encoder_hidden_states: torch.FloatTensor,
243
+ temb: torch.FloatTensor,
244
+ attention_kwargs: Optional[Dict[str, Any]] = None,
227
245
  ):
228
246
  residual = hidden_states
229
247
  residual_context = encoder_hidden_states
248
+ attention_kwargs = attention_kwargs or {}
230
249
 
231
250
  # Norm + Projection.
232
251
  norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
@@ -236,7 +255,9 @@ class AuraFlowJointTransformerBlock(nn.Module):
236
255
 
237
256
  # Attention.
238
257
  attn_output, context_attn_output = self.attn(
239
- hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
258
+ hidden_states=norm_hidden_states,
259
+ encoder_hidden_states=norm_encoder_hidden_states,
260
+ **attention_kwargs,
240
261
  )
241
262
 
242
263
  # Process attention outputs for the `hidden_states`.
@@ -254,7 +275,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
254
275
  return encoder_hidden_states, hidden_states
255
276
 
256
277
 
257
- class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
278
+ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
258
279
  r"""
259
280
  A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
260
281
 
@@ -262,17 +283,17 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
262
283
  sample_size (`int`): The width of the latent images. This is fixed during training since
263
284
  it is used to learn a number of position embeddings.
264
285
  patch_size (`int`): Patch size to turn the input data into small patches.
265
- in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
286
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
266
287
  num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
267
- num_single_dit_layers (`int`, *optional*, defaults to 4):
288
+ num_single_dit_layers (`int`, *optional*, defaults to 32):
268
289
  The number of layers of Transformer blocks to use. These blocks use concatenated image and text
269
290
  representations.
270
- attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
271
- num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
291
+ attention_head_dim (`int`, *optional*, defaults to 256): The number of channels in each head.
292
+ num_attention_heads (`int`, *optional*, defaults to 12): The number of heads to use for multi-head attention.
272
293
  joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
273
294
  caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
274
- out_channels (`int`, defaults to 16): Number of output channels.
275
- pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
295
+ out_channels (`int`, defaults to 4): Number of output channels.
296
+ pos_embed_max_size (`int`, defaults to 1024): Maximum positions to embed from the image latents.
276
297
  """
277
298
 
278
299
  _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
@@ -338,7 +359,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
338
359
  self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim)
339
360
  self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
340
361
 
341
- # https://arxiv.org/abs/2309.16588
362
+ # https://huggingface.co/papers/2309.16588
342
363
  # prevents artifacts in the attention maps
343
364
  self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02)
344
365
 
@@ -449,8 +470,24 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
449
470
  hidden_states: torch.FloatTensor,
450
471
  encoder_hidden_states: torch.FloatTensor = None,
451
472
  timestep: torch.LongTensor = None,
473
+ attention_kwargs: Optional[Dict[str, Any]] = None,
452
474
  return_dict: bool = True,
453
475
  ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
476
+ if attention_kwargs is not None:
477
+ attention_kwargs = attention_kwargs.copy()
478
+ lora_scale = attention_kwargs.pop("scale", 1.0)
479
+ else:
480
+ lora_scale = 1.0
481
+
482
+ if USE_PEFT_BACKEND:
483
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
484
+ scale_lora_layers(self, lora_scale)
485
+ else:
486
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
487
+ logger.warning(
488
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
489
+ )
490
+
454
491
  height, width = hidden_states.shape[-2:]
455
492
 
456
493
  # Apply patch embedding, timestep embedding, and project the caption embeddings.
@@ -474,7 +511,10 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
474
511
 
475
512
  else:
476
513
  encoder_hidden_states, hidden_states = block(
477
- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
514
+ hidden_states=hidden_states,
515
+ encoder_hidden_states=encoder_hidden_states,
516
+ temb=temb,
517
+ attention_kwargs=attention_kwargs,
478
518
  )
479
519
 
480
520
  # Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
@@ -491,7 +531,9 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
491
531
  )
492
532
 
493
533
  else:
494
- combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb)
534
+ combined_hidden_states = block(
535
+ hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs
536
+ )
495
537
 
496
538
  hidden_states = combined_hidden_states[:, encoder_seq_len:]
497
539
 
@@ -512,6 +554,10 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
512
554
  shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
513
555
  )
514
556
 
557
+ if USE_PEFT_BACKEND:
558
+ # remove `lora_scale` from each PEFT layer
559
+ unscale_lora_layers(self, lora_scale)
560
+
515
561
  if not return_dict:
516
562
  return (output,)
517
563
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
1
+ # Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # Copyright 2024 ConsisID Authors and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 ConsisID Authors and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -30,7 +30,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
30
 
31
31
  class DiTTransformer2DModel(ModelMixin, ConfigMixin):
32
32
  r"""
33
- A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748).
33
+ A 2D Transformer model as introduced in DiT (https://huggingface.co/papers/2212.09748).
34
34
 
35
35
  Parameters:
36
36
  num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -308,7 +308,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
308
308
  activation_fn=activation_fn,
309
309
  ff_inner_dim=int(self.inner_dim * mlp_ratio),
310
310
  cross_attention_dim=cross_attention_dim,
311
- qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
311
+ qk_norm=True, # See https://huggingface.co/papers/2302.05442 for details.
312
312
  skip=layer > num_layers // 2,
313
313
  )
314
314
  for layer in range(num_layers)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 the Latte Team and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 the Latte Team and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -18,10 +18,9 @@ import torch
18
18
  from torch import nn
19
19
 
20
20
  from ...configuration_utils import ConfigMixin, register_to_config
21
- from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
22
21
  from ..attention import BasicTransformerBlock
23
22
  from ..cache_utils import CacheMixin
24
- from ..embeddings import PatchEmbed
23
+ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
25
24
  from ..modeling_outputs import Transformer2DModelOutput
26
25
  from ..modeling_utils import ModelMixin
27
26
  from ..normalization import AdaLayerNormSingle
@@ -31,7 +30,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
31
30
  _supports_gradient_checkpointing = True
32
31
 
33
32
  """
34
- A 3D Transformer model for video-like data, paper: https://arxiv.org/abs/2401.03048, offical code:
33
+ A 3D Transformer model for video-like data, paper: https://huggingface.co/papers/2401.03048, official code:
35
34
  https://github.com/Vchitect/Latte
36
35
 
37
36
  Parameters:
@@ -217,7 +216,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
217
216
  )
218
217
  num_patches = height * width
219
218
 
220
- hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings
219
+ hidden_states = self.pos_embed(hidden_states) # already add positional embeddings
221
220
 
222
221
  added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
223
222
  timestep, embedded_timestep = self.adaln_single(
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -43,7 +43,7 @@ class LuminaNextDiTBlock(nn.Module):
43
43
  num_kv_heads (`int`):
44
44
  Number of attention heads in key and value features (if using GQA), or set to None for the same as query.
45
45
  multiple_of (`int`): The number of multiple of ffn layer.
46
- ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension.
46
+ ffn_dim_multiplier (`float`): The multiplier factor of ffn layer dimension.
47
47
  norm_eps (`float`): The eps for norm layer.
48
48
  qk_norm (`bool`): normalization for query and key.
49
49
  cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -31,8 +31,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
31
 
32
32
  class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
33
33
  r"""
34
- A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
35
- https://arxiv.org/abs/2403.04692).
34
+ A 2D Transformer model as introduced in PixArt family of models (https://huggingface.co/papers/2310.00426,
35
+ https://huggingface.co/papers/2403.04692).
36
36
 
37
37
  Parameters:
38
38
  num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
@@ -61,7 +61,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
61
61
  added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
62
62
  Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
63
63
  product between the text embedding and image embedding as proposed in the unclip paper
64
- https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
64
+ https://huggingface.co/papers/2204.06125 If it is `None`, no additional embeddings will be prepended.
65
65
  time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
66
66
  If None, will be set to `num_attention_heads * attention_head_dim`
67
67
  embedding_proj_dim (`int`, *optional*, default to None):
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -483,6 +483,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
483
483
  encoder_attention_mask: Optional[torch.Tensor] = None,
484
484
  attention_mask: Optional[torch.Tensor] = None,
485
485
  attention_kwargs: Optional[Dict[str, Any]] = None,
486
+ controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
486
487
  return_dict: bool = True,
487
488
  ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
488
489
  if attention_kwargs is not None:
@@ -546,7 +547,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
546
547
 
547
548
  # 2. Transformer blocks
548
549
  if torch.is_grad_enabled() and self.gradient_checkpointing:
549
- for block in self.transformer_blocks:
550
+ for index_block, block in enumerate(self.transformer_blocks):
550
551
  hidden_states = self._gradient_checkpointing_func(
551
552
  block,
552
553
  hidden_states,
@@ -557,9 +558,11 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
557
558
  post_patch_height,
558
559
  post_patch_width,
559
560
  )
561
+ if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
562
+ hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
560
563
 
561
564
  else:
562
- for block in self.transformer_blocks:
565
+ for index_block, block in enumerate(self.transformer_blocks):
563
566
  hidden_states = block(
564
567
  hidden_states,
565
568
  attention_mask,
@@ -569,6 +572,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
569
572
  post_patch_height,
570
573
  post_patch_width,
571
574
  )
575
+ if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
576
+ hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
572
577
 
573
578
  # 3. Normalization
574
579
  hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -21,16 +21,12 @@ import torch.nn as nn
21
21
  import torch.utils.checkpoint
22
22
 
23
23
  from ...configuration_utils import ConfigMixin, register_to_config
24
- from ...models.attention import FeedForward
25
- from ...models.attention_processor import (
26
- Attention,
27
- AttentionProcessor,
28
- StableAudioAttnProcessor2_0,
29
- )
30
- from ...models.modeling_utils import ModelMixin
31
- from ...models.transformers.transformer_2d import Transformer2DModelOutput
32
24
  from ...utils import logging
33
25
  from ...utils.torch_utils import maybe_allow_in_graph
26
+ from ..attention import FeedForward
27
+ from ..attention_processor import Attention, AttentionProcessor, StableAudioAttnProcessor2_0
28
+ from ..modeling_utils import ModelMixin
29
+ from ..transformers.transformer_2d import Transformer2DModelOutput
34
30
 
35
31
 
36
32
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -390,7 +390,7 @@ class T5LayerNorm(nn.Module):
390
390
 
391
391
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
392
392
  # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
393
- # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
393
+ # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
394
394
  # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
395
395
  # half-precision inputs is done in fp32
396
396
 
@@ -407,7 +407,7 @@ class T5LayerNorm(nn.Module):
407
407
  class NewGELUActivation(nn.Module):
408
408
  """
409
409
  Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
410
- the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
410
+ the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
411
411
  """
412
412
 
413
413
  def forward(self, input: torch.Tensor) -> torch.Tensor:
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The RhymesAI and The HuggingFace Team.
1
+ # Copyright 2025 The RhymesAI and The HuggingFace Team.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");