diffusers 0.33.0__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (478) hide show
  1. diffusers/__init__.py +48 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/dependency_versions_check.py +1 -1
  7. diffusers/dependency_versions_table.py +1 -1
  8. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  9. diffusers/hooks/faster_cache.py +2 -2
  10. diffusers/hooks/group_offloading.py +128 -29
  11. diffusers/hooks/hooks.py +2 -2
  12. diffusers/hooks/layerwise_casting.py +3 -3
  13. diffusers/hooks/pyramid_attention_broadcast.py +1 -1
  14. diffusers/image_processor.py +7 -2
  15. diffusers/loaders/__init__.py +4 -0
  16. diffusers/loaders/ip_adapter.py +5 -14
  17. diffusers/loaders/lora_base.py +212 -111
  18. diffusers/loaders/lora_conversion_utils.py +275 -34
  19. diffusers/loaders/lora_pipeline.py +1554 -819
  20. diffusers/loaders/peft.py +52 -109
  21. diffusers/loaders/single_file.py +2 -2
  22. diffusers/loaders/single_file_model.py +20 -4
  23. diffusers/loaders/single_file_utils.py +225 -5
  24. diffusers/loaders/textual_inversion.py +3 -2
  25. diffusers/loaders/transformer_flux.py +1 -1
  26. diffusers/loaders/transformer_sd3.py +2 -2
  27. diffusers/loaders/unet.py +2 -16
  28. diffusers/loaders/unet_loader_utils.py +1 -1
  29. diffusers/loaders/utils.py +1 -1
  30. diffusers/models/__init__.py +15 -1
  31. diffusers/models/activations.py +5 -5
  32. diffusers/models/adapter.py +2 -3
  33. diffusers/models/attention.py +4 -4
  34. diffusers/models/attention_flax.py +10 -10
  35. diffusers/models/attention_processor.py +14 -10
  36. diffusers/models/auto_model.py +47 -10
  37. diffusers/models/autoencoders/__init__.py +1 -0
  38. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  39. diffusers/models/autoencoders/autoencoder_dc.py +3 -3
  40. diffusers/models/autoencoders/autoencoder_kl.py +4 -4
  41. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  42. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  43. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
  44. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  45. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  46. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  47. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  48. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  49. diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
  50. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  52. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  53. diffusers/models/autoencoders/vae.py +13 -2
  54. diffusers/models/autoencoders/vq_model.py +2 -2
  55. diffusers/models/cache_utils.py +1 -1
  56. diffusers/models/controlnet.py +1 -1
  57. diffusers/models/controlnet_flux.py +1 -1
  58. diffusers/models/controlnet_sd3.py +1 -1
  59. diffusers/models/controlnet_sparsectrl.py +1 -1
  60. diffusers/models/controlnets/__init__.py +1 -0
  61. diffusers/models/controlnets/controlnet.py +3 -3
  62. diffusers/models/controlnets/controlnet_flax.py +1 -1
  63. diffusers/models/controlnets/controlnet_flux.py +16 -15
  64. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  65. diffusers/models/controlnets/controlnet_sana.py +290 -0
  66. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  67. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  68. diffusers/models/controlnets/controlnet_union.py +1 -1
  69. diffusers/models/controlnets/controlnet_xs.py +7 -7
  70. diffusers/models/controlnets/multicontrolnet.py +4 -5
  71. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  72. diffusers/models/downsampling.py +2 -2
  73. diffusers/models/embeddings.py +10 -12
  74. diffusers/models/embeddings_flax.py +2 -2
  75. diffusers/models/lora.py +3 -3
  76. diffusers/models/modeling_utils.py +44 -14
  77. diffusers/models/normalization.py +4 -4
  78. diffusers/models/resnet.py +2 -2
  79. diffusers/models/resnet_flax.py +1 -1
  80. diffusers/models/transformers/__init__.py +5 -0
  81. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  82. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  83. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  84. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  85. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  86. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  87. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  88. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  89. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  90. diffusers/models/transformers/prior_transformer.py +1 -1
  91. diffusers/models/transformers/sana_transformer.py +8 -3
  92. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  93. diffusers/models/transformers/t5_film_transformer.py +3 -3
  94. diffusers/models/transformers/transformer_2d.py +1 -1
  95. diffusers/models/transformers/transformer_allegro.py +1 -1
  96. diffusers/models/transformers/transformer_chroma.py +742 -0
  97. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  98. diffusers/models/transformers/transformer_cogview4.py +317 -25
  99. diffusers/models/transformers/transformer_cosmos.py +579 -0
  100. diffusers/models/transformers/transformer_flux.py +9 -11
  101. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  102. diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
  103. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  104. diffusers/models/transformers/transformer_ltx.py +2 -2
  105. diffusers/models/transformers/transformer_lumina2.py +1 -1
  106. diffusers/models/transformers/transformer_mochi.py +1 -1
  107. diffusers/models/transformers/transformer_omnigen.py +2 -2
  108. diffusers/models/transformers/transformer_sd3.py +7 -7
  109. diffusers/models/transformers/transformer_temporal.py +1 -1
  110. diffusers/models/transformers/transformer_wan.py +24 -8
  111. diffusers/models/transformers/transformer_wan_vace.py +393 -0
  112. diffusers/models/unets/unet_1d.py +1 -1
  113. diffusers/models/unets/unet_1d_blocks.py +1 -1
  114. diffusers/models/unets/unet_2d.py +1 -1
  115. diffusers/models/unets/unet_2d_blocks.py +1 -1
  116. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  117. diffusers/models/unets/unet_2d_condition.py +2 -2
  118. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  119. diffusers/models/unets/unet_3d_blocks.py +1 -1
  120. diffusers/models/unets/unet_3d_condition.py +3 -3
  121. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  122. diffusers/models/unets/unet_kandinsky3.py +1 -1
  123. diffusers/models/unets/unet_motion_model.py +2 -2
  124. diffusers/models/unets/unet_stable_cascade.py +1 -1
  125. diffusers/models/upsampling.py +2 -2
  126. diffusers/models/vae_flax.py +2 -2
  127. diffusers/models/vq_model.py +1 -1
  128. diffusers/pipelines/__init__.py +37 -6
  129. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  130. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  131. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  132. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  133. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  134. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  135. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  136. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  137. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  138. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  139. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  140. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  141. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
  142. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  143. diffusers/pipelines/auto_pipeline.py +6 -7
  144. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  145. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  146. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  147. diffusers/pipelines/chroma/__init__.py +49 -0
  148. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  149. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  150. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  151. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
  152. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
  153. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
  154. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
  155. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  156. diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
  157. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  158. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  159. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  160. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  161. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  162. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
  163. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  164. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  165. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  166. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  167. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  168. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
  169. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
  170. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
  171. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  172. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  173. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  174. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  175. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  176. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  177. diffusers/pipelines/cosmos/__init__.py +54 -0
  178. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  179. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  180. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  181. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  182. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  183. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  184. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  185. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  186. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  187. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  188. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  189. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  190. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  191. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  192. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  193. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  194. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  195. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  196. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  197. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  198. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  199. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  200. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  201. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  202. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  203. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  204. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
  205. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  206. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  207. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  208. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  209. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  210. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  211. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  212. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  213. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  214. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  215. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  216. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  217. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  218. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  219. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  220. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  221. diffusers/pipelines/flux/modeling_flux.py +1 -1
  222. diffusers/pipelines/flux/pipeline_flux.py +10 -17
  223. diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
  224. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
  225. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
  226. diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
  227. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
  228. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
  229. diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
  230. diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
  231. diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
  232. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  233. diffusers/pipelines/free_init_utils.py +2 -2
  234. diffusers/pipelines/free_noise_utils.py +3 -3
  235. diffusers/pipelines/hidream_image/__init__.py +47 -0
  236. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  237. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  238. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  239. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  240. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
  241. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  242. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  243. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  244. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  245. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  246. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  247. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  248. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  249. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  250. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  251. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  252. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  253. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  254. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  255. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  256. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  257. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  258. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  259. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  260. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  261. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  262. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  263. diffusers/pipelines/kolors/text_encoder.py +3 -3
  264. diffusers/pipelines/kolors/tokenizer.py +1 -1
  265. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  266. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  267. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  268. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  269. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  270. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  271. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  272. diffusers/pipelines/ltx/__init__.py +4 -0
  273. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  274. diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
  275. diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
  276. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
  277. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  278. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  279. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  280. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  281. diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
  282. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  283. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  284. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  285. diffusers/pipelines/onnx_utils.py +15 -2
  286. diffusers/pipelines/pag/pag_utils.py +2 -2
  287. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  288. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  289. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  290. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  291. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  292. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  293. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  294. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  295. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  296. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  297. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  298. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  299. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  300. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  301. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  302. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  303. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  304. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  305. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  306. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  307. diffusers/pipelines/pipeline_flax_utils.py +3 -4
  308. diffusers/pipelines/pipeline_loading_utils.py +89 -13
  309. diffusers/pipelines/pipeline_utils.py +105 -33
  310. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
  311. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
  312. diffusers/pipelines/sana/__init__.py +4 -0
  313. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  314. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  315. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  316. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  317. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  318. diffusers/pipelines/shap_e/camera.py +1 -1
  319. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  320. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  321. diffusers/pipelines/shap_e/renderer.py +3 -3
  322. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  323. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  324. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  325. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  326. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  327. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  328. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  329. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  330. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  331. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  332. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  333. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
  334. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  335. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
  336. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
  337. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
  338. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  339. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  340. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  341. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  342. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  343. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  344. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  345. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  346. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  347. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  348. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  349. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  350. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
  351. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  352. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  353. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  354. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  355. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  356. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  357. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  358. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  359. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  360. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  361. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  362. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  363. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  364. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  365. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  366. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  367. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  368. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  369. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  370. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  371. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  372. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  373. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  374. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  375. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  376. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  377. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  378. diffusers/pipelines/unclip/text_proj.py +2 -2
  379. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  380. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  381. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  382. diffusers/pipelines/visualcloze/__init__.py +52 -0
  383. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  384. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  385. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  386. diffusers/pipelines/wan/__init__.py +2 -0
  387. diffusers/pipelines/wan/pipeline_wan.py +17 -12
  388. diffusers/pipelines/wan/pipeline_wan_i2v.py +42 -20
  389. diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
  390. diffusers/pipelines/wan/pipeline_wan_video2video.py +18 -18
  391. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  392. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  393. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  394. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  395. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  396. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  397. diffusers/quantizers/__init__.py +179 -1
  398. diffusers/quantizers/base.py +6 -1
  399. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  400. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  401. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  402. diffusers/quantizers/gguf/utils.py +16 -13
  403. diffusers/quantizers/quantization_config.py +18 -16
  404. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  405. diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
  406. diffusers/schedulers/__init__.py +3 -1
  407. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  408. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  409. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  410. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  411. diffusers/schedulers/scheduling_ddim.py +8 -8
  412. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  413. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  414. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  415. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  416. diffusers/schedulers/scheduling_ddpm.py +9 -9
  417. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  418. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  419. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  420. diffusers/schedulers/scheduling_deis_multistep.py +8 -8
  421. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  422. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
  423. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  424. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  425. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  426. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
  427. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  428. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  429. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  430. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  431. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  432. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  433. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  434. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  435. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  436. diffusers/schedulers/scheduling_ipndm.py +2 -2
  437. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  438. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  439. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  440. diffusers/schedulers/scheduling_lcm.py +3 -3
  441. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  442. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  443. diffusers/schedulers/scheduling_pndm.py +4 -4
  444. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  445. diffusers/schedulers/scheduling_repaint.py +9 -9
  446. diffusers/schedulers/scheduling_sasolver.py +15 -15
  447. diffusers/schedulers/scheduling_scm.py +1 -1
  448. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  449. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  450. diffusers/schedulers/scheduling_tcd.py +3 -3
  451. diffusers/schedulers/scheduling_unclip.py +5 -5
  452. diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
  453. diffusers/schedulers/scheduling_utils.py +1 -1
  454. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  455. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  456. diffusers/training_utils.py +13 -5
  457. diffusers/utils/__init__.py +5 -0
  458. diffusers/utils/accelerate_utils.py +1 -1
  459. diffusers/utils/doc_utils.py +1 -1
  460. diffusers/utils/dummy_pt_objects.py +120 -0
  461. diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
  462. diffusers/utils/dynamic_modules_utils.py +21 -3
  463. diffusers/utils/export_utils.py +1 -1
  464. diffusers/utils/import_utils.py +81 -18
  465. diffusers/utils/logging.py +1 -1
  466. diffusers/utils/outputs.py +2 -1
  467. diffusers/utils/peft_utils.py +91 -8
  468. diffusers/utils/state_dict_utils.py +20 -3
  469. diffusers/utils/testing_utils.py +59 -7
  470. diffusers/utils/torch_utils.py +25 -5
  471. diffusers/video_processor.py +2 -2
  472. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/METADATA +3 -3
  473. diffusers-0.34.0.dist-info/RECORD +639 -0
  474. diffusers-0.33.0.dist-info/RECORD +0 -608
  475. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
  476. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/WHEEL +0 -0
  477. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
  478. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,579 @@
