diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,641 @@
1
+ # Copyright 2025 Black Forest Labs, The HuggingFace Team and loadstone-rock . All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
24
+ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
25
+ from ...utils.import_utils import is_torch_npu_available
26
+ from ...utils.torch_utils import maybe_allow_in_graph
27
+ from ..attention import AttentionMixin, FeedForward
28
+ from ..cache_utils import CacheMixin
29
+ from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
30
+ from ..modeling_outputs import Transformer2DModelOutput
31
+ from ..modeling_utils import ModelMixin
32
+ from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
33
+ from .transformer_flux import FluxAttention, FluxAttnProcessor
34
+
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ class ChromaAdaLayerNormZeroPruned(nn.Module):
40
+ r"""
41
+ Norm layer adaptive layer norm zero (adaLN-Zero).
42
+
43
+ Parameters:
44
+ embedding_dim (`int`): The size of each embedding vector.
45
+ num_embeddings (`int`): The size of the embeddings dictionary.
46
+ """
47
+
48
+ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
49
+ super().__init__()
50
+ if num_embeddings is not None:
51
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
52
+ else:
53
+ self.emb = None
54
+
55
+ if norm_type == "layer_norm":
56
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
57
+ elif norm_type == "fp32_layer_norm":
58
+ self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
59
+ else:
60
+ raise ValueError(
61
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
62
+ )
63
+
64
+ def forward(
65
+ self,
66
+ x: torch.Tensor,
67
+ timestep: Optional[torch.Tensor] = None,
68
+ class_labels: Optional[torch.LongTensor] = None,
69
+ hidden_dtype: Optional[torch.dtype] = None,
70
+ emb: Optional[torch.Tensor] = None,
71
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
72
+ if self.emb is not None:
73
+ emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
74
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.flatten(1, 2).chunk(6, dim=1)
75
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
76
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
77
+
78
+
79
+ class ChromaAdaLayerNormZeroSinglePruned(nn.Module):
80
+ r"""
81
+ Norm layer adaptive layer norm zero (adaLN-Zero).
82
+
83
+ Parameters:
84
+ embedding_dim (`int`): The size of each embedding vector.
85
+ num_embeddings (`int`): The size of the embeddings dictionary.
86
+ """
87
+
88
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
89
+ super().__init__()
90
+
91
+ if norm_type == "layer_norm":
92
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
93
+ else:
94
+ raise ValueError(
95
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
96
+ )
97
+
98
+ def forward(
99
+ self,
100
+ x: torch.Tensor,
101
+ emb: Optional[torch.Tensor] = None,
102
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
103
+ shift_msa, scale_msa, gate_msa = emb.flatten(1, 2).chunk(3, dim=1)
104
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
105
+ return x, gate_msa
106
+
107
+
108
+ class ChromaAdaLayerNormContinuousPruned(nn.Module):
109
+ r"""
110
+ Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
111
+
112
+ Args:
113
+ embedding_dim (`int`): Embedding dimension to use during projection.
114
+ conditioning_embedding_dim (`int`): Dimension of the input condition.
115
+ elementwise_affine (`bool`, defaults to `True`):
116
+ Boolean flag to denote if affine transformation should be applied.
117
+ eps (`float`, defaults to 1e-5): Epsilon factor.
118
+ bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
119
+ norm_type (`str`, defaults to `"layer_norm"`):
120
+ Normalization layer to use. Values supported: "layer_norm", "rms_norm".
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ embedding_dim: int,
126
+ conditioning_embedding_dim: int,
127
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
128
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
129
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
130
+ # However, this is how it was implemented in the original code, and it's rather likely you should
131
+ # set `elementwise_affine` to False.
132
+ elementwise_affine=True,
133
+ eps=1e-5,
134
+ bias=True,
135
+ norm_type="layer_norm",
136
+ ):
137
+ super().__init__()
138
+ if norm_type == "layer_norm":
139
+ self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
140
+ elif norm_type == "rms_norm":
141
+ self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
142
+ else:
143
+ raise ValueError(f"unknown norm_type {norm_type}")
144
+
145
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
146
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
147
+ shift, scale = torch.chunk(emb.flatten(1, 2).to(x.dtype), 2, dim=1)
148
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
149
+ return x
150
+
151
+
152
+ class ChromaCombinedTimestepTextProjEmbeddings(nn.Module):
153
+ def __init__(self, num_channels: int, out_dim: int):
154
+ super().__init__()
155
+
156
+ self.time_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
157
+ self.guidance_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
158
+
159
+ self.register_buffer(
160
+ "mod_proj",
161
+ get_timestep_embedding(
162
+ torch.arange(out_dim) * 1000, 2 * num_channels, flip_sin_to_cos=True, downscale_freq_shift=0
163
+ ),
164
+ persistent=False,
165
+ )
166
+
167
+ def forward(self, timestep: torch.Tensor) -> torch.Tensor:
168
+ mod_index_length = self.mod_proj.shape[0]
169
+ batch_size = timestep.shape[0]
170
+
171
+ timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
172
+ guidance_proj = self.guidance_proj(torch.tensor([0] * batch_size)).to(
173
+ dtype=timestep.dtype, device=timestep.device
174
+ )
175
+
176
+ mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device).repeat(batch_size, 1, 1)
177
+ timestep_guidance = (
178
+ torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
179
+ )
180
+ input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1)
181
+ return input_vec.to(timestep.dtype)
182
+
183
+
184
+ class ChromaApproximator(nn.Module):
185
+ def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers: int = 5):
186
+ super().__init__()
187
+ self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
188
+ self.layers = nn.ModuleList(
189
+ [PixArtAlphaTextProjection(hidden_dim, hidden_dim, act_fn="silu") for _ in range(n_layers)]
190
+ )
191
+ self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(n_layers)])
192
+ self.out_proj = nn.Linear(hidden_dim, out_dim)
193
+
194
+ def forward(self, x):
195
+ x = self.in_proj(x)
196
+
197
+ for layer, norms in zip(self.layers, self.norms):
198
+ x = x + layer(norms(x))
199
+
200
+ return self.out_proj(x)
201
+
202
+
203
+ @maybe_allow_in_graph
204
+ class ChromaSingleTransformerBlock(nn.Module):
205
+ def __init__(
206
+ self,
207
+ dim: int,
208
+ num_attention_heads: int,
209
+ attention_head_dim: int,
210
+ mlp_ratio: float = 4.0,
211
+ ):
212
+ super().__init__()
213
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
214
+ self.norm = ChromaAdaLayerNormZeroSinglePruned(dim)
215
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
216
+ self.act_mlp = nn.GELU(approximate="tanh")
217
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
218
+
219
+ if is_torch_npu_available():
220
+ from ..attention_processor import FluxAttnProcessor2_0_NPU
221
+
222
+ deprecation_message = (
223
+ "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
224
+ "should be set explicitly using the `set_attn_processor` method."
225
+ )
226
+ deprecate("npu_processor", "0.34.0", deprecation_message)
227
+ processor = FluxAttnProcessor2_0_NPU()
228
+ else:
229
+ processor = FluxAttnProcessor()
230
+
231
+ self.attn = FluxAttention(
232
+ query_dim=dim,
233
+ dim_head=attention_head_dim,
234
+ heads=num_attention_heads,
235
+ out_dim=dim,
236
+ bias=True,
237
+ processor=processor,
238
+ eps=1e-6,
239
+ pre_only=True,
240
+ )
241
+
242
+ def forward(
243
+ self,
244
+ hidden_states: torch.Tensor,
245
+ temb: torch.Tensor,
246
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
247
+ attention_mask: Optional[torch.Tensor] = None,
248
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
249
+ ) -> torch.Tensor:
250
+ residual = hidden_states
251
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
252
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
253
+ joint_attention_kwargs = joint_attention_kwargs or {}
254
+
255
+ if attention_mask is not None:
256
+ attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
257
+
258
+ attn_output = self.attn(
259
+ hidden_states=norm_hidden_states,
260
+ image_rotary_emb=image_rotary_emb,
261
+ attention_mask=attention_mask,
262
+ **joint_attention_kwargs,
263
+ )
264
+
265
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
266
+ gate = gate.unsqueeze(1)
267
+ hidden_states = gate * self.proj_out(hidden_states)
268
+ hidden_states = residual + hidden_states
269
+ if hidden_states.dtype == torch.float16:
270
+ hidden_states = hidden_states.clip(-65504, 65504)
271
+
272
+ return hidden_states
273
+
274
+
275
+ @maybe_allow_in_graph
276
+ class ChromaTransformerBlock(nn.Module):
277
+ def __init__(
278
+ self,
279
+ dim: int,
280
+ num_attention_heads: int,
281
+ attention_head_dim: int,
282
+ qk_norm: str = "rms_norm",
283
+ eps: float = 1e-6,
284
+ ):
285
+ super().__init__()
286
+ self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
287
+ self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
288
+
289
+ self.attn = FluxAttention(
290
+ query_dim=dim,
291
+ added_kv_proj_dim=dim,
292
+ dim_head=attention_head_dim,
293
+ heads=num_attention_heads,
294
+ out_dim=dim,
295
+ context_pre_only=False,
296
+ bias=True,
297
+ processor=FluxAttnProcessor(),
298
+ eps=eps,
299
+ )
300
+
301
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
302
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
303
+
304
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
305
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
306
+
307
+ def forward(
308
+ self,
309
+ hidden_states: torch.Tensor,
310
+ encoder_hidden_states: torch.Tensor,
311
+ temb: torch.Tensor,
312
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
313
+ attention_mask: Optional[torch.Tensor] = None,
314
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
315
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
316
+ temb_img, temb_txt = temb[:, :6], temb[:, 6:]
317
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb_img)
318
+
319
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
320
+ encoder_hidden_states, emb=temb_txt
321
+ )
322
+ joint_attention_kwargs = joint_attention_kwargs or {}
323
+ if attention_mask is not None:
324
+ attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
325
+
326
+ # Attention.
327
+ attention_outputs = self.attn(
328
+ hidden_states=norm_hidden_states,
329
+ encoder_hidden_states=norm_encoder_hidden_states,
330
+ image_rotary_emb=image_rotary_emb,
331
+ attention_mask=attention_mask,
332
+ **joint_attention_kwargs,
333
+ )
334
+
335
+ if len(attention_outputs) == 2:
336
+ attn_output, context_attn_output = attention_outputs
337
+ elif len(attention_outputs) == 3:
338
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
339
+
340
+ # Process attention outputs for the `hidden_states`.
341
+ attn_output = gate_msa.unsqueeze(1) * attn_output
342
+ hidden_states = hidden_states + attn_output
343
+
344
+ norm_hidden_states = self.norm2(hidden_states)
345
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
346
+
347
+ ff_output = self.ff(norm_hidden_states)
348
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
349
+
350
+ hidden_states = hidden_states + ff_output
351
+ if len(attention_outputs) == 3:
352
+ hidden_states = hidden_states + ip_attn_output
353
+
354
+ # Process attention outputs for the `encoder_hidden_states`.
355
+
356
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
357
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
358
+
359
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
360
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
361
+
362
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
363
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
364
+ if encoder_hidden_states.dtype == torch.float16:
365
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
366
+
367
+ return encoder_hidden_states, hidden_states
368
+
369
+
370
+ class ChromaTransformer2DModel(
371
+ ModelMixin,
372
+ ConfigMixin,
373
+ PeftAdapterMixin,
374
+ FromOriginalModelMixin,
375
+ FluxTransformer2DLoadersMixin,
376
+ CacheMixin,
377
+ AttentionMixin,
378
+ ):
379
+ """
380
+ The Transformer model introduced in Flux, modified for Chroma.
381
+
382
+ Reference: https://huggingface.co/lodestones/Chroma
383
+
384
+ Args:
385
+ patch_size (`int`, defaults to `1`):
386
+ Patch size to turn the input data into small patches.
387
+ in_channels (`int`, defaults to `64`):
388
+ The number of channels in the input.
389
+ out_channels (`int`, *optional*, defaults to `None`):
390
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
391
+ num_layers (`int`, defaults to `19`):
392
+ The number of layers of dual stream DiT blocks to use.
393
+ num_single_layers (`int`, defaults to `38`):
394
+ The number of layers of single stream DiT blocks to use.
395
+ attention_head_dim (`int`, defaults to `128`):
396
+ The number of dimensions to use for each attention head.
397
+ num_attention_heads (`int`, defaults to `24`):
398
+ The number of attention heads to use.
399
+ joint_attention_dim (`int`, defaults to `4096`):
400
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
401
+ `encoder_hidden_states`).
402
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
403
+ The dimensions to use for the rotary positional embeddings.
404
+ """
405
+
406
+ _supports_gradient_checkpointing = True
407
+ _no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
408
+ _repeated_blocks = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
409
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
410
+
411
+ @register_to_config
412
+ def __init__(
413
+ self,
414
+ patch_size: int = 1,
415
+ in_channels: int = 64,
416
+ out_channels: Optional[int] = None,
417
+ num_layers: int = 19,
418
+ num_single_layers: int = 38,
419
+ attention_head_dim: int = 128,
420
+ num_attention_heads: int = 24,
421
+ joint_attention_dim: int = 4096,
422
+ axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
423
+ approximator_num_channels: int = 64,
424
+ approximator_hidden_dim: int = 5120,
425
+ approximator_layers: int = 5,
426
+ ):
427
+ super().__init__()
428
+ self.out_channels = out_channels or in_channels
429
+ self.inner_dim = num_attention_heads * attention_head_dim
430
+
431
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
432
+
433
+ self.time_text_embed = ChromaCombinedTimestepTextProjEmbeddings(
434
+ num_channels=approximator_num_channels // 4,
435
+ out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2,
436
+ )
437
+ self.distilled_guidance_layer = ChromaApproximator(
438
+ in_dim=approximator_num_channels,
439
+ out_dim=self.inner_dim,
440
+ hidden_dim=approximator_hidden_dim,
441
+ n_layers=approximator_layers,
442
+ )
443
+
444
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
445
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
446
+
447
+ self.transformer_blocks = nn.ModuleList(
448
+ [
449
+ ChromaTransformerBlock(
450
+ dim=self.inner_dim,
451
+ num_attention_heads=num_attention_heads,
452
+ attention_head_dim=attention_head_dim,
453
+ )
454
+ for _ in range(num_layers)
455
+ ]
456
+ )
457
+
458
+ self.single_transformer_blocks = nn.ModuleList(
459
+ [
460
+ ChromaSingleTransformerBlock(
461
+ dim=self.inner_dim,
462
+ num_attention_heads=num_attention_heads,
463
+ attention_head_dim=attention_head_dim,
464
+ )
465
+ for _ in range(num_single_layers)
466
+ ]
467
+ )
468
+
469
+ self.norm_out = ChromaAdaLayerNormContinuousPruned(
470
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
471
+ )
472
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
473
+
474
+ self.gradient_checkpointing = False
475
+
476
+ def forward(
477
+ self,
478
+ hidden_states: torch.Tensor,
479
+ encoder_hidden_states: torch.Tensor = None,
480
+ timestep: torch.LongTensor = None,
481
+ img_ids: torch.Tensor = None,
482
+ txt_ids: torch.Tensor = None,
483
+ attention_mask: torch.Tensor = None,
484
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
485
+ controlnet_block_samples=None,
486
+ controlnet_single_block_samples=None,
487
+ return_dict: bool = True,
488
+ controlnet_blocks_repeat: bool = False,
489
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
490
+ """
491
+ The [`FluxTransformer2DModel`] forward method.
492
+
493
+ Args:
494
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
495
+ Input `hidden_states`.
496
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
497
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
498
+ timestep ( `torch.LongTensor`):
499
+ Used to indicate denoising step.
500
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
501
+ A list of tensors that if specified are added to the residuals of transformer blocks.
502
+ joint_attention_kwargs (`dict`, *optional*):
503
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
504
+ `self.processor` in
505
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
506
+ return_dict (`bool`, *optional*, defaults to `True`):
507
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
508
+ tuple.
509
+
510
+ Returns:
511
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
512
+ `tuple` where the first element is the sample tensor.
513
+ """
514
+ if joint_attention_kwargs is not None:
515
+ joint_attention_kwargs = joint_attention_kwargs.copy()
516
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
517
+ else:
518
+ lora_scale = 1.0
519
+
520
+ if USE_PEFT_BACKEND:
521
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
522
+ scale_lora_layers(self, lora_scale)
523
+ else:
524
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
525
+ logger.warning(
526
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
527
+ )
528
+
529
+ hidden_states = self.x_embedder(hidden_states)
530
+
531
+ timestep = timestep.to(hidden_states.dtype) * 1000
532
+
533
+ input_vec = self.time_text_embed(timestep)
534
+ pooled_temb = self.distilled_guidance_layer(input_vec)
535
+
536
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
537
+
538
+ if txt_ids.ndim == 3:
539
+ logger.warning(
540
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
541
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
542
+ )
543
+ txt_ids = txt_ids[0]
544
+ if img_ids.ndim == 3:
545
+ logger.warning(
546
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
547
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
548
+ )
549
+ img_ids = img_ids[0]
550
+
551
+ ids = torch.cat((txt_ids, img_ids), dim=0)
552
+ image_rotary_emb = self.pos_embed(ids)
553
+
554
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
555
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
556
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
557
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
558
+
559
+ for index_block, block in enumerate(self.transformer_blocks):
560
+ img_offset = 3 * len(self.single_transformer_blocks)
561
+ txt_offset = img_offset + 6 * len(self.transformer_blocks)
562
+ img_modulation = img_offset + 6 * index_block
563
+ text_modulation = txt_offset + 6 * index_block
564
+ temb = torch.cat(
565
+ (
566
+ pooled_temb[:, img_modulation : img_modulation + 6],
567
+ pooled_temb[:, text_modulation : text_modulation + 6],
568
+ ),
569
+ dim=1,
570
+ )
571
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
572
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
573
+ block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask
574
+ )
575
+
576
+ else:
577
+ encoder_hidden_states, hidden_states = block(
578
+ hidden_states=hidden_states,
579
+ encoder_hidden_states=encoder_hidden_states,
580
+ temb=temb,
581
+ image_rotary_emb=image_rotary_emb,
582
+ attention_mask=attention_mask,
583
+ joint_attention_kwargs=joint_attention_kwargs,
584
+ )
585
+
586
+ # controlnet residual
587
+ if controlnet_block_samples is not None:
588
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
589
+ interval_control = int(np.ceil(interval_control))
590
+ # For Xlabs ControlNet.
591
+ if controlnet_blocks_repeat:
592
+ hidden_states = (
593
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
594
+ )
595
+ else:
596
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
597
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
598
+
599
+ for index_block, block in enumerate(self.single_transformer_blocks):
600
+ start_idx = 3 * index_block
601
+ temb = pooled_temb[:, start_idx : start_idx + 3]
602
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
603
+ hidden_states = self._gradient_checkpointing_func(
604
+ block,
605
+ hidden_states,
606
+ temb,
607
+ image_rotary_emb,
608
+ )
609
+
610
+ else:
611
+ hidden_states = block(
612
+ hidden_states=hidden_states,
613
+ temb=temb,
614
+ image_rotary_emb=image_rotary_emb,
615
+ attention_mask=attention_mask,
616
+ joint_attention_kwargs=joint_attention_kwargs,
617
+ )
618
+
619
+ # controlnet residual
620
+ if controlnet_single_block_samples is not None:
621
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
622
+ interval_control = int(np.ceil(interval_control))
623
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
624
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
625
+ + controlnet_single_block_samples[index_block // interval_control]
626
+ )
627
+
628
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
629
+
630
+ temb = pooled_temb[:, -2:]
631
+ hidden_states = self.norm_out(hidden_states, temb)
632
+ output = self.proj_out(hidden_states)
633
+
634
+ if USE_PEFT_BACKEND:
635
+ # remove `lora_scale` from each PEFT layer
636
+ unscale_lora_layers(self, lora_scale)
637
+
638
+ if not return_dict:
639
+ return (output,)
640
+
641
+ return Transformer2DModelOutput(sample=output)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -19,18 +19,13 @@ import torch
19
19
  import torch.nn as nn
20
20
 
21
21
  from ...configuration_utils import ConfigMixin, register_to_config
22
- from ...models.attention import FeedForward
23
- from ...models.attention_processor import (
24
- Attention,
25
- AttentionProcessor,
26
- CogVideoXAttnProcessor2_0,
27
- )
28
- from ...models.modeling_utils import ModelMixin
29
- from ...models.normalization import AdaLayerNormContinuous
30
22
  from ...utils import logging
23
+ from ..attention import FeedForward
24
+ from ..attention_processor import Attention, AttentionProcessor, CogVideoXAttnProcessor2_0
31
25
  from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
32
26
  from ..modeling_outputs import Transformer2DModelOutput
33
- from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
27
+ from ..modeling_utils import ModelMixin
28
+ from ..normalization import AdaLayerNormContinuous, CogView3PlusAdaLayerNormZeroTextImage
34
29
 
35
30
 
36
31
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name