diffusers 0.33.0__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (478) hide show
  1. diffusers/__init__.py +48 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/dependency_versions_check.py +1 -1
  7. diffusers/dependency_versions_table.py +1 -1
  8. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  9. diffusers/hooks/faster_cache.py +2 -2
  10. diffusers/hooks/group_offloading.py +128 -29
  11. diffusers/hooks/hooks.py +2 -2
  12. diffusers/hooks/layerwise_casting.py +3 -3
  13. diffusers/hooks/pyramid_attention_broadcast.py +1 -1
  14. diffusers/image_processor.py +7 -2
  15. diffusers/loaders/__init__.py +4 -0
  16. diffusers/loaders/ip_adapter.py +5 -14
  17. diffusers/loaders/lora_base.py +212 -111
  18. diffusers/loaders/lora_conversion_utils.py +275 -34
  19. diffusers/loaders/lora_pipeline.py +1554 -819
  20. diffusers/loaders/peft.py +52 -109
  21. diffusers/loaders/single_file.py +2 -2
  22. diffusers/loaders/single_file_model.py +20 -4
  23. diffusers/loaders/single_file_utils.py +225 -5
  24. diffusers/loaders/textual_inversion.py +3 -2
  25. diffusers/loaders/transformer_flux.py +1 -1
  26. diffusers/loaders/transformer_sd3.py +2 -2
  27. diffusers/loaders/unet.py +2 -16
  28. diffusers/loaders/unet_loader_utils.py +1 -1
  29. diffusers/loaders/utils.py +1 -1
  30. diffusers/models/__init__.py +15 -1
  31. diffusers/models/activations.py +5 -5
  32. diffusers/models/adapter.py +2 -3
  33. diffusers/models/attention.py +4 -4
  34. diffusers/models/attention_flax.py +10 -10
  35. diffusers/models/attention_processor.py +14 -10
  36. diffusers/models/auto_model.py +47 -10
  37. diffusers/models/autoencoders/__init__.py +1 -0
  38. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  39. diffusers/models/autoencoders/autoencoder_dc.py +3 -3
  40. diffusers/models/autoencoders/autoencoder_kl.py +4 -4
  41. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  42. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  43. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
  44. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  45. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  46. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  47. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  48. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  49. diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
  50. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  52. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  53. diffusers/models/autoencoders/vae.py +13 -2
  54. diffusers/models/autoencoders/vq_model.py +2 -2
  55. diffusers/models/cache_utils.py +1 -1
  56. diffusers/models/controlnet.py +1 -1
  57. diffusers/models/controlnet_flux.py +1 -1
  58. diffusers/models/controlnet_sd3.py +1 -1
  59. diffusers/models/controlnet_sparsectrl.py +1 -1
  60. diffusers/models/controlnets/__init__.py +1 -0
  61. diffusers/models/controlnets/controlnet.py +3 -3
  62. diffusers/models/controlnets/controlnet_flax.py +1 -1
  63. diffusers/models/controlnets/controlnet_flux.py +16 -15
  64. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  65. diffusers/models/controlnets/controlnet_sana.py +290 -0
  66. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  67. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  68. diffusers/models/controlnets/controlnet_union.py +1 -1
  69. diffusers/models/controlnets/controlnet_xs.py +7 -7
  70. diffusers/models/controlnets/multicontrolnet.py +4 -5
  71. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  72. diffusers/models/downsampling.py +2 -2
  73. diffusers/models/embeddings.py +10 -12
  74. diffusers/models/embeddings_flax.py +2 -2
  75. diffusers/models/lora.py +3 -3
  76. diffusers/models/modeling_utils.py +44 -14
  77. diffusers/models/normalization.py +4 -4
  78. diffusers/models/resnet.py +2 -2
  79. diffusers/models/resnet_flax.py +1 -1
  80. diffusers/models/transformers/__init__.py +5 -0
  81. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  82. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  83. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  84. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  85. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  86. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  87. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  88. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  89. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  90. diffusers/models/transformers/prior_transformer.py +1 -1
  91. diffusers/models/transformers/sana_transformer.py +8 -3
  92. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  93. diffusers/models/transformers/t5_film_transformer.py +3 -3
  94. diffusers/models/transformers/transformer_2d.py +1 -1
  95. diffusers/models/transformers/transformer_allegro.py +1 -1
  96. diffusers/models/transformers/transformer_chroma.py +742 -0
  97. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  98. diffusers/models/transformers/transformer_cogview4.py +317 -25
  99. diffusers/models/transformers/transformer_cosmos.py +579 -0
  100. diffusers/models/transformers/transformer_flux.py +9 -11
  101. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  102. diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
  103. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  104. diffusers/models/transformers/transformer_ltx.py +2 -2
  105. diffusers/models/transformers/transformer_lumina2.py +1 -1
  106. diffusers/models/transformers/transformer_mochi.py +1 -1
  107. diffusers/models/transformers/transformer_omnigen.py +2 -2
  108. diffusers/models/transformers/transformer_sd3.py +7 -7
  109. diffusers/models/transformers/transformer_temporal.py +1 -1
  110. diffusers/models/transformers/transformer_wan.py +24 -8
  111. diffusers/models/transformers/transformer_wan_vace.py +393 -0
  112. diffusers/models/unets/unet_1d.py +1 -1
  113. diffusers/models/unets/unet_1d_blocks.py +1 -1
  114. diffusers/models/unets/unet_2d.py +1 -1
  115. diffusers/models/unets/unet_2d_blocks.py +1 -1
  116. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  117. diffusers/models/unets/unet_2d_condition.py +2 -2
  118. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  119. diffusers/models/unets/unet_3d_blocks.py +1 -1
  120. diffusers/models/unets/unet_3d_condition.py +3 -3
  121. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  122. diffusers/models/unets/unet_kandinsky3.py +1 -1
  123. diffusers/models/unets/unet_motion_model.py +2 -2
  124. diffusers/models/unets/unet_stable_cascade.py +1 -1
  125. diffusers/models/upsampling.py +2 -2
  126. diffusers/models/vae_flax.py +2 -2
  127. diffusers/models/vq_model.py +1 -1
  128. diffusers/pipelines/__init__.py +37 -6
  129. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  130. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  131. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  132. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  133. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  134. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  135. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  136. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  137. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  138. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  139. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  140. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  141. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
  142. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  143. diffusers/pipelines/auto_pipeline.py +6 -7
  144. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  145. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  146. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  147. diffusers/pipelines/chroma/__init__.py +49 -0
  148. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  149. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  150. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  151. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
  152. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
  153. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
  154. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
  155. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  156. diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
  157. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  158. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  159. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  160. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  161. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  162. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
  163. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  164. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  165. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  166. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  167. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  168. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
  169. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
  170. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
  171. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  172. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  173. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  174. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  175. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  176. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  177. diffusers/pipelines/cosmos/__init__.py +54 -0
  178. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  179. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  180. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  181. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  182. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  183. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  184. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  185. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  186. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  187. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  188. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  189. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  190. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  191. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  192. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  193. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  194. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  195. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  196. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  197. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  198. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  199. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  200. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  201. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  202. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  203. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  204. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
  205. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  206. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  207. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  208. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  209. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  210. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  211. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  212. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  213. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  214. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  215. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  216. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  217. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  218. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  219. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  220. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  221. diffusers/pipelines/flux/modeling_flux.py +1 -1
  222. diffusers/pipelines/flux/pipeline_flux.py +10 -17
  223. diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
  224. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
  225. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
  226. diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
  227. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
  228. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
  229. diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
  230. diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
  231. diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
  232. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  233. diffusers/pipelines/free_init_utils.py +2 -2
  234. diffusers/pipelines/free_noise_utils.py +3 -3
  235. diffusers/pipelines/hidream_image/__init__.py +47 -0
  236. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  237. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  238. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  239. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  240. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
  241. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  242. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  243. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  244. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  245. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  246. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  247. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  248. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  249. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  250. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  251. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  252. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  253. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  254. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  255. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  256. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  257. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  258. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  259. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  260. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  261. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  262. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  263. diffusers/pipelines/kolors/text_encoder.py +3 -3
  264. diffusers/pipelines/kolors/tokenizer.py +1 -1
  265. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  266. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  267. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  268. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  269. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  270. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  271. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  272. diffusers/pipelines/ltx/__init__.py +4 -0
  273. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  274. diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
  275. diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
  276. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
  277. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  278. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  279. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  280. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  281. diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
  282. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  283. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  284. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  285. diffusers/pipelines/onnx_utils.py +15 -2
  286. diffusers/pipelines/pag/pag_utils.py +2 -2
  287. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  288. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  289. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  290. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  291. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  292. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  293. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  294. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  295. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  296. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  297. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  298. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  299. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  300. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  301. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  302. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  303. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  304. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  305. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  306. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  307. diffusers/pipelines/pipeline_flax_utils.py +3 -4
  308. diffusers/pipelines/pipeline_loading_utils.py +89 -13
  309. diffusers/pipelines/pipeline_utils.py +105 -33
  310. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
  311. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
  312. diffusers/pipelines/sana/__init__.py +4 -0
  313. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  314. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  315. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  316. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  317. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  318. diffusers/pipelines/shap_e/camera.py +1 -1
  319. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  320. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  321. diffusers/pipelines/shap_e/renderer.py +3 -3
  322. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  323. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  324. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  325. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  326. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  327. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  328. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  329. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  330. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  331. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  332. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  333. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
  334. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  335. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
  336. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
  337. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
  338. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  339. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  340. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  341. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  342. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  343. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  344. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  345. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  346. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  347. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  348. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  349. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  350. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
  351. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  352. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  353. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  354. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  355. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  356. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  357. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  358. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  359. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  360. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  361. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  362. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  363. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  364. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  365. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  366. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  367. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  368. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  369. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  370. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  371. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  372. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  373. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  374. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  375. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  376. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  377. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  378. diffusers/pipelines/unclip/text_proj.py +2 -2
  379. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  380. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  381. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  382. diffusers/pipelines/visualcloze/__init__.py +52 -0
  383. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  384. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  385. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  386. diffusers/pipelines/wan/__init__.py +2 -0
  387. diffusers/pipelines/wan/pipeline_wan.py +17 -12
  388. diffusers/pipelines/wan/pipeline_wan_i2v.py +42 -20
  389. diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
  390. diffusers/pipelines/wan/pipeline_wan_video2video.py +18 -18
  391. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  392. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  393. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  394. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  395. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  396. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  397. diffusers/quantizers/__init__.py +179 -1
  398. diffusers/quantizers/base.py +6 -1
  399. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  400. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  401. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  402. diffusers/quantizers/gguf/utils.py +16 -13
  403. diffusers/quantizers/quantization_config.py +18 -16
  404. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  405. diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
  406. diffusers/schedulers/__init__.py +3 -1
  407. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  408. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  409. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  410. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  411. diffusers/schedulers/scheduling_ddim.py +8 -8
  412. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  413. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  414. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  415. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  416. diffusers/schedulers/scheduling_ddpm.py +9 -9
  417. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  418. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  419. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  420. diffusers/schedulers/scheduling_deis_multistep.py +8 -8
  421. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  422. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
  423. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  424. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  425. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  426. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
  427. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  428. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  429. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  430. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  431. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  432. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  433. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  434. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  435. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  436. diffusers/schedulers/scheduling_ipndm.py +2 -2
  437. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  438. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  439. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  440. diffusers/schedulers/scheduling_lcm.py +3 -3
  441. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  442. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  443. diffusers/schedulers/scheduling_pndm.py +4 -4
  444. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  445. diffusers/schedulers/scheduling_repaint.py +9 -9
  446. diffusers/schedulers/scheduling_sasolver.py +15 -15
  447. diffusers/schedulers/scheduling_scm.py +1 -1
  448. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  449. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  450. diffusers/schedulers/scheduling_tcd.py +3 -3
  451. diffusers/schedulers/scheduling_unclip.py +5 -5
  452. diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
  453. diffusers/schedulers/scheduling_utils.py +1 -1
  454. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  455. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  456. diffusers/training_utils.py +13 -5
  457. diffusers/utils/__init__.py +5 -0
  458. diffusers/utils/accelerate_utils.py +1 -1
  459. diffusers/utils/doc_utils.py +1 -1
  460. diffusers/utils/dummy_pt_objects.py +120 -0
  461. diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
  462. diffusers/utils/dynamic_modules_utils.py +21 -3
  463. diffusers/utils/export_utils.py +1 -1
  464. diffusers/utils/import_utils.py +81 -18
  465. diffusers/utils/logging.py +1 -1
  466. diffusers/utils/outputs.py +2 -1
  467. diffusers/utils/peft_utils.py +91 -8
  468. diffusers/utils/state_dict_utils.py +20 -3
  469. diffusers/utils/testing_utils.py +59 -7
  470. diffusers/utils/torch_utils.py +25 -5
  471. diffusers/video_processor.py +2 -2
  472. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/METADATA +3 -3
  473. diffusers-0.34.0.dist-info/RECORD +639 -0
  474. diffusers-0.33.0.dist-info/RECORD +0 -608
  475. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
  476. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/WHEEL +0 -0
  477. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
  478. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,742 @@
