diffusers 0.33.0__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (478) hide show
  1. diffusers/__init__.py +48 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/dependency_versions_check.py +1 -1
  7. diffusers/dependency_versions_table.py +1 -1
  8. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  9. diffusers/hooks/faster_cache.py +2 -2
  10. diffusers/hooks/group_offloading.py +128 -29
  11. diffusers/hooks/hooks.py +2 -2
  12. diffusers/hooks/layerwise_casting.py +3 -3
  13. diffusers/hooks/pyramid_attention_broadcast.py +1 -1
  14. diffusers/image_processor.py +7 -2
  15. diffusers/loaders/__init__.py +4 -0
  16. diffusers/loaders/ip_adapter.py +5 -14
  17. diffusers/loaders/lora_base.py +212 -111
  18. diffusers/loaders/lora_conversion_utils.py +275 -34
  19. diffusers/loaders/lora_pipeline.py +1554 -819
  20. diffusers/loaders/peft.py +52 -109
  21. diffusers/loaders/single_file.py +2 -2
  22. diffusers/loaders/single_file_model.py +20 -4
  23. diffusers/loaders/single_file_utils.py +225 -5
  24. diffusers/loaders/textual_inversion.py +3 -2
  25. diffusers/loaders/transformer_flux.py +1 -1
  26. diffusers/loaders/transformer_sd3.py +2 -2
  27. diffusers/loaders/unet.py +2 -16
  28. diffusers/loaders/unet_loader_utils.py +1 -1
  29. diffusers/loaders/utils.py +1 -1
  30. diffusers/models/__init__.py +15 -1
  31. diffusers/models/activations.py +5 -5
  32. diffusers/models/adapter.py +2 -3
  33. diffusers/models/attention.py +4 -4
  34. diffusers/models/attention_flax.py +10 -10
  35. diffusers/models/attention_processor.py +14 -10
  36. diffusers/models/auto_model.py +47 -10
  37. diffusers/models/autoencoders/__init__.py +1 -0
  38. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  39. diffusers/models/autoencoders/autoencoder_dc.py +3 -3
  40. diffusers/models/autoencoders/autoencoder_kl.py +4 -4
  41. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  42. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  43. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
  44. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  45. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  46. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  47. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  48. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  49. diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
  50. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  52. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  53. diffusers/models/autoencoders/vae.py +13 -2
  54. diffusers/models/autoencoders/vq_model.py +2 -2
  55. diffusers/models/cache_utils.py +1 -1
  56. diffusers/models/controlnet.py +1 -1
  57. diffusers/models/controlnet_flux.py +1 -1
  58. diffusers/models/controlnet_sd3.py +1 -1
  59. diffusers/models/controlnet_sparsectrl.py +1 -1
  60. diffusers/models/controlnets/__init__.py +1 -0
  61. diffusers/models/controlnets/controlnet.py +3 -3
  62. diffusers/models/controlnets/controlnet_flax.py +1 -1
  63. diffusers/models/controlnets/controlnet_flux.py +16 -15
  64. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  65. diffusers/models/controlnets/controlnet_sana.py +290 -0
  66. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  67. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  68. diffusers/models/controlnets/controlnet_union.py +1 -1
  69. diffusers/models/controlnets/controlnet_xs.py +7 -7
  70. diffusers/models/controlnets/multicontrolnet.py +4 -5
  71. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  72. diffusers/models/downsampling.py +2 -2
  73. diffusers/models/embeddings.py +10 -12
  74. diffusers/models/embeddings_flax.py +2 -2
  75. diffusers/models/lora.py +3 -3
  76. diffusers/models/modeling_utils.py +44 -14
  77. diffusers/models/normalization.py +4 -4
  78. diffusers/models/resnet.py +2 -2
  79. diffusers/models/resnet_flax.py +1 -1
  80. diffusers/models/transformers/__init__.py +5 -0
  81. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  82. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  83. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  84. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  85. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  86. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  87. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  88. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  89. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  90. diffusers/models/transformers/prior_transformer.py +1 -1
  91. diffusers/models/transformers/sana_transformer.py +8 -3
  92. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  93. diffusers/models/transformers/t5_film_transformer.py +3 -3
  94. diffusers/models/transformers/transformer_2d.py +1 -1
  95. diffusers/models/transformers/transformer_allegro.py +1 -1
  96. diffusers/models/transformers/transformer_chroma.py +742 -0
  97. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  98. diffusers/models/transformers/transformer_cogview4.py +317 -25
  99. diffusers/models/transformers/transformer_cosmos.py +579 -0
  100. diffusers/models/transformers/transformer_flux.py +9 -11
  101. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  102. diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
  103. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  104. diffusers/models/transformers/transformer_ltx.py +2 -2
  105. diffusers/models/transformers/transformer_lumina2.py +1 -1
  106. diffusers/models/transformers/transformer_mochi.py +1 -1
  107. diffusers/models/transformers/transformer_omnigen.py +2 -2
  108. diffusers/models/transformers/transformer_sd3.py +7 -7
  109. diffusers/models/transformers/transformer_temporal.py +1 -1
  110. diffusers/models/transformers/transformer_wan.py +24 -8
  111. diffusers/models/transformers/transformer_wan_vace.py +393 -0
  112. diffusers/models/unets/unet_1d.py +1 -1
  113. diffusers/models/unets/unet_1d_blocks.py +1 -1
  114. diffusers/models/unets/unet_2d.py +1 -1
  115. diffusers/models/unets/unet_2d_blocks.py +1 -1
  116. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  117. diffusers/models/unets/unet_2d_condition.py +2 -2
  118. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  119. diffusers/models/unets/unet_3d_blocks.py +1 -1
  120. diffusers/models/unets/unet_3d_condition.py +3 -3
  121. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  122. diffusers/models/unets/unet_kandinsky3.py +1 -1
  123. diffusers/models/unets/unet_motion_model.py +2 -2
  124. diffusers/models/unets/unet_stable_cascade.py +1 -1
  125. diffusers/models/upsampling.py +2 -2
  126. diffusers/models/vae_flax.py +2 -2
  127. diffusers/models/vq_model.py +1 -1
  128. diffusers/pipelines/__init__.py +37 -6
  129. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  130. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  131. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  132. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  133. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  134. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  135. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  136. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  137. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  138. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  139. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  140. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  141. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
  142. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  143. diffusers/pipelines/auto_pipeline.py +6 -7
  144. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  145. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  146. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  147. diffusers/pipelines/chroma/__init__.py +49 -0
  148. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  149. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  150. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  151. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
  152. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
  153. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
  154. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
  155. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  156. diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
  157. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  158. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  159. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  160. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  161. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  162. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
  163. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  164. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  165. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  166. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  167. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  168. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
  169. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
  170. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
  171. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  172. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  173. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  174. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  175. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  176. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  177. diffusers/pipelines/cosmos/__init__.py +54 -0
  178. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  179. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  180. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  181. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  182. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  183. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  184. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  185. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  186. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  187. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  188. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  189. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  190. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  191. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  192. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  193. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  194. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  195. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  196. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  197. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  198. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  199. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  200. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  201. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  202. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  203. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  204. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
  205. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  206. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  207. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  208. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  209. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  210. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  211. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  212. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  213. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  214. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  215. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  216. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  217. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  218. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  219. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  220. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  221. diffusers/pipelines/flux/modeling_flux.py +1 -1
  222. diffusers/pipelines/flux/pipeline_flux.py +10 -17
  223. diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
  224. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
  225. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
  226. diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
  227. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
  228. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
  229. diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
  230. diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
  231. diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
  232. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  233. diffusers/pipelines/free_init_utils.py +2 -2
  234. diffusers/pipelines/free_noise_utils.py +3 -3
  235. diffusers/pipelines/hidream_image/__init__.py +47 -0
  236. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  237. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  238. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  239. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  240. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
  241. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  242. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  243. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  244. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  245. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  246. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  247. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  248. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  249. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  250. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  251. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  252. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  253. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  254. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  255. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  256. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  257. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  258. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  259. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  260. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  261. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  262. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  263. diffusers/pipelines/kolors/text_encoder.py +3 -3
  264. diffusers/pipelines/kolors/tokenizer.py +1 -1
  265. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  266. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  267. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  268. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  269. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  270. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  271. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  272. diffusers/pipelines/ltx/__init__.py +4 -0
  273. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  274. diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
  275. diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
  276. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
  277. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  278. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  279. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  280. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  281. diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
  282. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  283. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  284. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  285. diffusers/pipelines/onnx_utils.py +15 -2
  286. diffusers/pipelines/pag/pag_utils.py +2 -2
  287. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  288. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  289. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  290. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  291. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  292. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  293. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  294. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  295. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  296. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  297. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  298. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  299. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  300. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  301. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  302. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  303. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  304. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  305. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  306. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  307. diffusers/pipelines/pipeline_flax_utils.py +3 -4
  308. diffusers/pipelines/pipeline_loading_utils.py +89 -13
  309. diffusers/pipelines/pipeline_utils.py +105 -33
  310. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
  311. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
  312. diffusers/pipelines/sana/__init__.py +4 -0
  313. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  314. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  315. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  316. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  317. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  318. diffusers/pipelines/shap_e/camera.py +1 -1
  319. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  320. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  321. diffusers/pipelines/shap_e/renderer.py +3 -3
  322. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  323. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  324. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  325. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  326. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  327. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  328. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  329. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  330. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  331. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  332. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  333. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
  334. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  335. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
  336. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
  337. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
  338. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  339. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  340. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  341. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  342. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  343. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  344. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  345. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  346. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  347. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  348. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  349. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  350. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
  351. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  352. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  353. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  354. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  355. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  356. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  357. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  358. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  359. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  360. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  361. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  362. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  363. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  364. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  365. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  366. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  367. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  368. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  369. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  370. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  371. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  372. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  373. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  374. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  375. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  376. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  377. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  378. diffusers/pipelines/unclip/text_proj.py +2 -2
  379. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  380. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  381. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  382. diffusers/pipelines/visualcloze/__init__.py +52 -0
  383. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  384. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  385. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  386. diffusers/pipelines/wan/__init__.py +2 -0
  387. diffusers/pipelines/wan/pipeline_wan.py +17 -12
  388. diffusers/pipelines/wan/pipeline_wan_i2v.py +42 -20
  389. diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
  390. diffusers/pipelines/wan/pipeline_wan_video2video.py +18 -18
  391. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  392. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  393. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  394. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  395. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  396. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  397. diffusers/quantizers/__init__.py +179 -1
  398. diffusers/quantizers/base.py +6 -1
  399. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  400. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  401. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  402. diffusers/quantizers/gguf/utils.py +16 -13
  403. diffusers/quantizers/quantization_config.py +18 -16
  404. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  405. diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
  406. diffusers/schedulers/__init__.py +3 -1
  407. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  408. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  409. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  410. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  411. diffusers/schedulers/scheduling_ddim.py +8 -8
  412. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  413. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  414. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  415. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  416. diffusers/schedulers/scheduling_ddpm.py +9 -9
  417. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  418. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  419. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  420. diffusers/schedulers/scheduling_deis_multistep.py +8 -8
  421. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  422. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
  423. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  424. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  425. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  426. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
  427. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  428. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  429. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  430. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  431. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  432. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  433. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  434. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  435. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  436. diffusers/schedulers/scheduling_ipndm.py +2 -2
  437. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  438. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  439. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  440. diffusers/schedulers/scheduling_lcm.py +3 -3
  441. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  442. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  443. diffusers/schedulers/scheduling_pndm.py +4 -4
  444. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  445. diffusers/schedulers/scheduling_repaint.py +9 -9
  446. diffusers/schedulers/scheduling_sasolver.py +15 -15
  447. diffusers/schedulers/scheduling_scm.py +1 -1
  448. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  449. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  450. diffusers/schedulers/scheduling_tcd.py +3 -3
  451. diffusers/schedulers/scheduling_unclip.py +5 -5
  452. diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
  453. diffusers/schedulers/scheduling_utils.py +1 -1
  454. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  455. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  456. diffusers/training_utils.py +13 -5
  457. diffusers/utils/__init__.py +5 -0
  458. diffusers/utils/accelerate_utils.py +1 -1
  459. diffusers/utils/doc_utils.py +1 -1
  460. diffusers/utils/dummy_pt_objects.py +120 -0
  461. diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
  462. diffusers/utils/dynamic_modules_utils.py +21 -3
  463. diffusers/utils/export_utils.py +1 -1
  464. diffusers/utils/import_utils.py +81 -18
  465. diffusers/utils/logging.py +1 -1
  466. diffusers/utils/outputs.py +2 -1
  467. diffusers/utils/peft_utils.py +91 -8
  468. diffusers/utils/state_dict_utils.py +20 -3
  469. diffusers/utils/testing_utils.py +59 -7
  470. diffusers/utils/torch_utils.py +25 -5
  471. diffusers/video_processor.py +2 -2
  472. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/METADATA +3 -3
  473. diffusers-0.34.0.dist-info/RECORD +639 -0
  474. diffusers-0.33.0.dist-info/RECORD +0 -608
  475. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
  476. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/WHEEL +0 -0
  477. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
  478. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,942 @@
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from ...configuration_utils import ConfigMixin, register_to_config
8
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
9
+ from ...models.modeling_outputs import Transformer2DModelOutput
10
+ from ...models.modeling_utils import ModelMixin
11
+ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
12
+ from ...utils.torch_utils import maybe_allow_in_graph
13
+ from ..attention import Attention
14
+ from ..embeddings import TimestepEmbedding, Timesteps
15
+
16
+
17
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
+
19
+
20
+ class HiDreamImageFeedForwardSwiGLU(nn.Module):
21
+ def __init__(
22
+ self,
23
+ dim: int,
24
+ hidden_dim: int,
25
+ multiple_of: int = 256,
26
+ ffn_dim_multiplier: Optional[float] = None,
27
+ ):
28
+ super().__init__()
29
+ hidden_dim = int(2 * hidden_dim / 3)
30
+ # custom dim factor multiplier
31
+ if ffn_dim_multiplier is not None:
32
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
33
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
34
+
35
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
36
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
37
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
41
+
42
+
43
+ class HiDreamImagePooledEmbed(nn.Module):
44
+ def __init__(self, text_emb_dim, hidden_size):
45
+ super().__init__()
46
+ self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size)
47
+
48
+ def forward(self, pooled_embed: torch.Tensor) -> torch.Tensor:
49
+ return self.pooled_embedder(pooled_embed)
50
+
51
+
52
+ class HiDreamImageTimestepEmbed(nn.Module):
53
+ def __init__(self, hidden_size, frequency_embedding_size=256):
54
+ super().__init__()
55
+ self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
56
+ self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
57
+
58
+ def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None):
59
+ t_emb = self.time_proj(timesteps).to(dtype=wdtype)
60
+ t_emb = self.timestep_embedder(t_emb)
61
+ return t_emb
62
+
63
+
64
+ class HiDreamImageOutEmbed(nn.Module):
65
+ def __init__(self, hidden_size, patch_size, out_channels):
66
+ super().__init__()
67
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
68
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
69
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
70
+
71
+ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
72
+ shift, scale = self.adaLN_modulation(temb).chunk(2, dim=1)
73
+ hidden_states = self.norm_final(hidden_states) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
74
+ hidden_states = self.linear(hidden_states)
75
+ return hidden_states
76
+
77
+
78
+ class HiDreamImagePatchEmbed(nn.Module):
79
+ def __init__(
80
+ self,
81
+ patch_size=2,
82
+ in_channels=4,
83
+ out_channels=1024,
84
+ ):
85
+ super().__init__()
86
+ self.patch_size = patch_size
87
+ self.out_channels = out_channels
88
+ self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
89
+
90
+ def forward(self, latent):
91
+ latent = self.proj(latent)
92
+ return latent
93
+
94
+
95
+ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
96
+ assert dim % 2 == 0, "The dimension must be even."
97
+
98
+ is_mps = pos.device.type == "mps"
99
+ is_npu = pos.device.type == "npu"
100
+
101
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
102
+
103
+ scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim
104
+ omega = 1.0 / (theta**scale)
105
+
106
+ batch_size, seq_length = pos.shape
107
+ out = torch.einsum("...n,d->...nd", pos, omega)
108
+ cos_out = torch.cos(out)
109
+ sin_out = torch.sin(out)
110
+
111
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
112
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
113
+ return out.float()
114
+
115
+
116
+ class HiDreamImageEmbedND(nn.Module):
117
+ def __init__(self, theta: int, axes_dim: List[int]):
118
+ super().__init__()
119
+ self.theta = theta
120
+ self.axes_dim = axes_dim
121
+
122
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
123
+ n_axes = ids.shape[-1]
124
+ emb = torch.cat(
125
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
126
+ dim=-3,
127
+ )
128
+ return emb.unsqueeze(2)
129
+
130
+
131
+ def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
132
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
133
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
134
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
135
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
136
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
137
+
138
+
139
+ @maybe_allow_in_graph
140
+ class HiDreamAttention(Attention):
141
+ def __init__(
142
+ self,
143
+ query_dim: int,
144
+ heads: int = 8,
145
+ dim_head: int = 64,
146
+ upcast_attention: bool = False,
147
+ upcast_softmax: bool = False,
148
+ scale_qk: bool = True,
149
+ eps: float = 1e-5,
150
+ processor=None,
151
+ out_dim: int = None,
152
+ single: bool = False,
153
+ ):
154
+ super(Attention, self).__init__()
155
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
156
+ self.query_dim = query_dim
157
+ self.upcast_attention = upcast_attention
158
+ self.upcast_softmax = upcast_softmax
159
+ self.out_dim = out_dim if out_dim is not None else query_dim
160
+
161
+ self.scale_qk = scale_qk
162
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
163
+
164
+ self.heads = out_dim // dim_head if out_dim is not None else heads
165
+ self.sliceable_head_dim = heads
166
+ self.single = single
167
+
168
+ self.to_q = nn.Linear(query_dim, self.inner_dim)
169
+ self.to_k = nn.Linear(self.inner_dim, self.inner_dim)
170
+ self.to_v = nn.Linear(self.inner_dim, self.inner_dim)
171
+ self.to_out = nn.Linear(self.inner_dim, self.out_dim)
172
+ self.q_rms_norm = nn.RMSNorm(self.inner_dim, eps)
173
+ self.k_rms_norm = nn.RMSNorm(self.inner_dim, eps)
174
+
175
+ if not single:
176
+ self.to_q_t = nn.Linear(query_dim, self.inner_dim)
177
+ self.to_k_t = nn.Linear(self.inner_dim, self.inner_dim)
178
+ self.to_v_t = nn.Linear(self.inner_dim, self.inner_dim)
179
+ self.to_out_t = nn.Linear(self.inner_dim, self.out_dim)
180
+ self.q_rms_norm_t = nn.RMSNorm(self.inner_dim, eps)
181
+ self.k_rms_norm_t = nn.RMSNorm(self.inner_dim, eps)
182
+
183
+ self.set_processor(processor)
184
+
185
+ def forward(
186
+ self,
187
+ norm_hidden_states: torch.Tensor,
188
+ hidden_states_masks: torch.Tensor = None,
189
+ norm_encoder_hidden_states: torch.Tensor = None,
190
+ image_rotary_emb: torch.Tensor = None,
191
+ ) -> torch.Tensor:
192
+ return self.processor(
193
+ self,
194
+ hidden_states=norm_hidden_states,
195
+ hidden_states_masks=hidden_states_masks,
196
+ encoder_hidden_states=norm_encoder_hidden_states,
197
+ image_rotary_emb=image_rotary_emb,
198
+ )
199
+
200
+
201
+ class HiDreamAttnProcessor:
202
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
203
+
204
+ def __call__(
205
+ self,
206
+ attn: HiDreamAttention,
207
+ hidden_states: torch.Tensor,
208
+ hidden_states_masks: Optional[torch.Tensor] = None,
209
+ encoder_hidden_states: Optional[torch.Tensor] = None,
210
+ image_rotary_emb: torch.Tensor = None,
211
+ *args,
212
+ **kwargs,
213
+ ) -> torch.Tensor:
214
+ dtype = hidden_states.dtype
215
+ batch_size = hidden_states.shape[0]
216
+
217
+ query_i = attn.q_rms_norm(attn.to_q(hidden_states)).to(dtype=dtype)
218
+ key_i = attn.k_rms_norm(attn.to_k(hidden_states)).to(dtype=dtype)
219
+ value_i = attn.to_v(hidden_states)
220
+
221
+ inner_dim = key_i.shape[-1]
222
+ head_dim = inner_dim // attn.heads
223
+
224
+ query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
225
+ key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
226
+ value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
227
+ if hidden_states_masks is not None:
228
+ key_i = key_i * hidden_states_masks.view(batch_size, -1, 1, 1)
229
+
230
+ if not attn.single:
231
+ query_t = attn.q_rms_norm_t(attn.to_q_t(encoder_hidden_states)).to(dtype=dtype)
232
+ key_t = attn.k_rms_norm_t(attn.to_k_t(encoder_hidden_states)).to(dtype=dtype)
233
+ value_t = attn.to_v_t(encoder_hidden_states)
234
+
235
+ query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
236
+ key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
237
+ value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
238
+
239
+ num_image_tokens = query_i.shape[1]
240
+ num_text_tokens = query_t.shape[1]
241
+ query = torch.cat([query_i, query_t], dim=1)
242
+ key = torch.cat([key_i, key_t], dim=1)
243
+ value = torch.cat([value_i, value_t], dim=1)
244
+ else:
245
+ query = query_i
246
+ key = key_i
247
+ value = value_i
248
+
249
+ if query.shape[-1] == image_rotary_emb.shape[-3] * 2:
250
+ query, key = apply_rope(query, key, image_rotary_emb)
251
+
252
+ else:
253
+ query_1, query_2 = query.chunk(2, dim=-1)
254
+ key_1, key_2 = key.chunk(2, dim=-1)
255
+ query_1, key_1 = apply_rope(query_1, key_1, image_rotary_emb)
256
+ query = torch.cat([query_1, query_2], dim=-1)
257
+ key = torch.cat([key_1, key_2], dim=-1)
258
+
259
+ hidden_states = F.scaled_dot_product_attention(
260
+ query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), dropout_p=0.0, is_causal=False
261
+ )
262
+
263
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
264
+ hidden_states = hidden_states.to(query.dtype)
265
+
266
+ if not attn.single:
267
+ hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
268
+ hidden_states_i = attn.to_out(hidden_states_i)
269
+ hidden_states_t = attn.to_out_t(hidden_states_t)
270
+ return hidden_states_i, hidden_states_t
271
+ else:
272
+ hidden_states = attn.to_out(hidden_states)
273
+ return hidden_states
274
+
275
+
276
+ # Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
277
+ class MoEGate(nn.Module):
278
+ def __init__(
279
+ self,
280
+ embed_dim,
281
+ num_routed_experts=4,
282
+ num_activated_experts=2,
283
+ aux_loss_alpha=0.01,
284
+ _force_inference_output=False,
285
+ ):
286
+ super().__init__()
287
+ self.top_k = num_activated_experts
288
+ self.n_routed_experts = num_routed_experts
289
+
290
+ self.scoring_func = "softmax"
291
+ self.alpha = aux_loss_alpha
292
+ self.seq_aux = False
293
+
294
+ # topk selection algorithm
295
+ self.norm_topk_prob = False
296
+ self.gating_dim = embed_dim
297
+ self.weight = nn.Parameter(torch.randn(self.n_routed_experts, self.gating_dim) / embed_dim**0.5)
298
+
299
+ self._force_inference_output = _force_inference_output
300
+
301
+ def forward(self, hidden_states):
302
+ bsz, seq_len, h = hidden_states.shape
303
+ ### compute gating score
304
+ hidden_states = hidden_states.view(-1, h)
305
+ logits = F.linear(hidden_states, self.weight, None)
306
+ if self.scoring_func == "softmax":
307
+ scores = logits.softmax(dim=-1)
308
+ else:
309
+ raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
310
+
311
+ ### select top-k experts
312
+ topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
313
+
314
+ ### norm gate to sum 1
315
+ if self.top_k > 1 and self.norm_topk_prob:
316
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
317
+ topk_weight = topk_weight / denominator
318
+
319
+ ### expert-level computation auxiliary loss
320
+ if self.training and self.alpha > 0.0 and not self._force_inference_output:
321
+ scores_for_aux = scores
322
+ aux_topk = self.top_k
323
+ # always compute aux loss based on the naive greedy topk method
324
+ topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
325
+ if self.seq_aux:
326
+ scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
327
+ ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
328
+ ce.scatter_add_(
329
+ 1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)
330
+ ).div_(seq_len * aux_topk / self.n_routed_experts)
331
+ aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
332
+ else:
333
+ mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
334
+ ce = mask_ce.float().mean(0)
335
+
336
+ Pi = scores_for_aux.mean(0)
337
+ fi = ce * self.n_routed_experts
338
+ aux_loss = (Pi * fi).sum() * self.alpha
339
+ else:
340
+ aux_loss = None
341
+ return topk_idx, topk_weight, aux_loss
342
+
343
+
344
+ # Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
345
+ class MOEFeedForwardSwiGLU(nn.Module):
346
+ def __init__(
347
+ self,
348
+ dim: int,
349
+ hidden_dim: int,
350
+ num_routed_experts: int,
351
+ num_activated_experts: int,
352
+ _force_inference_output: bool = False,
353
+ ):
354
+ super().__init__()
355
+ self.shared_experts = HiDreamImageFeedForwardSwiGLU(dim, hidden_dim // 2)
356
+ self.experts = nn.ModuleList(
357
+ [HiDreamImageFeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)]
358
+ )
359
+ self._force_inference_output = _force_inference_output
360
+ self.gate = MoEGate(
361
+ embed_dim=dim,
362
+ num_routed_experts=num_routed_experts,
363
+ num_activated_experts=num_activated_experts,
364
+ _force_inference_output=_force_inference_output,
365
+ )
366
+ self.num_activated_experts = num_activated_experts
367
+
368
+ def forward(self, x):
369
+ wtype = x.dtype
370
+ identity = x
371
+ orig_shape = x.shape
372
+ topk_idx, topk_weight, aux_loss = self.gate(x)
373
+ x = x.view(-1, x.shape[-1])
374
+ flat_topk_idx = topk_idx.view(-1)
375
+ if self.training and not self._force_inference_output:
376
+ x = x.repeat_interleave(self.num_activated_experts, dim=0)
377
+ y = torch.empty_like(x, dtype=wtype)
378
+ for i, expert in enumerate(self.experts):
379
+ y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
380
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
381
+ y = y.view(*orig_shape).to(dtype=wtype)
382
+ # y = AddAuxiliaryLoss.apply(y, aux_loss)
383
+ else:
384
+ y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
385
+ y = y + self.shared_experts(identity)
386
+ return y
387
+
388
+ @torch.no_grad()
389
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
390
+ expert_cache = torch.zeros_like(x)
391
+ idxs = flat_expert_indices.argsort()
392
+ tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
393
+ token_idxs = idxs // self.num_activated_experts
394
+ for i, end_idx in enumerate(tokens_per_expert):
395
+ start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
396
+ if start_idx == end_idx:
397
+ continue
398
+ expert = self.experts[i]
399
+ exp_token_idx = token_idxs[start_idx:end_idx]
400
+ expert_tokens = x[exp_token_idx]
401
+ expert_out = expert(expert_tokens)
402
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
403
+
404
+ # for fp16 and other dtype
405
+ expert_cache = expert_cache.to(expert_out.dtype)
406
+ expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum")
407
+ return expert_cache
408
+
409
+
410
+ class TextProjection(nn.Module):
411
+ def __init__(self, in_features, hidden_size):
412
+ super().__init__()
413
+ self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
414
+
415
+ def forward(self, caption):
416
+ hidden_states = self.linear(caption)
417
+ return hidden_states
418
+
419
+
420
+ @maybe_allow_in_graph
421
+ class HiDreamImageSingleTransformerBlock(nn.Module):
422
+ def __init__(
423
+ self,
424
+ dim: int,
425
+ num_attention_heads: int,
426
+ attention_head_dim: int,
427
+ num_routed_experts: int = 4,
428
+ num_activated_experts: int = 2,
429
+ _force_inference_output: bool = False,
430
+ ):
431
+ super().__init__()
432
+ self.num_attention_heads = num_attention_heads
433
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
434
+
435
+ # 1. Attention
436
+ self.norm1_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
437
+ self.attn1 = HiDreamAttention(
438
+ query_dim=dim,
439
+ heads=num_attention_heads,
440
+ dim_head=attention_head_dim,
441
+ processor=HiDreamAttnProcessor(),
442
+ single=True,
443
+ )
444
+
445
+ # 3. Feed-forward
446
+ self.norm3_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
447
+ if num_routed_experts > 0:
448
+ self.ff_i = MOEFeedForwardSwiGLU(
449
+ dim=dim,
450
+ hidden_dim=4 * dim,
451
+ num_routed_experts=num_routed_experts,
452
+ num_activated_experts=num_activated_experts,
453
+ _force_inference_output=_force_inference_output,
454
+ )
455
+ else:
456
+ self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
457
+
458
+ def forward(
459
+ self,
460
+ hidden_states: torch.Tensor,
461
+ hidden_states_masks: Optional[torch.Tensor] = None,
462
+ encoder_hidden_states: Optional[torch.Tensor] = None,
463
+ temb: Optional[torch.Tensor] = None,
464
+ image_rotary_emb: torch.Tensor = None,
465
+ ) -> torch.Tensor:
466
+ wtype = hidden_states.dtype
467
+ shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = self.adaLN_modulation(temb)[
468
+ :, None
469
+ ].chunk(6, dim=-1)
470
+
471
+ # 1. MM-Attention
472
+ norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
473
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
474
+ attn_output_i = self.attn1(
475
+ norm_hidden_states,
476
+ hidden_states_masks,
477
+ image_rotary_emb=image_rotary_emb,
478
+ )
479
+ hidden_states = gate_msa_i * attn_output_i + hidden_states
480
+
481
+ # 2. Feed-forward
482
+ norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
483
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
484
+ ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype))
485
+ hidden_states = ff_output_i + hidden_states
486
+ return hidden_states
487
+
488
+
489
+ @maybe_allow_in_graph
490
+ class HiDreamImageTransformerBlock(nn.Module):
491
+ def __init__(
492
+ self,
493
+ dim: int,
494
+ num_attention_heads: int,
495
+ attention_head_dim: int,
496
+ num_routed_experts: int = 4,
497
+ num_activated_experts: int = 2,
498
+ _force_inference_output: bool = False,
499
+ ):
500
+ super().__init__()
501
+ self.num_attention_heads = num_attention_heads
502
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 12 * dim, bias=True))
503
+
504
+ # 1. Attention
505
+ self.norm1_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
506
+ self.norm1_t = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
507
+ self.attn1 = HiDreamAttention(
508
+ query_dim=dim,
509
+ heads=num_attention_heads,
510
+ dim_head=attention_head_dim,
511
+ processor=HiDreamAttnProcessor(),
512
+ single=False,
513
+ )
514
+
515
+ # 3. Feed-forward
516
+ self.norm3_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
517
+ if num_routed_experts > 0:
518
+ self.ff_i = MOEFeedForwardSwiGLU(
519
+ dim=dim,
520
+ hidden_dim=4 * dim,
521
+ num_routed_experts=num_routed_experts,
522
+ num_activated_experts=num_activated_experts,
523
+ _force_inference_output=_force_inference_output,
524
+ )
525
+ else:
526
+ self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
527
+ self.norm3_t = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
528
+ self.ff_t = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
529
+
530
+ def forward(
531
+ self,
532
+ hidden_states: torch.Tensor,
533
+ hidden_states_masks: Optional[torch.Tensor] = None,
534
+ encoder_hidden_states: Optional[torch.Tensor] = None,
535
+ temb: Optional[torch.Tensor] = None,
536
+ image_rotary_emb: torch.Tensor = None,
537
+ ) -> torch.Tensor:
538
+ wtype = hidden_states.dtype
539
+ (
540
+ shift_msa_i,
541
+ scale_msa_i,
542
+ gate_msa_i,
543
+ shift_mlp_i,
544
+ scale_mlp_i,
545
+ gate_mlp_i,
546
+ shift_msa_t,
547
+ scale_msa_t,
548
+ gate_msa_t,
549
+ shift_mlp_t,
550
+ scale_mlp_t,
551
+ gate_mlp_t,
552
+ ) = self.adaLN_modulation(temb)[:, None].chunk(12, dim=-1)
553
+
554
+ # 1. MM-Attention
555
+ norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
556
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
557
+ norm_encoder_hidden_states = self.norm1_t(encoder_hidden_states).to(dtype=wtype)
558
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_msa_t) + shift_msa_t
559
+
560
+ attn_output_i, attn_output_t = self.attn1(
561
+ norm_hidden_states,
562
+ hidden_states_masks,
563
+ norm_encoder_hidden_states,
564
+ image_rotary_emb=image_rotary_emb,
565
+ )
566
+
567
+ hidden_states = gate_msa_i * attn_output_i + hidden_states
568
+ encoder_hidden_states = gate_msa_t * attn_output_t + encoder_hidden_states
569
+
570
+ # 2. Feed-forward
571
+ norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
572
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
573
+ norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(dtype=wtype)
574
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t
575
+
576
+ ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states)
577
+ ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states)
578
+ hidden_states = ff_output_i + hidden_states
579
+ encoder_hidden_states = ff_output_t + encoder_hidden_states
580
+ return hidden_states, encoder_hidden_states
581
+
582
+
583
+ class HiDreamBlock(nn.Module):
584
+ def __init__(self, block: Union[HiDreamImageTransformerBlock, HiDreamImageSingleTransformerBlock]):
585
+ super().__init__()
586
+ self.block = block
587
+
588
+ def forward(
589
+ self,
590
+ hidden_states: torch.Tensor,
591
+ hidden_states_masks: Optional[torch.Tensor] = None,
592
+ encoder_hidden_states: Optional[torch.Tensor] = None,
593
+ temb: Optional[torch.Tensor] = None,
594
+ image_rotary_emb: torch.Tensor = None,
595
+ ) -> torch.Tensor:
596
+ return self.block(
597
+ hidden_states=hidden_states,
598
+ hidden_states_masks=hidden_states_masks,
599
+ encoder_hidden_states=encoder_hidden_states,
600
+ temb=temb,
601
+ image_rotary_emb=image_rotary_emb,
602
+ )
603
+
604
+
605
+ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
606
+ _supports_gradient_checkpointing = True
607
+ _no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"]
608
+
609
+ @register_to_config
610
+ def __init__(
611
+ self,
612
+ patch_size: Optional[int] = None,
613
+ in_channels: int = 64,
614
+ out_channels: Optional[int] = None,
615
+ num_layers: int = 16,
616
+ num_single_layers: int = 32,
617
+ attention_head_dim: int = 128,
618
+ num_attention_heads: int = 20,
619
+ caption_channels: List[int] = None,
620
+ text_emb_dim: int = 2048,
621
+ num_routed_experts: int = 4,
622
+ num_activated_experts: int = 2,
623
+ axes_dims_rope: Tuple[int, int] = (32, 32),
624
+ max_resolution: Tuple[int, int] = (128, 128),
625
+ llama_layers: List[int] = None,
626
+ force_inference_output: bool = False,
627
+ ):
628
+ super().__init__()
629
+ self.out_channels = out_channels or in_channels
630
+ self.inner_dim = num_attention_heads * attention_head_dim
631
+
632
+ self.t_embedder = HiDreamImageTimestepEmbed(self.inner_dim)
633
+ self.p_embedder = HiDreamImagePooledEmbed(text_emb_dim, self.inner_dim)
634
+ self.x_embedder = HiDreamImagePatchEmbed(
635
+ patch_size=patch_size,
636
+ in_channels=in_channels,
637
+ out_channels=self.inner_dim,
638
+ )
639
+ self.pe_embedder = HiDreamImageEmbedND(theta=10000, axes_dim=axes_dims_rope)
640
+
641
+ self.double_stream_blocks = nn.ModuleList(
642
+ [
643
+ HiDreamBlock(
644
+ HiDreamImageTransformerBlock(
645
+ dim=self.inner_dim,
646
+ num_attention_heads=num_attention_heads,
647
+ attention_head_dim=attention_head_dim,
648
+ num_routed_experts=num_routed_experts,
649
+ num_activated_experts=num_activated_experts,
650
+ _force_inference_output=force_inference_output,
651
+ )
652
+ )
653
+ for _ in range(num_layers)
654
+ ]
655
+ )
656
+
657
+ self.single_stream_blocks = nn.ModuleList(
658
+ [
659
+ HiDreamBlock(
660
+ HiDreamImageSingleTransformerBlock(
661
+ dim=self.inner_dim,
662
+ num_attention_heads=num_attention_heads,
663
+ attention_head_dim=attention_head_dim,
664
+ num_routed_experts=num_routed_experts,
665
+ num_activated_experts=num_activated_experts,
666
+ _force_inference_output=force_inference_output,
667
+ )
668
+ )
669
+ for _ in range(num_single_layers)
670
+ ]
671
+ )
672
+
673
+ self.final_layer = HiDreamImageOutEmbed(self.inner_dim, patch_size, self.out_channels)
674
+
675
+ caption_channels = [caption_channels[1]] * (num_layers + num_single_layers) + [caption_channels[0]]
676
+ caption_projection = []
677
+ for caption_channel in caption_channels:
678
+ caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim))
679
+ self.caption_projection = nn.ModuleList(caption_projection)
680
+ self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
681
+
682
+ self.gradient_checkpointing = False
683
+
684
+ def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
685
+ if is_training and not self.config.force_inference_output:
686
+ B, S, F = x.shape
687
+ C = F // (self.config.patch_size * self.config.patch_size)
688
+ x = (
689
+ x.reshape(B, S, self.config.patch_size, self.config.patch_size, C)
690
+ .permute(0, 4, 1, 2, 3)
691
+ .reshape(B, C, S, self.config.patch_size * self.config.patch_size)
692
+ )
693
+ else:
694
+ x_arr = []
695
+ p1 = self.config.patch_size
696
+ p2 = self.config.patch_size
697
+ for i, img_size in enumerate(img_sizes):
698
+ pH, pW = img_size
699
+ t = x[i, : pH * pW].reshape(1, pH, pW, -1)
700
+ F_token = t.shape[-1]
701
+ C = F_token // (p1 * p2)
702
+ t = t.reshape(1, pH, pW, p1, p2, C)
703
+ t = t.permute(0, 5, 1, 3, 2, 4)
704
+ t = t.reshape(1, C, pH * p1, pW * p2)
705
+ x_arr.append(t)
706
+ x = torch.cat(x_arr, dim=0)
707
+ return x
708
+
709
+ def patchify(self, hidden_states):
710
+ batch_size, channels, height, width = hidden_states.shape
711
+ patch_size = self.config.patch_size
712
+ patch_height, patch_width = height // patch_size, width // patch_size
713
+ device = hidden_states.device
714
+ dtype = hidden_states.dtype
715
+
716
+ # create img_sizes
717
+ img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1)
718
+ img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1)
719
+
720
+ # create hidden_states_masks
721
+ if hidden_states.shape[-2] != hidden_states.shape[-1]:
722
+ hidden_states_masks = torch.zeros((batch_size, self.max_seq), dtype=dtype, device=device)
723
+ hidden_states_masks[:, : patch_height * patch_width] = 1.0
724
+ else:
725
+ hidden_states_masks = None
726
+
727
+ # create img_ids
728
+ img_ids = torch.zeros(patch_height, patch_width, 3, device=device)
729
+ row_indices = torch.arange(patch_height, device=device)[:, None]
730
+ col_indices = torch.arange(patch_width, device=device)[None, :]
731
+ img_ids[..., 1] = img_ids[..., 1] + row_indices
732
+ img_ids[..., 2] = img_ids[..., 2] + col_indices
733
+ img_ids = img_ids.reshape(patch_height * patch_width, -1)
734
+
735
+ if hidden_states.shape[-2] != hidden_states.shape[-1]:
736
+ # Handle non-square latents
737
+ img_ids_pad = torch.zeros(self.max_seq, 3, device=device)
738
+ img_ids_pad[: patch_height * patch_width, :] = img_ids
739
+ img_ids = img_ids_pad.unsqueeze(0).repeat(batch_size, 1, 1)
740
+ else:
741
+ img_ids = img_ids.unsqueeze(0).repeat(batch_size, 1, 1)
742
+
743
+ # patchify hidden_states
744
+ if hidden_states.shape[-2] != hidden_states.shape[-1]:
745
+ # Handle non-square latents
746
+ out = torch.zeros(
747
+ (batch_size, channels, self.max_seq, patch_size * patch_size),
748
+ dtype=dtype,
749
+ device=device,
750
+ )
751
+ hidden_states = hidden_states.reshape(
752
+ batch_size, channels, patch_height, patch_size, patch_width, patch_size
753
+ )
754
+ hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
755
+ hidden_states = hidden_states.reshape(
756
+ batch_size, channels, patch_height * patch_width, patch_size * patch_size
757
+ )
758
+ out[:, :, 0 : patch_height * patch_width] = hidden_states
759
+ hidden_states = out
760
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
761
+ batch_size, self.max_seq, patch_size * patch_size * channels
762
+ )
763
+
764
+ else:
765
+ # Handle square latents
766
+ hidden_states = hidden_states.reshape(
767
+ batch_size, channels, patch_height, patch_size, patch_width, patch_size
768
+ )
769
+ hidden_states = hidden_states.permute(0, 2, 4, 3, 5, 1)
770
+ hidden_states = hidden_states.reshape(
771
+ batch_size, patch_height * patch_width, patch_size * patch_size * channels
772
+ )
773
+
774
+ return hidden_states, hidden_states_masks, img_sizes, img_ids
775
+
776
+ def forward(
777
+ self,
778
+ hidden_states: torch.Tensor,
779
+ timesteps: torch.LongTensor = None,
780
+ encoder_hidden_states_t5: torch.Tensor = None,
781
+ encoder_hidden_states_llama3: torch.Tensor = None,
782
+ pooled_embeds: torch.Tensor = None,
783
+ img_ids: Optional[torch.Tensor] = None,
784
+ img_sizes: Optional[List[Tuple[int, int]]] = None,
785
+ hidden_states_masks: Optional[torch.Tensor] = None,
786
+ attention_kwargs: Optional[Dict[str, Any]] = None,
787
+ return_dict: bool = True,
788
+ **kwargs,
789
+ ):
790
+ encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
791
+
792
+ if encoder_hidden_states is not None:
793
+ deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
794
+ deprecate("encoder_hidden_states", "0.35.0", deprecation_message)
795
+ encoder_hidden_states_t5 = encoder_hidden_states[0]
796
+ encoder_hidden_states_llama3 = encoder_hidden_states[1]
797
+
798
+ if img_ids is not None and img_sizes is not None and hidden_states_masks is None:
799
+ deprecation_message = (
800
+ "Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
801
+ )
802
+ deprecate("img_ids", "0.35.0", deprecation_message)
803
+
804
+ if hidden_states_masks is not None and (img_ids is None or img_sizes is None):
805
+ raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.")
806
+ elif hidden_states_masks is not None and hidden_states.ndim != 3:
807
+ raise ValueError(
808
+ "if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
809
+ )
810
+
811
+ if attention_kwargs is not None:
812
+ attention_kwargs = attention_kwargs.copy()
813
+ lora_scale = attention_kwargs.pop("scale", 1.0)
814
+ else:
815
+ lora_scale = 1.0
816
+
817
+ if USE_PEFT_BACKEND:
818
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
819
+ scale_lora_layers(self, lora_scale)
820
+ else:
821
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
822
+ logger.warning(
823
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
824
+ )
825
+
826
+ # spatial forward
827
+ batch_size = hidden_states.shape[0]
828
+ hidden_states_type = hidden_states.dtype
829
+
830
+ # Patchify the input
831
+ if hidden_states_masks is None:
832
+ hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)
833
+
834
+ # Embed the hidden states
835
+ hidden_states = self.x_embedder(hidden_states)
836
+
837
+ # 0. time
838
+ timesteps = self.t_embedder(timesteps, hidden_states_type)
839
+ p_embedder = self.p_embedder(pooled_embeds)
840
+ temb = timesteps + p_embedder
841
+
842
+ encoder_hidden_states = [encoder_hidden_states_llama3[k] for k in self.config.llama_layers]
843
+
844
+ if self.caption_projection is not None:
845
+ new_encoder_hidden_states = []
846
+ for i, enc_hidden_state in enumerate(encoder_hidden_states):
847
+ enc_hidden_state = self.caption_projection[i](enc_hidden_state)
848
+ enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
849
+ new_encoder_hidden_states.append(enc_hidden_state)
850
+ encoder_hidden_states = new_encoder_hidden_states
851
+ encoder_hidden_states_t5 = self.caption_projection[-1](encoder_hidden_states_t5)
852
+ encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, -1, hidden_states.shape[-1])
853
+ encoder_hidden_states.append(encoder_hidden_states_t5)
854
+
855
+ txt_ids = torch.zeros(
856
+ batch_size,
857
+ encoder_hidden_states[-1].shape[1]
858
+ + encoder_hidden_states[-2].shape[1]
859
+ + encoder_hidden_states[0].shape[1],
860
+ 3,
861
+ device=img_ids.device,
862
+ dtype=img_ids.dtype,
863
+ )
864
+ ids = torch.cat((img_ids, txt_ids), dim=1)
865
+ image_rotary_emb = self.pe_embedder(ids)
866
+
867
+ # 2. Blocks
868
+ block_id = 0
869
+ initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
870
+ initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
871
+ for bid, block in enumerate(self.double_stream_blocks):
872
+ cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
873
+ cur_encoder_hidden_states = torch.cat(
874
+ [initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1
875
+ )
876
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
877
+ hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
878
+ block,
879
+ hidden_states,
880
+ hidden_states_masks,
881
+ cur_encoder_hidden_states,
882
+ temb,
883
+ image_rotary_emb,
884
+ )
885
+ else:
886
+ hidden_states, initial_encoder_hidden_states = block(
887
+ hidden_states=hidden_states,
888
+ hidden_states_masks=hidden_states_masks,
889
+ encoder_hidden_states=cur_encoder_hidden_states,
890
+ temb=temb,
891
+ image_rotary_emb=image_rotary_emb,
892
+ )
893
+ initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
894
+ block_id += 1
895
+
896
+ image_tokens_seq_len = hidden_states.shape[1]
897
+ hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
898
+ hidden_states_seq_len = hidden_states.shape[1]
899
+ if hidden_states_masks is not None:
900
+ encoder_attention_mask_ones = torch.ones(
901
+ (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
902
+ device=hidden_states_masks.device,
903
+ dtype=hidden_states_masks.dtype,
904
+ )
905
+ hidden_states_masks = torch.cat([hidden_states_masks, encoder_attention_mask_ones], dim=1)
906
+
907
+ for bid, block in enumerate(self.single_stream_blocks):
908
+ cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
909
+ hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
910
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
911
+ hidden_states = self._gradient_checkpointing_func(
912
+ block,
913
+ hidden_states,
914
+ hidden_states_masks,
915
+ None,
916
+ temb,
917
+ image_rotary_emb,
918
+ )
919
+ else:
920
+ hidden_states = block(
921
+ hidden_states=hidden_states,
922
+ hidden_states_masks=hidden_states_masks,
923
+ encoder_hidden_states=None,
924
+ temb=temb,
925
+ image_rotary_emb=image_rotary_emb,
926
+ )
927
+ hidden_states = hidden_states[:, :hidden_states_seq_len]
928
+ block_id += 1
929
+
930
+ hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
931
+ output = self.final_layer(hidden_states, temb)
932
+ output = self.unpatchify(output, img_sizes, self.training)
933
+ if hidden_states_masks is not None:
934
+ hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
935
+
936
+ if USE_PEFT_BACKEND:
937
+ # remove `lora_scale` from each PEFT layer
938
+ unscale_lora_layers(self, lora_scale)
939
+
940
+ if not return_dict:
941
+ return (output,)
942
+ return Transformer2DModelOutput(sample=output)