diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Genmo team and The HuggingFace Team.
1
+ # Copyright 2025 The Genmo team and The HuggingFace Team.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 OmniGen team and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -283,7 +283,7 @@ class OmniGenBlock(nn.Module):
283
283
 
284
284
  class OmniGenTransformer2DModel(ModelMixin, ConfigMixin):
285
285
  """
286
- The Transformer model introduced in OmniGen (https://arxiv.org/pdf/2409.11340).
286
+ The Transformer model introduced in OmniGen (https://huggingface.co/papers/2409.11340).
287
287
 
288
288
  Parameters:
289
289
  in_channels (`int`, defaults to `4`):
@@ -0,0 +1,645 @@
1
+ # Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ import math
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from ...configuration_utils import ConfigMixin, register_to_config
24
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
25
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
26
+ from ...utils.torch_utils import maybe_allow_in_graph
27
+ from ..attention import FeedForward
28
+ from ..attention_dispatch import dispatch_attention_fn
29
+ from ..attention_processor import Attention
30
+ from ..cache_utils import CacheMixin
31
+ from ..embeddings import TimestepEmbedding, Timesteps
32
+ from ..modeling_outputs import Transformer2DModelOutput
33
+ from ..modeling_utils import ModelMixin
34
+ from ..normalization import AdaLayerNormContinuous, RMSNorm
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ def get_timestep_embedding(
41
+ timesteps: torch.Tensor,
42
+ embedding_dim: int,
43
+ flip_sin_to_cos: bool = False,
44
+ downscale_freq_shift: float = 1,
45
+ scale: float = 1,
46
+ max_period: int = 10000,
47
+ ) -> torch.Tensor:
48
+ """
49
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
50
+
51
+ Args
52
+ timesteps (torch.Tensor):
53
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
54
+ embedding_dim (int):
55
+ the dimension of the output.
56
+ flip_sin_to_cos (bool):
57
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
58
+ downscale_freq_shift (float):
59
+ Controls the delta between frequencies between dimensions
60
+ scale (float):
61
+ Scaling factor applied to the embeddings.
62
+ max_period (int):
63
+ Controls the maximum frequency of the embeddings
64
+ Returns
65
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
66
+ """
67
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
68
+
69
+ half_dim = embedding_dim // 2
70
+ exponent = -math.log(max_period) * torch.arange(
71
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
72
+ )
73
+ exponent = exponent / (half_dim - downscale_freq_shift)
74
+
75
+ emb = torch.exp(exponent).to(timesteps.dtype)
76
+ emb = timesteps[:, None].float() * emb[None, :]
77
+
78
+ # scale embeddings
79
+ emb = scale * emb
80
+
81
+ # concat sine and cosine embeddings
82
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
83
+
84
+ # flip sine and cosine embeddings
85
+ if flip_sin_to_cos:
86
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
87
+
88
+ # zero pad
89
+ if embedding_dim % 2 == 1:
90
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
91
+ return emb
92
+
93
+
94
+ def apply_rotary_emb_qwen(
95
+ x: torch.Tensor,
96
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
97
+ use_real: bool = True,
98
+ use_real_unbind_dim: int = -1,
99
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
100
+ """
101
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
102
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
103
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
104
+ tensors contain rotary embeddings and are returned as real tensors.
105
+
106
+ Args:
107
+ x (`torch.Tensor`):
108
+ Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
109
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
110
+
111
+ Returns:
112
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
113
+ """
114
+ if use_real:
115
+ cos, sin = freqs_cis # [S, D]
116
+ cos = cos[None, None]
117
+ sin = sin[None, None]
118
+ cos, sin = cos.to(x.device), sin.to(x.device)
119
+
120
+ if use_real_unbind_dim == -1:
121
+ # Used for flux, cogvideox, hunyuan-dit
122
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
123
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
124
+ elif use_real_unbind_dim == -2:
125
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
126
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
127
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
128
+ else:
129
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
130
+
131
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
132
+
133
+ return out
134
+ else:
135
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
136
+ freqs_cis = freqs_cis.unsqueeze(1)
137
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
138
+
139
+ return x_out.type_as(x)
140
+
141
+
142
+ class QwenTimestepProjEmbeddings(nn.Module):
143
+ def __init__(self, embedding_dim):
144
+ super().__init__()
145
+
146
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
147
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
148
+
149
+ def forward(self, timestep, hidden_states):
150
+ timesteps_proj = self.time_proj(timestep)
151
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
152
+
153
+ conditioning = timesteps_emb
154
+
155
+ return conditioning
156
+
157
+
158
+ class QwenEmbedRope(nn.Module):
159
+ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
160
+ super().__init__()
161
+ self.theta = theta
162
+ self.axes_dim = axes_dim
163
+ pos_index = torch.arange(4096)
164
+ neg_index = torch.arange(4096).flip(0) * -1 - 1
165
+ self.pos_freqs = torch.cat(
166
+ [
167
+ self.rope_params(pos_index, self.axes_dim[0], self.theta),
168
+ self.rope_params(pos_index, self.axes_dim[1], self.theta),
169
+ self.rope_params(pos_index, self.axes_dim[2], self.theta),
170
+ ],
171
+ dim=1,
172
+ )
173
+ self.neg_freqs = torch.cat(
174
+ [
175
+ self.rope_params(neg_index, self.axes_dim[0], self.theta),
176
+ self.rope_params(neg_index, self.axes_dim[1], self.theta),
177
+ self.rope_params(neg_index, self.axes_dim[2], self.theta),
178
+ ],
179
+ dim=1,
180
+ )
181
+ self.rope_cache = {}
182
+
183
+ # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
184
+ self.scale_rope = scale_rope
185
+
186
+ def rope_params(self, index, dim, theta=10000):
187
+ """
188
+ Args:
189
+ index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
190
+ """
191
+ assert dim % 2 == 0
192
+ freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
193
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
194
+ return freqs
195
+
196
+ def forward(self, video_fhw, txt_seq_lens, device):
197
+ """
198
+ Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
199
+ txt_length: [bs] a list of 1 integers representing the length of the text
200
+ """
201
+ if self.pos_freqs.device != device:
202
+ self.pos_freqs = self.pos_freqs.to(device)
203
+ self.neg_freqs = self.neg_freqs.to(device)
204
+
205
+ if isinstance(video_fhw, list):
206
+ video_fhw = video_fhw[0]
207
+ if not isinstance(video_fhw, list):
208
+ video_fhw = [video_fhw]
209
+
210
+ vid_freqs = []
211
+ max_vid_index = 0
212
+ for idx, fhw in enumerate(video_fhw):
213
+ frame, height, width = fhw
214
+ rope_key = f"{idx}_{height}_{width}"
215
+
216
+ if not torch.compiler.is_compiling():
217
+ if rope_key not in self.rope_cache:
218
+ self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
219
+ video_freq = self.rope_cache[rope_key]
220
+ else:
221
+ video_freq = self._compute_video_freqs(frame, height, width, idx)
222
+ video_freq = video_freq.to(device)
223
+ vid_freqs.append(video_freq)
224
+
225
+ if self.scale_rope:
226
+ max_vid_index = max(height // 2, width // 2, max_vid_index)
227
+ else:
228
+ max_vid_index = max(height, width, max_vid_index)
229
+
230
+ max_len = max(txt_seq_lens)
231
+ txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
232
+ vid_freqs = torch.cat(vid_freqs, dim=0)
233
+
234
+ return vid_freqs, txt_freqs
235
+
236
+ @functools.lru_cache(maxsize=None)
237
+ def _compute_video_freqs(self, frame, height, width, idx=0):
238
+ seq_lens = frame * height * width
239
+ freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
240
+ freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
241
+
242
+ freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
243
+ if self.scale_rope:
244
+ freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
245
+ freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
246
+ freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
247
+ freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
248
+ else:
249
+ freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
250
+ freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
251
+
252
+ freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
253
+ return freqs.clone().contiguous()
254
+
255
+
256
+ class QwenDoubleStreamAttnProcessor2_0:
257
+ """
258
+ Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
259
+ implements joint attention computation where text and image streams are processed together.
260
+ """
261
+
262
+ _attention_backend = None
263
+
264
+ def __init__(self):
265
+ if not hasattr(F, "scaled_dot_product_attention"):
266
+ raise ImportError(
267
+ "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
268
+ )
269
+
270
+ def __call__(
271
+ self,
272
+ attn: Attention,
273
+ hidden_states: torch.FloatTensor, # Image stream
274
+ encoder_hidden_states: torch.FloatTensor = None, # Text stream
275
+ encoder_hidden_states_mask: torch.FloatTensor = None,
276
+ attention_mask: Optional[torch.FloatTensor] = None,
277
+ image_rotary_emb: Optional[torch.Tensor] = None,
278
+ ) -> torch.FloatTensor:
279
+ if encoder_hidden_states is None:
280
+ raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
281
+
282
+ seq_txt = encoder_hidden_states.shape[1]
283
+
284
+ # Compute QKV for image stream (sample projections)
285
+ img_query = attn.to_q(hidden_states)
286
+ img_key = attn.to_k(hidden_states)
287
+ img_value = attn.to_v(hidden_states)
288
+
289
+ # Compute QKV for text stream (context projections)
290
+ txt_query = attn.add_q_proj(encoder_hidden_states)
291
+ txt_key = attn.add_k_proj(encoder_hidden_states)
292
+ txt_value = attn.add_v_proj(encoder_hidden_states)
293
+
294
+ # Reshape for multi-head attention
295
+ img_query = img_query.unflatten(-1, (attn.heads, -1))
296
+ img_key = img_key.unflatten(-1, (attn.heads, -1))
297
+ img_value = img_value.unflatten(-1, (attn.heads, -1))
298
+
299
+ txt_query = txt_query.unflatten(-1, (attn.heads, -1))
300
+ txt_key = txt_key.unflatten(-1, (attn.heads, -1))
301
+ txt_value = txt_value.unflatten(-1, (attn.heads, -1))
302
+
303
+ # Apply QK normalization
304
+ if attn.norm_q is not None:
305
+ img_query = attn.norm_q(img_query)
306
+ if attn.norm_k is not None:
307
+ img_key = attn.norm_k(img_key)
308
+ if attn.norm_added_q is not None:
309
+ txt_query = attn.norm_added_q(txt_query)
310
+ if attn.norm_added_k is not None:
311
+ txt_key = attn.norm_added_k(txt_key)
312
+
313
+ # Apply RoPE
314
+ if image_rotary_emb is not None:
315
+ img_freqs, txt_freqs = image_rotary_emb
316
+ img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
317
+ img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
318
+ txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
319
+ txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
320
+
321
+ # Concatenate for joint attention
322
+ # Order: [text, image]
323
+ joint_query = torch.cat([txt_query, img_query], dim=1)
324
+ joint_key = torch.cat([txt_key, img_key], dim=1)
325
+ joint_value = torch.cat([txt_value, img_value], dim=1)
326
+
327
+ # Compute joint attention
328
+ joint_hidden_states = dispatch_attention_fn(
329
+ joint_query,
330
+ joint_key,
331
+ joint_value,
332
+ attn_mask=attention_mask,
333
+ dropout_p=0.0,
334
+ is_causal=False,
335
+ backend=self._attention_backend,
336
+ )
337
+
338
+ # Reshape back
339
+ joint_hidden_states = joint_hidden_states.flatten(2, 3)
340
+ joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
341
+
342
+ # Split attention outputs back
343
+ txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
344
+ img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
345
+
346
+ # Apply output projections
347
+ img_attn_output = attn.to_out[0](img_attn_output)
348
+ if len(attn.to_out) > 1:
349
+ img_attn_output = attn.to_out[1](img_attn_output) # dropout
350
+
351
+ txt_attn_output = attn.to_add_out(txt_attn_output)
352
+
353
+ return img_attn_output, txt_attn_output
354
+
355
+
356
+ @maybe_allow_in_graph
357
+ class QwenImageTransformerBlock(nn.Module):
358
+ def __init__(
359
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
360
+ ):
361
+ super().__init__()
362
+
363
+ self.dim = dim
364
+ self.num_attention_heads = num_attention_heads
365
+ self.attention_head_dim = attention_head_dim
366
+
367
+ # Image processing modules
368
+ self.img_mod = nn.Sequential(
369
+ nn.SiLU(),
370
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
371
+ )
372
+ self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
373
+ self.attn = Attention(
374
+ query_dim=dim,
375
+ cross_attention_dim=None, # Enable cross attention for joint computation
376
+ added_kv_proj_dim=dim, # Enable added KV projections for text stream
377
+ dim_head=attention_head_dim,
378
+ heads=num_attention_heads,
379
+ out_dim=dim,
380
+ context_pre_only=False,
381
+ bias=True,
382
+ processor=QwenDoubleStreamAttnProcessor2_0(),
383
+ qk_norm=qk_norm,
384
+ eps=eps,
385
+ )
386
+ self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
387
+ self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
388
+
389
+ # Text processing modules
390
+ self.txt_mod = nn.Sequential(
391
+ nn.SiLU(),
392
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
393
+ )
394
+ self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
395
+ # Text doesn't need separate attention - it's handled by img_attn joint computation
396
+ self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
397
+ self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
398
+
399
+ def _modulate(self, x, mod_params):
400
+ """Apply modulation to input tensor"""
401
+ shift, scale, gate = mod_params.chunk(3, dim=-1)
402
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
403
+
404
+ def forward(
405
+ self,
406
+ hidden_states: torch.Tensor,
407
+ encoder_hidden_states: torch.Tensor,
408
+ encoder_hidden_states_mask: torch.Tensor,
409
+ temb: torch.Tensor,
410
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
411
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
412
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
413
+ # Get modulation parameters for both streams
414
+ img_mod_params = self.img_mod(temb) # [B, 6*dim]
415
+ txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
416
+
417
+ # Split modulation parameters for norm1 and norm2
418
+ img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
419
+ txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
420
+
421
+ # Process image stream - norm1 + modulation
422
+ img_normed = self.img_norm1(hidden_states)
423
+ img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
424
+
425
+ # Process text stream - norm1 + modulation
426
+ txt_normed = self.txt_norm1(encoder_hidden_states)
427
+ txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
428
+
429
+ # Use QwenAttnProcessor2_0 for joint attention computation
430
+ # This directly implements the DoubleStreamLayerMegatron logic:
431
+ # 1. Computes QKV for both streams
432
+ # 2. Applies QK normalization and RoPE
433
+ # 3. Concatenates and runs joint attention
434
+ # 4. Splits results back to separate streams
435
+ joint_attention_kwargs = joint_attention_kwargs or {}
436
+ attn_output = self.attn(
437
+ hidden_states=img_modulated, # Image stream (will be processed as "sample")
438
+ encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
439
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
440
+ image_rotary_emb=image_rotary_emb,
441
+ **joint_attention_kwargs,
442
+ )
443
+
444
+ # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
445
+ img_attn_output, txt_attn_output = attn_output
446
+
447
+ # Apply attention gates and add residual (like in Megatron)
448
+ hidden_states = hidden_states + img_gate1 * img_attn_output
449
+ encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
450
+
451
+ # Process image stream - norm2 + MLP
452
+ img_normed2 = self.img_norm2(hidden_states)
453
+ img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
454
+ img_mlp_output = self.img_mlp(img_modulated2)
455
+ hidden_states = hidden_states + img_gate2 * img_mlp_output
456
+
457
+ # Process text stream - norm2 + MLP
458
+ txt_normed2 = self.txt_norm2(encoder_hidden_states)
459
+ txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
460
+ txt_mlp_output = self.txt_mlp(txt_modulated2)
461
+ encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
462
+
463
+ # Clip to prevent overflow for fp16
464
+ if encoder_hidden_states.dtype == torch.float16:
465
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
466
+ if hidden_states.dtype == torch.float16:
467
+ hidden_states = hidden_states.clip(-65504, 65504)
468
+
469
+ return encoder_hidden_states, hidden_states
470
+
471
+
472
+ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
473
+ """
474
+ The Transformer model introduced in Qwen.
475
+
476
+ Args:
477
+ patch_size (`int`, defaults to `2`):
478
+ Patch size to turn the input data into small patches.
479
+ in_channels (`int`, defaults to `64`):
480
+ The number of channels in the input.
481
+ out_channels (`int`, *optional*, defaults to `None`):
482
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
483
+ num_layers (`int`, defaults to `60`):
484
+ The number of layers of dual stream DiT blocks to use.
485
+ attention_head_dim (`int`, defaults to `128`):
486
+ The number of dimensions to use for each attention head.
487
+ num_attention_heads (`int`, defaults to `24`):
488
+ The number of attention heads to use.
489
+ joint_attention_dim (`int`, defaults to `3584`):
490
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
491
+ `encoder_hidden_states`).
492
+ guidance_embeds (`bool`, defaults to `False`):
493
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
494
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
495
+ The dimensions to use for the rotary positional embeddings.
496
+ """
497
+
498
+ _supports_gradient_checkpointing = True
499
+ _no_split_modules = ["QwenImageTransformerBlock"]
500
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
501
+ _repeated_blocks = ["QwenImageTransformerBlock"]
502
+
503
+ @register_to_config
504
+ def __init__(
505
+ self,
506
+ patch_size: int = 2,
507
+ in_channels: int = 64,
508
+ out_channels: Optional[int] = 16,
509
+ num_layers: int = 60,
510
+ attention_head_dim: int = 128,
511
+ num_attention_heads: int = 24,
512
+ joint_attention_dim: int = 3584,
513
+ guidance_embeds: bool = False, # TODO: this should probably be removed
514
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
515
+ ):
516
+ super().__init__()
517
+ self.out_channels = out_channels or in_channels
518
+ self.inner_dim = num_attention_heads * attention_head_dim
519
+
520
+ self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
521
+
522
+ self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
523
+
524
+ self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
525
+
526
+ self.img_in = nn.Linear(in_channels, self.inner_dim)
527
+ self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
528
+
529
+ self.transformer_blocks = nn.ModuleList(
530
+ [
531
+ QwenImageTransformerBlock(
532
+ dim=self.inner_dim,
533
+ num_attention_heads=num_attention_heads,
534
+ attention_head_dim=attention_head_dim,
535
+ )
536
+ for _ in range(num_layers)
537
+ ]
538
+ )
539
+
540
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
541
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
542
+
543
+ self.gradient_checkpointing = False
544
+
545
+ def forward(
546
+ self,
547
+ hidden_states: torch.Tensor,
548
+ encoder_hidden_states: torch.Tensor = None,
549
+ encoder_hidden_states_mask: torch.Tensor = None,
550
+ timestep: torch.LongTensor = None,
551
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
552
+ txt_seq_lens: Optional[List[int]] = None,
553
+ guidance: torch.Tensor = None, # TODO: this should probably be removed
554
+ attention_kwargs: Optional[Dict[str, Any]] = None,
555
+ return_dict: bool = True,
556
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
557
+ """
558
+ The [`QwenTransformer2DModel`] forward method.
559
+
560
+ Args:
561
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
562
+ Input `hidden_states`.
563
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
564
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
565
+ encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
566
+ Mask of the input conditions.
567
+ timestep ( `torch.LongTensor`):
568
+ Used to indicate denoising step.
569
+ attention_kwargs (`dict`, *optional*):
570
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
571
+ `self.processor` in
572
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
573
+ return_dict (`bool`, *optional*, defaults to `True`):
574
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
575
+ tuple.
576
+
577
+ Returns:
578
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
579
+ `tuple` where the first element is the sample tensor.
580
+ """
581
+ if attention_kwargs is not None:
582
+ attention_kwargs = attention_kwargs.copy()
583
+ lora_scale = attention_kwargs.pop("scale", 1.0)
584
+ else:
585
+ lora_scale = 1.0
586
+
587
+ if USE_PEFT_BACKEND:
588
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
589
+ scale_lora_layers(self, lora_scale)
590
+ else:
591
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
592
+ logger.warning(
593
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
594
+ )
595
+
596
+ hidden_states = self.img_in(hidden_states)
597
+
598
+ timestep = timestep.to(hidden_states.dtype)
599
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
600
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
601
+
602
+ if guidance is not None:
603
+ guidance = guidance.to(hidden_states.dtype) * 1000
604
+
605
+ temb = (
606
+ self.time_text_embed(timestep, hidden_states)
607
+ if guidance is None
608
+ else self.time_text_embed(timestep, guidance, hidden_states)
609
+ )
610
+
611
+ image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
612
+
613
+ for index_block, block in enumerate(self.transformer_blocks):
614
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
615
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
616
+ block,
617
+ hidden_states,
618
+ encoder_hidden_states,
619
+ encoder_hidden_states_mask,
620
+ temb,
621
+ image_rotary_emb,
622
+ )
623
+
624
+ else:
625
+ encoder_hidden_states, hidden_states = block(
626
+ hidden_states=hidden_states,
627
+ encoder_hidden_states=encoder_hidden_states,
628
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
629
+ temb=temb,
630
+ image_rotary_emb=image_rotary_emb,
631
+ joint_attention_kwargs=attention_kwargs,
632
+ )
633
+
634
+ # Use only the image part (hidden_states) from the dual-stream blocks
635
+ hidden_states = self.norm_out(hidden_states, temb)
636
+ output = self.proj_out(hidden_states)
637
+
638
+ if USE_PEFT_BACKEND:
639
+ # remove `lora_scale` from each PEFT layer
640
+ unscale_lora_layers(self, lora_scale)
641
+
642
+ if not return_dict:
643
+ return (output,)
644
+
645
+ return Transformer2DModelOutput(sample=output)