1
+ # Copyright 2025 Black Forest Labs, The HuggingFace Team and loadstone-rock . All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
24
+ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
25
+ from ...utils.import_utils import is_torch_npu_available
26
+ from ...utils.torch_utils import maybe_allow_in_graph
27
+ from ..attention import FeedForward
28
+ from ..attention_processor import (
29
+ Attention,
30
+ AttentionProcessor,
31
+ FluxAttnProcessor2_0,
32
+ FluxAttnProcessor2_0_NPU,
33
+ FusedFluxAttnProcessor2_0,
34
+ )
35
+ from ..cache_utils import CacheMixin
36
+ from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
37
+ from ..modeling_outputs import Transformer2DModelOutput
38
+ from ..modeling_utils import ModelMixin
39
+ from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
40
+
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+
45
+ class ChromaAdaLayerNormZeroPruned(nn.Module):
46
+ r"""
47
+ Norm layer adaptive layer norm zero (adaLN-Zero).
48
+
49
+ Parameters:
50
+ embedding_dim (`int`): The size of each embedding vector.
51
+ num_embeddings (`int`): The size of the embeddings dictionary.
52
+ """
53
+
54
+ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
55
+ super().__init__()
56
+ if num_embeddings is not None:
57
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
58
+ else:
59
+ self.emb = None
60
+
61
+ if norm_type == "layer_norm":
62
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
63
+ elif norm_type == "fp32_layer_norm":
64
+ self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
65
+ else:
66
+ raise ValueError(
67
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
68
+ )
69
+
70
+ def forward(
71
+ self,
72
+ x: torch.Tensor,
73
+ timestep: Optional[torch.Tensor] = None,
74
+ class_labels: Optional[torch.LongTensor] = None,
75
+ hidden_dtype: Optional[torch.dtype] = None,
76
+ emb: Optional[torch.Tensor] = None,
77
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
78
+ if self.emb is not None:
79
+ emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
80
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.flatten(1, 2).chunk(6, dim=1)
81
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
82
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
83
+
84
+
85
+ class ChromaAdaLayerNormZeroSinglePruned(nn.Module):
86
+ r"""
87
+ Norm layer adaptive layer norm zero (adaLN-Zero).
88
+
89
+ Parameters:
90
+ embedding_dim (`int`): The size of each embedding vector.
91
+ num_embeddings (`int`): The size of the embeddings dictionary.
92
+ """
93
+
94
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
95
+ super().__init__()
96
+
97
+ if norm_type == "layer_norm":
98
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
99
+ else:
100
+ raise ValueError(
101
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
102
+ )
103
+
104
+ def forward(
105
+ self,
106
+ x: torch.Tensor,
107
+ emb: Optional[torch.Tensor] = None,
108
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
109
+ shift_msa, scale_msa, gate_msa = emb.flatten(1, 2).chunk(3, dim=1)
110
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
111
+ return x, gate_msa
112
+
113
+
114
+ class ChromaAdaLayerNormContinuousPruned(nn.Module):
115
+ r"""
116
+ Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
117
+
118
+ Args:
119
+ embedding_dim (`int`): Embedding dimension to use during projection.
120
+ conditioning_embedding_dim (`int`): Dimension of the input condition.
121
+ elementwise_affine (`bool`, defaults to `True`):
122
+ Boolean flag to denote if affine transformation should be applied.
123
+ eps (`float`, defaults to 1e-5): Epsilon factor.
124
+ bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
125
+ norm_type (`str`, defaults to `"layer_norm"`):
126
+ Normalization layer to use. Values supported: "layer_norm", "rms_norm".
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ embedding_dim: int,
132
+ conditioning_embedding_dim: int,
133
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
134
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
135
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
136
+ # However, this is how it was implemented in the original code, and it's rather likely you should
137
+ # set `elementwise_affine` to False.
138
+ elementwise_affine=True,
139
+ eps=1e-5,
140
+ bias=True,
141
+ norm_type="layer_norm",
142
+ ):
143
+ super().__init__()
144
+ if norm_type == "layer_norm":
145
+ self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
146
+ elif norm_type == "rms_norm":
147
+ self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
148
+ else:
149
+ raise ValueError(f"unknown norm_type {norm_type}")
150
+
151
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
152
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
153
+ shift, scale = torch.chunk(emb.flatten(1, 2).to(x.dtype), 2, dim=1)
154
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
155
+ return x
156
+
157
+
158
+ class ChromaCombinedTimestepTextProjEmbeddings(nn.Module):
159
+ def __init__(self, num_channels: int, out_dim: int):
160
+ super().__init__()
161
+
162
+ self.time_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
163
+ self.guidance_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
164
+
165
+ self.register_buffer(
166
+ "mod_proj",
167
+ get_timestep_embedding(
168
+ torch.arange(out_dim) * 1000, 2 * num_channels, flip_sin_to_cos=True, downscale_freq_shift=0
169
+ ),
170
+ persistent=False,
171
+ )
172
+
173
+ def forward(self, timestep: torch.Tensor) -> torch.Tensor:
174
+ mod_index_length = self.mod_proj.shape[0]
175
+ batch_size = timestep.shape[0]
176
+
177
+ timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
178
+ guidance_proj = self.guidance_proj(torch.tensor([0] * batch_size)).to(
179
+ dtype=timestep.dtype, device=timestep.device
180
+ )
181
+
182
+ mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device).repeat(batch_size, 1, 1)
183
+ timestep_guidance = (
184
+ torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
185
+ )
186
+ input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1)
187
+ return input_vec.to(timestep.dtype)
188
+
189
+
190
+ class ChromaApproximator(nn.Module):
191
+ def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers: int = 5):
192
+ super().__init__()
193
+ self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
194
+ self.layers = nn.ModuleList(
195
+ [PixArtAlphaTextProjection(hidden_dim, hidden_dim, act_fn="silu") for _ in range(n_layers)]
196
+ )
197
+ self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(n_layers)])
198
+ self.out_proj = nn.Linear(hidden_dim, out_dim)
199
+
200
+ def forward(self, x):
201
+ x = self.in_proj(x)
202
+
203
+ for layer, norms in zip(self.layers, self.norms):
204
+ x = x + layer(norms(x))
205
+
206
+ return self.out_proj(x)
207
+
208
+
209
+ @maybe_allow_in_graph
210
+ class ChromaSingleTransformerBlock(nn.Module):
211
+ def __init__(
212
+ self,
213
+ dim: int,
214
+ num_attention_heads: int,
215
+ attention_head_dim: int,
216
+ mlp_ratio: float = 4.0,
217
+ ):
218
+ super().__init__()
219
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
220
+ self.norm = ChromaAdaLayerNormZeroSinglePruned(dim)
221
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
222
+ self.act_mlp = nn.GELU(approximate="tanh")
223
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
224
+
225
+ if is_torch_npu_available():
226
+ deprecation_message = (
227
+ "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
228
+ "should be set explicitly using the `set_attn_processor` method."
229
+ )
230
+ deprecate("npu_processor", "0.34.0", deprecation_message)
231
+ processor = FluxAttnProcessor2_0_NPU()
232
+ else:
233
+ processor = FluxAttnProcessor2_0()
234
+
235
+ self.attn = Attention(
236
+ query_dim=dim,
237
+ cross_attention_dim=None,
238
+ dim_head=attention_head_dim,
239
+ heads=num_attention_heads,
240
+ out_dim=dim,
241
+ bias=True,
242
+ processor=processor,
243
+ qk_norm="rms_norm",
244
+ eps=1e-6,
245
+ pre_only=True,
246
+ )
247
+
248
+ def forward(
249
+ self,
250
+ hidden_states: torch.Tensor,
251
+ temb: torch.Tensor,
252
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
253
+ attention_mask: Optional[torch.Tensor] = None,
254
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
255
+ ) -> torch.Tensor:
256
+ residual = hidden_states
257
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
258
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
259
+ joint_attention_kwargs = joint_attention_kwargs or {}
260
+
261
+ if attention_mask is not None:
262
+ attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
263
+
264
+ attn_output = self.attn(
265
+ hidden_states=norm_hidden_states,
266
+ image_rotary_emb=image_rotary_emb,
267
+ attention_mask=attention_mask,
268
+ **joint_attention_kwargs,
269
+ )
270
+
271
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
272
+ gate = gate.unsqueeze(1)
273
+ hidden_states = gate * self.proj_out(hidden_states)
274
+ hidden_states = residual + hidden_states
275
+ if hidden_states.dtype == torch.float16:
276
+ hidden_states = hidden_states.clip(-65504, 65504)
277
+
278
+ return hidden_states
279
+
280
+
281
+ @maybe_allow_in_graph
282
+ class ChromaTransformerBlock(nn.Module):
283
+ def __init__(
284
+ self,
285
+ dim: int,
286
+ num_attention_heads: int,
287
+ attention_head_dim: int,
288
+ qk_norm: str = "rms_norm",
289
+ eps: float = 1e-6,
290
+ ):
291
+ super().__init__()
292
+ self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
293
+ self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
294
+
295
+ self.attn = Attention(
296
+ query_dim=dim,
297
+ cross_attention_dim=None,
298
+ added_kv_proj_dim=dim,
299
+ dim_head=attention_head_dim,
300
+ heads=num_attention_heads,
301
+ out_dim=dim,
302
+ context_pre_only=False,
303
+ bias=True,
304
+ processor=FluxAttnProcessor2_0(),
305
+ qk_norm=qk_norm,
306
+ eps=eps,
307
+ )
308
+
309
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
310
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
311
+
312
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
313
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
314
+
315
+ def forward(
316
+ self,
317
+ hidden_states: torch.Tensor,
318
+ encoder_hidden_states: torch.Tensor,
319
+ temb: torch.Tensor,
320
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
321
+ attention_mask: Optional[torch.Tensor] = None,
322
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
323
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
324
+ temb_img, temb_txt = temb[:, :6], temb[:, 6:]
325
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb_img)
326
+
327
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
328
+ encoder_hidden_states, emb=temb_txt
329
+ )
330
+ joint_attention_kwargs = joint_attention_kwargs or {}
331
+ if attention_mask is not None:
332
+ attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
333
+
334
+ # Attention.
335
+ attention_outputs = self.attn(
336
+ hidden_states=norm_hidden_states,
337
+ encoder_hidden_states=norm_encoder_hidden_states,
338
+ image_rotary_emb=image_rotary_emb,
339
+ attention_mask=attention_mask,
340
+ **joint_attention_kwargs,
341
+ )
342
+
343
+ if len(attention_outputs) == 2:
344
+ attn_output, context_attn_output = attention_outputs
345
+ elif len(attention_outputs) == 3:
346
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
347
+
348
+ # Process attention outputs for the `hidden_states`.
349
+ attn_output = gate_msa.unsqueeze(1) * attn_output
350
+ hidden_states = hidden_states + attn_output
351
+
352
+ norm_hidden_states = self.norm2(hidden_states)
353
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
354
+
355
+ ff_output = self.ff(norm_hidden_states)
356
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
357
+
358
+ hidden_states = hidden_states + ff_output
359
+ if len(attention_outputs) == 3:
360
+ hidden_states = hidden_states + ip_attn_output
361
+
362
+ # Process attention outputs for the `encoder_hidden_states`.
363
+
364
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
365
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
366
+
367
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
368
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
369
+
370
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
371
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
372
+ if encoder_hidden_states.dtype == torch.float16:
373
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
374
+
375
+ return encoder_hidden_states, hidden_states
376
+
377
+
378
+ class ChromaTransformer2DModel(
379
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
380
+ ):
381
+ """
382
+ The Transformer model introduced in Flux, modified for Chroma.
383
+
384
+ Reference: https://huggingface.co/lodestones/Chroma
385
+
386
+ Args:
387
+ patch_size (`int`, defaults to `1`):
388
+ Patch size to turn the input data into small patches.
389
+ in_channels (`int`, defaults to `64`):
390
+ The number of channels in the input.
391
+ out_channels (`int`, *optional*, defaults to `None`):
392
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
393
+ num_layers (`int`, defaults to `19`):
394
+ The number of layers of dual stream DiT blocks to use.
395
+ num_single_layers (`int`, defaults to `38`):
396
+ The number of layers of single stream DiT blocks to use.
397
+ attention_head_dim (`int`, defaults to `128`):
398
+ The number of dimensions to use for each attention head.
399
+ num_attention_heads (`int`, defaults to `24`):
400
+ The number of attention heads to use.
401
+ joint_attention_dim (`int`, defaults to `4096`):
402
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
403
+ `encoder_hidden_states`).
404
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
405
+ The dimensions to use for the rotary positional embeddings.
406
+ """
407
+
408
+ _supports_gradient_checkpointing = True
409
+ _no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
410
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
411
+
412
+ @register_to_config
413
+ def __init__(
414
+ self,
415
+ patch_size: int = 1,
416
+ in_channels: int = 64,
417
+ out_channels: Optional[int] = None,
418
+ num_layers: int = 19,
419
+ num_single_layers: int = 38,
420
+ attention_head_dim: int = 128,
421
+ num_attention_heads: int = 24,
422
+ joint_attention_dim: int = 4096,
423
+ axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
424
+ approximator_num_channels: int = 64,
425
+ approximator_hidden_dim: int = 5120,
426
+ approximator_layers: int = 5,
427
+ ):
428
+ super().__init__()
429
+ self.out_channels = out_channels or in_channels
430
+ self.inner_dim = num_attention_heads * attention_head_dim
431
+
432
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
433
+
434
+ self.time_text_embed = ChromaCombinedTimestepTextProjEmbeddings(
435
+ num_channels=approximator_num_channels // 4,
436
+ out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2,
437
+ )
438
+ self.distilled_guidance_layer = ChromaApproximator(
439
+ in_dim=approximator_num_channels,
440
+ out_dim=self.inner_dim,
441
+ hidden_dim=approximator_hidden_dim,
442
+ n_layers=approximator_layers,
443
+ )
444
+
445
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
446
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
447
+
448
+ self.transformer_blocks = nn.ModuleList(
449
+ [
450
+ ChromaTransformerBlock(
451
+ dim=self.inner_dim,
452
+ num_attention_heads=num_attention_heads,
453
+ attention_head_dim=attention_head_dim,
454
+ )
455
+ for _ in range(num_layers)
456
+ ]
457
+ )
458
+
459
+ self.single_transformer_blocks = nn.ModuleList(
460
+ [
461
+ ChromaSingleTransformerBlock(
462
+ dim=self.inner_dim,
463
+ num_attention_heads=num_attention_heads,
464
+ attention_head_dim=attention_head_dim,
465
+ )
466
+ for _ in range(num_single_layers)
467
+ ]
468
+ )
469
+
470
+ self.norm_out = ChromaAdaLayerNormContinuousPruned(
471
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
472
+ )
473
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
474
+
475
+ self.gradient_checkpointing = False
476
+
477
+ @property
478
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
479
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
480
+ r"""
481
+ Returns:
482
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
483
+ indexed by its weight name.
484
+ """
485
+ # set recursively
486
+ processors = {}
487
+
488
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
489
+ if hasattr(module, "get_processor"):
490
+ processors[f"{name}.processor"] = module.get_processor()
491
+
492
+ for sub_name, child in module.named_children():
493
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
494
+
495
+ return processors
496
+
497
+ for name, module in self.named_children():
498
+ fn_recursive_add_processors(name, module, processors)
499
+
500
+ return processors
501
+
502
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
503
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
504
+ r"""
505
+ Sets the attention processor to use to compute attention.
506
+
507
+ Parameters:
508
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
509
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
510
+ for **all** `Attention` layers.
511
+
512
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
513
+ processor. This is strongly recommended when setting trainable attention processors.
514
+
515
+ """
516
+ count = len(self.attn_processors.keys())
517
+
518
+ if isinstance(processor, dict) and len(processor) != count:
519
+ raise ValueError(
520
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
521
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
522
+ )
523
+
524
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
525
+ if hasattr(module, "set_processor"):
526
+ if not isinstance(processor, dict):
527
+ module.set_processor(processor)
528
+ else:
529
+ module.set_processor(processor.pop(f"{name}.processor"))
530
+
531
+ for sub_name, child in module.named_children():
532
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
533
+
534
+ for name, module in self.named_children():
535
+ fn_recursive_attn_processor(name, module, processor)
536
+
537
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
538
+ def fuse_qkv_projections(self):
539
+ """
540
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
541
+ are fused. For cross-attention modules, key and value projection matrices are fused.
542
+
543
+ <Tip warning={true}>
544
+
545
+ This API is 🧪 experimental.
546
+
547
+ </Tip>
548
+ """
549
+ self.original_attn_processors = None
550
+
551
+ for _, attn_processor in self.attn_processors.items():
552
+ if "Added" in str(attn_processor.__class__.__name__):
553
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
554
+
555
+ self.original_attn_processors = self.attn_processors
556
+
557
+ for module in self.modules():
558
+ if isinstance(module, Attention):
559
+ module.fuse_projections(fuse=True)
560
+
561
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
562
+
563
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
564
+ def unfuse_qkv_projections(self):
565
+ """Disables the fused QKV projection if enabled.
566
+
567
+ <Tip warning={true}>
568
+
569
+ This API is 🧪 experimental.
570
+
571
+ </Tip>
572
+
573
+ """
574
+ if self.original_attn_processors is not None:
575
+ self.set_attn_processor(self.original_attn_processors)
576
+
577
+ def forward(
578
+ self,
579
+ hidden_states: torch.Tensor,
580
+ encoder_hidden_states: torch.Tensor = None,
581
+ timestep: torch.LongTensor = None,
582
+ img_ids: torch.Tensor = None,
583
+ txt_ids: torch.Tensor = None,
584
+ attention_mask: torch.Tensor = None,
585
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
586
+ controlnet_block_samples=None,
587
+ controlnet_single_block_samples=None,
588
+ return_dict: bool = True,
589
+ controlnet_blocks_repeat: bool = False,
590
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
591
+ """
592
+ The [`FluxTransformer2DModel`] forward method.
593
+
594
+ Args:
595
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
596
+ Input `hidden_states`.
597
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
598
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
599
+ timestep ( `torch.LongTensor`):
600
+ Used to indicate denoising step.
601
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
602
+ A list of tensors that if specified are added to the residuals of transformer blocks.
603
+ joint_attention_kwargs (`dict`, *optional*):
604
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
605
+ `self.processor` in
606
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
607
+ return_dict (`bool`, *optional*, defaults to `True`):
608
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
609
+ tuple.
610
+
611
+ Returns:
612
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
613
+ `tuple` where the first element is the sample tensor.
614
+ """
615
+ if joint_attention_kwargs is not None:
616
+ joint_attention_kwargs = joint_attention_kwargs.copy()
617
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
618
+ else:
619
+ lora_scale = 1.0
620
+
621
+ if USE_PEFT_BACKEND:
622
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
623
+ scale_lora_layers(self, lora_scale)
624
+ else:
625
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
626
+ logger.warning(
627
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
628
+ )
629
+
630
+ hidden_states = self.x_embedder(hidden_states)
631
+
632
+ timestep = timestep.to(hidden_states.dtype) * 1000
633
+
634
+ input_vec = self.time_text_embed(timestep)
635
+ pooled_temb = self.distilled_guidance_layer(input_vec)
636
+
637
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
638
+
639
+ if txt_ids.ndim == 3:
640
+ logger.warning(
641
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
642
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
643
+ )
644
+ txt_ids = txt_ids[0]
645
+ if img_ids.ndim == 3:
646
+ logger.warning(
647
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
648
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
649
+ )
650
+ img_ids = img_ids[0]
651
+
652
+ ids = torch.cat((txt_ids, img_ids), dim=0)
653
+ image_rotary_emb = self.pos_embed(ids)
654
+
655
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
656
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
657
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
658
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
659
+
660
+ for index_block, block in enumerate(self.transformer_blocks):
661
+ img_offset = 3 * len(self.single_transformer_blocks)
662
+ txt_offset = img_offset + 6 * len(self.transformer_blocks)
663
+ img_modulation = img_offset + 6 * index_block
664
+ text_modulation = txt_offset + 6 * index_block
665
+ temb = torch.cat(
666
+ (
667
+ pooled_temb[:, img_modulation : img_modulation + 6],
668
+ pooled_temb[:, text_modulation : text_modulation + 6],
669
+ ),
670
+ dim=1,
671
+ )
672
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
673
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
674
+ block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask
675
+ )
676
+
677
+ else:
678
+ encoder_hidden_states, hidden_states = block(
679
+ hidden_states=hidden_states,
680
+ encoder_hidden_states=encoder_hidden_states,
681
+ temb=temb,
682
+ image_rotary_emb=image_rotary_emb,
683
+ attention_mask=attention_mask,
684
+ joint_attention_kwargs=joint_attention_kwargs,
685
+ )
686
+
687
+ # controlnet residual
688
+ if controlnet_block_samples is not None:
689
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
690
+ interval_control = int(np.ceil(interval_control))
691
+ # For Xlabs ControlNet.
692
+ if controlnet_blocks_repeat:
693
+ hidden_states = (
694
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
695
+ )
696
+ else:
697
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
698
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
699
+
700
+ for index_block, block in enumerate(self.single_transformer_blocks):
701
+ start_idx = 3 * index_block
702
+ temb = pooled_temb[:, start_idx : start_idx + 3]
703
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
704
+ hidden_states = self._gradient_checkpointing_func(
705
+ block,
706
+ hidden_states,
707
+ temb,
708
+ image_rotary_emb,
709
+ )
710
+
711
+ else:
712
+ hidden_states = block(
713
+ hidden_states=hidden_states,
714
+ temb=temb,
715
+ image_rotary_emb=image_rotary_emb,
716
+ attention_mask=attention_mask,
717
+ joint_attention_kwargs=joint_attention_kwargs,
718
+ )
719
+
720
+ # controlnet residual
721
+ if controlnet_single_block_samples is not None:
722
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
723
+ interval_control = int(np.ceil(interval_control))
724
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
725
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
726
+ + controlnet_single_block_samples[index_block // interval_control]
727
+ )
728
+
729
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
730
+
731
+ temb = pooled_temb[:, -2:]
732
+ hidden_states = self.norm_out(hidden_states, temb)
733
+ output = self.proj_out(hidden_states)
734
+
735
+ if USE_PEFT_BACKEND:
736
+ # remove `lora_scale` from each PEFT layer
737
+ unscale_lora_layers(self, lora_scale)
738
+
739
+ if not return_dict:
740
+ return (output,)
741
+
742
+ return Transformer2DModelOutput(sample=output)