1
+ # Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...utils import is_torchvision_available
24
+ from ..attention import FeedForward
25
+ from ..attention_processor import Attention
26
+ from ..embeddings import Timesteps
27
+ from ..modeling_outputs import Transformer2DModelOutput
28
+ from ..modeling_utils import ModelMixin
29
+ from ..normalization import RMSNorm
30
+
31
+
32
+ if is_torchvision_available():
33
+ from torchvision import transforms
34
+
35
+
36
+ class CosmosPatchEmbed(nn.Module):
37
+ def __init__(
38
+ self, in_channels: int, out_channels: int, patch_size: Tuple[int, int, int], bias: bool = True
39
+ ) -> None:
40
+ super().__init__()
41
+ self.patch_size = patch_size
42
+
43
+ self.proj = nn.Linear(in_channels * patch_size[0] * patch_size[1] * patch_size[2], out_channels, bias=bias)
44
+
45
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
46
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
47
+ p_t, p_h, p_w = self.patch_size
48
+ hidden_states = hidden_states.reshape(
49
+ batch_size, num_channels, num_frames // p_t, p_t, height // p_h, p_h, width // p_w, p_w
50
+ )
51
+ hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7)
52
+ hidden_states = self.proj(hidden_states)
53
+ return hidden_states
54
+
55
+
56
+ class CosmosTimestepEmbedding(nn.Module):
57
+ def __init__(self, in_features: int, out_features: int) -> None:
58
+ super().__init__()
59
+ self.linear_1 = nn.Linear(in_features, out_features, bias=False)
60
+ self.activation = nn.SiLU()
61
+ self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False)
62
+
63
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
64
+ emb = self.linear_1(timesteps)
65
+ emb = self.activation(emb)
66
+ emb = self.linear_2(emb)
67
+ return emb
68
+
69
+
70
+ class CosmosEmbedding(nn.Module):
71
+ def __init__(self, embedding_dim: int, condition_dim: int) -> None:
72
+ super().__init__()
73
+
74
+ self.time_proj = Timesteps(embedding_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
75
+ self.t_embedder = CosmosTimestepEmbedding(embedding_dim, condition_dim)
76
+ self.norm = RMSNorm(embedding_dim, eps=1e-6, elementwise_affine=True)
77
+
78
+ def forward(self, hidden_states: torch.Tensor, timestep: torch.LongTensor) -> torch.Tensor:
79
+ timesteps_proj = self.time_proj(timestep).type_as(hidden_states)
80
+ temb = self.t_embedder(timesteps_proj)
81
+ embedded_timestep = self.norm(timesteps_proj)
82
+ return temb, embedded_timestep
83
+
84
+
85
+ class CosmosAdaLayerNorm(nn.Module):
86
+ def __init__(self, in_features: int, hidden_features: int) -> None:
87
+ super().__init__()
88
+ self.embedding_dim = in_features
89
+
90
+ self.activation = nn.SiLU()
91
+ self.norm = nn.LayerNorm(in_features, elementwise_affine=False, eps=1e-6)
92
+ self.linear_1 = nn.Linear(in_features, hidden_features, bias=False)
93
+ self.linear_2 = nn.Linear(hidden_features, 2 * in_features, bias=False)
94
+
95
+ def forward(
96
+ self, hidden_states: torch.Tensor, embedded_timestep: torch.Tensor, temb: Optional[torch.Tensor] = None
97
+ ) -> torch.Tensor:
98
+ embedded_timestep = self.activation(embedded_timestep)
99
+ embedded_timestep = self.linear_1(embedded_timestep)
100
+ embedded_timestep = self.linear_2(embedded_timestep)
101
+
102
+ if temb is not None:
103
+ embedded_timestep = embedded_timestep + temb[..., : 2 * self.embedding_dim]
104
+
105
+ shift, scale = embedded_timestep.chunk(2, dim=-1)
106
+ hidden_states = self.norm(hidden_states)
107
+
108
+ if embedded_timestep.ndim == 2:
109
+ shift, scale = (x.unsqueeze(1) for x in (shift, scale))
110
+
111
+ hidden_states = hidden_states * (1 + scale) + shift
112
+ return hidden_states
113
+
114
+
115
+ class CosmosAdaLayerNormZero(nn.Module):
116
+ def __init__(self, in_features: int, hidden_features: Optional[int] = None) -> None:
117
+ super().__init__()
118
+
119
+ self.norm = nn.LayerNorm(in_features, elementwise_affine=False, eps=1e-6)
120
+ self.activation = nn.SiLU()
121
+
122
+ if hidden_features is None:
123
+ self.linear_1 = nn.Identity()
124
+ else:
125
+ self.linear_1 = nn.Linear(in_features, hidden_features, bias=False)
126
+
127
+ self.linear_2 = nn.Linear(hidden_features, 3 * in_features, bias=False)
128
+
129
+ def forward(
130
+ self,
131
+ hidden_states: torch.Tensor,
132
+ embedded_timestep: torch.Tensor,
133
+ temb: Optional[torch.Tensor] = None,
134
+ ) -> torch.Tensor:
135
+ embedded_timestep = self.activation(embedded_timestep)
136
+ embedded_timestep = self.linear_1(embedded_timestep)
137
+ embedded_timestep = self.linear_2(embedded_timestep)
138
+
139
+ if temb is not None:
140
+ embedded_timestep = embedded_timestep + temb
141
+
142
+ shift, scale, gate = embedded_timestep.chunk(3, dim=-1)
143
+ hidden_states = self.norm(hidden_states)
144
+
145
+ if embedded_timestep.ndim == 2:
146
+ shift, scale, gate = (x.unsqueeze(1) for x in (shift, scale, gate))
147
+
148
+ hidden_states = hidden_states * (1 + scale) + shift
149
+ return hidden_states, gate
150
+
151
+
152
+ class CosmosAttnProcessor2_0:
153
+ def __init__(self):
154
+ if not hasattr(F, "scaled_dot_product_attention"):
155
+ raise ImportError("CosmosAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
156
+
157
+ def __call__(
158
+ self,
159
+ attn: Attention,
160
+ hidden_states: torch.Tensor,
161
+ encoder_hidden_states: Optional[torch.Tensor] = None,
162
+ attention_mask: Optional[torch.Tensor] = None,
163
+ image_rotary_emb: Optional[torch.Tensor] = None,
164
+ ) -> torch.Tensor:
165
+ # 1. QKV projections
166
+ if encoder_hidden_states is None:
167
+ encoder_hidden_states = hidden_states
168
+
169
+ query = attn.to_q(hidden_states)
170
+ key = attn.to_k(encoder_hidden_states)
171
+ value = attn.to_v(encoder_hidden_states)
172
+
173
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
174
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
175
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
176
+
177
+ # 2. QK normalization
178
+ query = attn.norm_q(query)
179
+ key = attn.norm_k(key)
180
+
181
+ # 3. Apply RoPE
182
+ if image_rotary_emb is not None:
183
+ from ..embeddings import apply_rotary_emb
184
+
185
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
186
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
187
+
188
+ # 4. Prepare for GQA
189
+ query_idx = torch.tensor(query.size(3), device=query.device)
190
+ key_idx = torch.tensor(key.size(3), device=key.device)
191
+ value_idx = torch.tensor(value.size(3), device=value.device)
192
+ key = key.repeat_interleave(query_idx // key_idx, dim=3)
193
+ value = value.repeat_interleave(query_idx // value_idx, dim=3)
194
+
195
+ # 5. Attention
196
+ hidden_states = F.scaled_dot_product_attention(
197
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
198
+ )
199
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
200
+
201
+ # 6. Output projection
202
+ hidden_states = attn.to_out[0](hidden_states)
203
+ hidden_states = attn.to_out[1](hidden_states)
204
+
205
+ return hidden_states
206
+
207
+
208
+ class CosmosTransformerBlock(nn.Module):
209
+ def __init__(
210
+ self,
211
+ num_attention_heads: int,
212
+ attention_head_dim: int,
213
+ cross_attention_dim: int,
214
+ mlp_ratio: float = 4.0,
215
+ adaln_lora_dim: int = 256,
216
+ qk_norm: str = "rms_norm",
217
+ out_bias: bool = False,
218
+ ) -> None:
219
+ super().__init__()
220
+
221
+ hidden_size = num_attention_heads * attention_head_dim
222
+
223
+ self.norm1 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
224
+ self.attn1 = Attention(
225
+ query_dim=hidden_size,
226
+ cross_attention_dim=None,
227
+ heads=num_attention_heads,
228
+ dim_head=attention_head_dim,
229
+ qk_norm=qk_norm,
230
+ elementwise_affine=True,
231
+ out_bias=out_bias,
232
+ processor=CosmosAttnProcessor2_0(),
233
+ )
234
+
235
+ self.norm2 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
236
+ self.attn2 = Attention(
237
+ query_dim=hidden_size,
238
+ cross_attention_dim=cross_attention_dim,
239
+ heads=num_attention_heads,
240
+ dim_head=attention_head_dim,
241
+ qk_norm=qk_norm,
242
+ elementwise_affine=True,
243
+ out_bias=out_bias,
244
+ processor=CosmosAttnProcessor2_0(),
245
+ )
246
+
247
+ self.norm3 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
248
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu", bias=out_bias)
249
+
250
+ def forward(
251
+ self,
252
+ hidden_states: torch.Tensor,
253
+ encoder_hidden_states: torch.Tensor,
254
+ embedded_timestep: torch.Tensor,
255
+ temb: Optional[torch.Tensor] = None,
256
+ image_rotary_emb: Optional[torch.Tensor] = None,
257
+ extra_pos_emb: Optional[torch.Tensor] = None,
258
+ attention_mask: Optional[torch.Tensor] = None,
259
+ ) -> torch.Tensor:
260
+ if extra_pos_emb is not None:
261
+ hidden_states = hidden_states + extra_pos_emb
262
+
263
+ # 1. Self Attention
264
+ norm_hidden_states, gate = self.norm1(hidden_states, embedded_timestep, temb)
265
+ attn_output = self.attn1(norm_hidden_states, image_rotary_emb=image_rotary_emb)
266
+ hidden_states = hidden_states + gate * attn_output
267
+
268
+ # 2. Cross Attention
269
+ norm_hidden_states, gate = self.norm2(hidden_states, embedded_timestep, temb)
270
+ attn_output = self.attn2(
271
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
272
+ )
273
+ hidden_states = hidden_states + gate * attn_output
274
+
275
+ # 3. Feed Forward
276
+ norm_hidden_states, gate = self.norm3(hidden_states, embedded_timestep, temb)
277
+ ff_output = self.ff(norm_hidden_states)
278
+ hidden_states = hidden_states + gate * ff_output
279
+
280
+ return hidden_states
281
+
282
+
283
+ class CosmosRotaryPosEmbed(nn.Module):
284
+ def __init__(
285
+ self,
286
+ hidden_size: int,
287
+ max_size: Tuple[int, int, int] = (128, 240, 240),
288
+ patch_size: Tuple[int, int, int] = (1, 2, 2),
289
+ base_fps: int = 24,
290
+ rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
291
+ ) -> None:
292
+ super().__init__()
293
+
294
+ self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
295
+ self.patch_size = patch_size
296
+ self.base_fps = base_fps
297
+
298
+ self.dim_h = hidden_size // 6 * 2
299
+ self.dim_w = hidden_size // 6 * 2
300
+ self.dim_t = hidden_size - self.dim_h - self.dim_w
301
+
302
+ self.h_ntk_factor = rope_scale[1] ** (self.dim_h / (self.dim_h - 2))
303
+ self.w_ntk_factor = rope_scale[2] ** (self.dim_w / (self.dim_w - 2))
304
+ self.t_ntk_factor = rope_scale[0] ** (self.dim_t / (self.dim_t - 2))
305
+
306
+ def forward(self, hidden_states: torch.Tensor, fps: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
307
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
308
+ pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
309
+ device = hidden_states.device
310
+
311
+ h_theta = 10000.0 * self.h_ntk_factor
312
+ w_theta = 10000.0 * self.w_ntk_factor
313
+ t_theta = 10000.0 * self.t_ntk_factor
314
+
315
+ seq = torch.arange(max(self.max_size), device=device, dtype=torch.float32)
316
+ dim_h_range = (
317
+ torch.arange(0, self.dim_h, 2, device=device, dtype=torch.float32)[: (self.dim_h // 2)] / self.dim_h
318
+ )
319
+ dim_w_range = (
320
+ torch.arange(0, self.dim_w, 2, device=device, dtype=torch.float32)[: (self.dim_w // 2)] / self.dim_w
321
+ )
322
+ dim_t_range = (
323
+ torch.arange(0, self.dim_t, 2, device=device, dtype=torch.float32)[: (self.dim_t // 2)] / self.dim_t
324
+ )
325
+ h_spatial_freqs = 1.0 / (h_theta**dim_h_range)
326
+ w_spatial_freqs = 1.0 / (w_theta**dim_w_range)
327
+ temporal_freqs = 1.0 / (t_theta**dim_t_range)
328
+
329
+ emb_h = torch.outer(seq[: pe_size[1]], h_spatial_freqs)[None, :, None, :].repeat(pe_size[0], 1, pe_size[2], 1)
330
+ emb_w = torch.outer(seq[: pe_size[2]], w_spatial_freqs)[None, None, :, :].repeat(pe_size[0], pe_size[1], 1, 1)
331
+
332
+ # Apply sequence scaling in temporal dimension
333
+ if fps is None:
334
+ # Images
335
+ emb_t = torch.outer(seq[: pe_size[0]], temporal_freqs)
336
+ else:
337
+ # Videos
338
+ emb_t = torch.outer(seq[: pe_size[0]] / fps * self.base_fps, temporal_freqs)
339
+
340
+ emb_t = emb_t[:, None, None, :].repeat(1, pe_size[1], pe_size[2], 1)
341
+ freqs = torch.cat([emb_t, emb_h, emb_w] * 2, dim=-1).flatten(0, 2).float()
342
+ cos = torch.cos(freqs)
343
+ sin = torch.sin(freqs)
344
+ return cos, sin
345
+
346
+
347
+ class CosmosLearnablePositionalEmbed(nn.Module):
348
+ def __init__(
349
+ self,
350
+ hidden_size: int,
351
+ max_size: Tuple[int, int, int],
352
+ patch_size: Tuple[int, int, int],
353
+ eps: float = 1e-6,
354
+ ) -> None:
355
+ super().__init__()
356
+
357
+ self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
358
+ self.patch_size = patch_size
359
+ self.eps = eps
360
+
361
+ self.pos_emb_t = nn.Parameter(torch.zeros(self.max_size[0], hidden_size))
362
+ self.pos_emb_h = nn.Parameter(torch.zeros(self.max_size[1], hidden_size))
363
+ self.pos_emb_w = nn.Parameter(torch.zeros(self.max_size[2], hidden_size))
364
+
365
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
366
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
367
+ pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
368
+
369
+ emb_t = self.pos_emb_t[: pe_size[0]][None, :, None, None, :].repeat(batch_size, 1, pe_size[1], pe_size[2], 1)
370
+ emb_h = self.pos_emb_h[: pe_size[1]][None, None, :, None, :].repeat(batch_size, pe_size[0], 1, pe_size[2], 1)
371
+ emb_w = self.pos_emb_w[: pe_size[2]][None, None, None, :, :].repeat(batch_size, pe_size[0], pe_size[1], 1, 1)
372
+ emb = emb_t + emb_h + emb_w
373
+ emb = emb.flatten(1, 3)
374
+
375
+ norm = torch.linalg.vector_norm(emb, dim=-1, keepdim=True, dtype=torch.float32)
376
+ norm = torch.add(self.eps, norm, alpha=np.sqrt(norm.numel() / emb.numel()))
377
+ return (emb / norm).type_as(hidden_states)
378
+
379
+
380
+ class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
381
+ r"""
382
+ A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos).
383
+
384
+ Args:
385
+ in_channels (`int`, defaults to `16`):
386
+ The number of channels in the input.
387
+ out_channels (`int`, defaults to `16`):
388
+ The number of channels in the output.
389
+ num_attention_heads (`int`, defaults to `32`):
390
+ The number of heads to use for multi-head attention.
391
+ attention_head_dim (`int`, defaults to `128`):
392
+ The number of channels in each attention head.
393
+ num_layers (`int`, defaults to `28`):
394
+ The number of layers of transformer blocks to use.
395
+ mlp_ratio (`float`, defaults to `4.0`):
396
+ The ratio of the hidden layer size to the input size in the feedforward network.
397
+ text_embed_dim (`int`, defaults to `4096`):
398
+ Input dimension of text embeddings from the text encoder.
399
+ adaln_lora_dim (`int`, defaults to `256`):
400
+ The hidden dimension of the Adaptive LayerNorm LoRA layer.
401
+ max_size (`Tuple[int, int, int]`, defaults to `(128, 240, 240)`):
402
+ The maximum size of the input latent tensors in the temporal, height, and width dimensions.
403
+ patch_size (`Tuple[int, int, int]`, defaults to `(1, 2, 2)`):
404
+ The patch size to use for patchifying the input latent tensors in the temporal, height, and width
405
+ dimensions.
406
+ rope_scale (`Tuple[float, float, float]`, defaults to `(2.0, 1.0, 1.0)`):
407
+ The scaling factor to use for RoPE in the temporal, height, and width dimensions.
408
+ concat_padding_mask (`bool`, defaults to `True`):
409
+ Whether to concatenate the padding mask to the input latent tensors.
410
+ extra_pos_embed_type (`str`, *optional*, defaults to `learnable`):
411
+ The type of extra positional embeddings to use. Can be one of `None` or `learnable`.
412
+ """
413
+
414
+ _supports_gradient_checkpointing = True
415
+ _skip_layerwise_casting_patterns = ["patch_embed", "final_layer", "norm"]
416
+ _no_split_modules = ["CosmosTransformerBlock"]
417
+ _keep_in_fp32_modules = ["learnable_pos_embed"]
418
+
419
+ @register_to_config
420
+ def __init__(
421
+ self,
422
+ in_channels: int = 16,
423
+ out_channels: int = 16,
424
+ num_attention_heads: int = 32,
425
+ attention_head_dim: int = 128,
426
+ num_layers: int = 28,
427
+ mlp_ratio: float = 4.0,
428
+ text_embed_dim: int = 1024,
429
+ adaln_lora_dim: int = 256,
430
+ max_size: Tuple[int, int, int] = (128, 240, 240),
431
+ patch_size: Tuple[int, int, int] = (1, 2, 2),
432
+ rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
433
+ concat_padding_mask: bool = True,
434
+ extra_pos_embed_type: Optional[str] = "learnable",
435
+ ) -> None:
436
+ super().__init__()
437
+ hidden_size = num_attention_heads * attention_head_dim
438
+
439
+ # 1. Patch Embedding
440
+ patch_embed_in_channels = in_channels + 1 if concat_padding_mask else in_channels
441
+ self.patch_embed = CosmosPatchEmbed(patch_embed_in_channels, hidden_size, patch_size, bias=False)
442
+
443
+ # 2. Positional Embedding
444
+ self.rope = CosmosRotaryPosEmbed(
445
+ hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale
446
+ )
447
+
448
+ self.learnable_pos_embed = None
449
+ if extra_pos_embed_type == "learnable":
450
+ self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
451
+ hidden_size=hidden_size,
452
+ max_size=max_size,
453
+ patch_size=patch_size,
454
+ )
455
+
456
+ # 3. Time Embedding
457
+ self.time_embed = CosmosEmbedding(hidden_size, hidden_size)
458
+
459
+ # 4. Transformer Blocks
460
+ self.transformer_blocks = nn.ModuleList(
461
+ [
462
+ CosmosTransformerBlock(
463
+ num_attention_heads=num_attention_heads,
464
+ attention_head_dim=attention_head_dim,
465
+ cross_attention_dim=text_embed_dim,
466
+ mlp_ratio=mlp_ratio,
467
+ adaln_lora_dim=adaln_lora_dim,
468
+ qk_norm="rms_norm",
469
+ out_bias=False,
470
+ )
471
+ for _ in range(num_layers)
472
+ ]
473
+ )
474
+
475
+ # 5. Output norm & projection
476
+ self.norm_out = CosmosAdaLayerNorm(hidden_size, adaln_lora_dim)
477
+ self.proj_out = nn.Linear(
478
+ hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False
479
+ )
480
+
481
+ self.gradient_checkpointing = False
482
+
483
+ def forward(
484
+ self,
485
+ hidden_states: torch.Tensor,
486
+ timestep: torch.Tensor,
487
+ encoder_hidden_states: torch.Tensor,
488
+ attention_mask: Optional[torch.Tensor] = None,
489
+ fps: Optional[int] = None,
490
+ condition_mask: Optional[torch.Tensor] = None,
491
+ padding_mask: Optional[torch.Tensor] = None,
492
+ return_dict: bool = True,
493
+ ) -> torch.Tensor:
494
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
495
+
496
+ # 1. Concatenate padding mask if needed & prepare attention mask
497
+ if condition_mask is not None:
498
+ hidden_states = torch.cat([hidden_states, condition_mask], dim=1)
499
+
500
+ if self.config.concat_padding_mask:
501
+ padding_mask = transforms.functional.resize(
502
+ padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
503
+ )
504
+ hidden_states = torch.cat(
505
+ [hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
506
+ )
507
+
508
+ if attention_mask is not None:
509
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
510
+
511
+ # 2. Generate positional embeddings
512
+ image_rotary_emb = self.rope(hidden_states, fps=fps)
513
+ extra_pos_emb = self.learnable_pos_embed(hidden_states) if self.config.extra_pos_embed_type else None
514
+
515
+ # 3. Patchify input
516
+ p_t, p_h, p_w = self.config.patch_size
517
+ post_patch_num_frames = num_frames // p_t
518
+ post_patch_height = height // p_h
519
+ post_patch_width = width // p_w
520
+ hidden_states = self.patch_embed(hidden_states)
521
+ hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
522
+
523
+ # 4. Timestep embeddings
524
+ if timestep.ndim == 1:
525
+ temb, embedded_timestep = self.time_embed(hidden_states, timestep)
526
+ elif timestep.ndim == 5:
527
+ assert timestep.shape == (batch_size, 1, num_frames, 1, 1), (
528
+ f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}"
529
+ )
530
+ timestep = timestep.flatten()
531
+ temb, embedded_timestep = self.time_embed(hidden_states, timestep)
532
+ # We can do this because num_frames == post_patch_num_frames, as p_t is 1
533
+ temb, embedded_timestep = (
534
+ x.view(batch_size, post_patch_num_frames, 1, 1, -1)
535
+ .expand(-1, -1, post_patch_height, post_patch_width, -1)
536
+ .flatten(1, 3)
537
+ for x in (temb, embedded_timestep)
538
+ ) # [BT, C] -> [B, T, 1, 1, C] -> [B, T, H, W, C] -> [B, THW, C]
539
+ else:
540
+ assert False
541
+
542
+ # 5. Transformer blocks
543
+ for block in self.transformer_blocks:
544
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
545
+ hidden_states = self._gradient_checkpointing_func(
546
+ block,
547
+ hidden_states,
548
+ encoder_hidden_states,
549
+ embedded_timestep,
550
+ temb,
551
+ image_rotary_emb,
552
+ extra_pos_emb,
553
+ attention_mask,
554
+ )
555
+ else:
556
+ hidden_states = block(
557
+ hidden_states=hidden_states,
558
+ encoder_hidden_states=encoder_hidden_states,
559
+ embedded_timestep=embedded_timestep,
560
+ temb=temb,
561
+ image_rotary_emb=image_rotary_emb,
562
+ extra_pos_emb=extra_pos_emb,
563
+ attention_mask=attention_mask,
564
+ )
565
+
566
+ # 6. Output norm & projection & unpatchify
567
+ hidden_states = self.norm_out(hidden_states, embedded_timestep, temb)
568
+ hidden_states = self.proj_out(hidden_states)
569
+ hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))
570
+ hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
571
+ # NOTE: The permutation order here is not the inverse operation of what happens when patching as usually expected.
572
+ # It might be a source of confusion to the reader, but this is correct
573
+ hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
574
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
575
+
576
+ if not return_dict:
577
+ return (hidden_states,)
578
+
579
+ return Transformer2DModelOutput(sample=hidden_states)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
1
+ # Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -21,22 +21,22 @@ import torch.nn as nn
21
21
 
22
22
  from ...configuration_utils import ConfigMixin, register_to_config
23
23
  from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
24
- from ...models.attention import FeedForward
25
- from ...models.attention_processor import (
24
+ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
25
+ from ...utils.import_utils import is_torch_npu_available
26
+ from ...utils.torch_utils import maybe_allow_in_graph
27
+ from ..attention import FeedForward
28
+ from ..attention_processor import (
26
29
  Attention,
27
30
  AttentionProcessor,
28
31
  FluxAttnProcessor2_0,
29
32
  FluxAttnProcessor2_0_NPU,
30
33
  FusedFluxAttnProcessor2_0,
31
34
  )
32
- from ...models.modeling_utils import ModelMixin
33
- from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
34
- from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
35
- from ...utils.import_utils import is_torch_npu_available
36
- from ...utils.torch_utils import maybe_allow_in_graph
37
35
  from ..cache_utils import CacheMixin
38
36
  from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
39
37
  from ..modeling_outputs import Transformer2DModelOutput
38
+ from ..modeling_utils import ModelMixin
39
+ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
40
40
 
41
41
 
42
42
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -241,7 +241,7 @@ class FluxTransformer2DModel(
241
241
  joint_attention_dim: int = 4096,
242
242
  pooled_projection_dim: int = 768,
243
243
  guidance_embeds: bool = False,
244
- axes_dims_rope: Tuple[int] = (16, 56, 56),
244
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
245
245
  ):
246
246
  super().__init__()
247
247
  self.out_channels = out_channels or in_channels
@@ -447,8 +447,6 @@ class FluxTransformer2DModel(
447
447
  timestep = timestep.to(hidden_states.dtype) * 1000
448
448
  if guidance is not None:
449
449
  guidance = guidance.to(hidden_states.dtype) * 1000
450
- else:
451
- guidance = None
452
450
 
453
451
  temb = (
454
452
  self.time_text_embed(timestep, pooled_projections)