diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +595 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
  384. diffusers-0.33.1.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,469 @@
1
+ # Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...utils import logging
24
+ from ..attention_processor import Attention
25
+ from ..embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
26
+ from ..modeling_outputs import Transformer2DModelOutput
27
+ from ..modeling_utils import ModelMixin
28
+ from ..normalization import AdaLayerNorm, RMSNorm
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ class OmniGenFeedForward(nn.Module):
35
+ def __init__(self, hidden_size: int, intermediate_size: int):
36
+ super().__init__()
37
+
38
+ self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
39
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
40
+ self.activation_fn = nn.SiLU()
41
+
42
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
43
+ up_states = self.gate_up_proj(hidden_states)
44
+ gate, up_states = up_states.chunk(2, dim=-1)
45
+ up_states = up_states * self.activation_fn(gate)
46
+ return self.down_proj(up_states)
47
+
48
+
49
+ class OmniGenPatchEmbed(nn.Module):
50
+ def __init__(
51
+ self,
52
+ patch_size: int = 2,
53
+ in_channels: int = 4,
54
+ embed_dim: int = 768,
55
+ bias: bool = True,
56
+ interpolation_scale: float = 1,
57
+ pos_embed_max_size: int = 192,
58
+ base_size: int = 64,
59
+ ):
60
+ super().__init__()
61
+
62
+ self.output_image_proj = nn.Conv2d(
63
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
64
+ )
65
+ self.input_image_proj = nn.Conv2d(
66
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
67
+ )
68
+
69
+ self.patch_size = patch_size
70
+ self.interpolation_scale = interpolation_scale
71
+ self.pos_embed_max_size = pos_embed_max_size
72
+
73
+ pos_embed = get_2d_sincos_pos_embed(
74
+ embed_dim,
75
+ self.pos_embed_max_size,
76
+ base_size=base_size,
77
+ interpolation_scale=self.interpolation_scale,
78
+ output_type="pt",
79
+ )
80
+ self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=True)
81
+
82
+ def _cropped_pos_embed(self, height, width):
83
+ """Crops positional embeddings for SD3 compatibility."""
84
+ if self.pos_embed_max_size is None:
85
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
86
+
87
+ height = height // self.patch_size
88
+ width = width // self.patch_size
89
+ if height > self.pos_embed_max_size:
90
+ raise ValueError(
91
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
92
+ )
93
+ if width > self.pos_embed_max_size:
94
+ raise ValueError(
95
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
96
+ )
97
+
98
+ top = (self.pos_embed_max_size - height) // 2
99
+ left = (self.pos_embed_max_size - width) // 2
100
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
101
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
102
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
103
+ return spatial_pos_embed
104
+
105
+ def _patch_embeddings(self, hidden_states: torch.Tensor, is_input_image: bool) -> torch.Tensor:
106
+ if is_input_image:
107
+ hidden_states = self.input_image_proj(hidden_states)
108
+ else:
109
+ hidden_states = self.output_image_proj(hidden_states)
110
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
111
+ return hidden_states
112
+
113
+ def forward(
114
+ self, hidden_states: torch.Tensor, is_input_image: bool, padding_latent: torch.Tensor = None
115
+ ) -> torch.Tensor:
116
+ if isinstance(hidden_states, list):
117
+ if padding_latent is None:
118
+ padding_latent = [None] * len(hidden_states)
119
+ patched_latents = []
120
+ for sub_latent, padding in zip(hidden_states, padding_latent):
121
+ height, width = sub_latent.shape[-2:]
122
+ sub_latent = self._patch_embeddings(sub_latent, is_input_image)
123
+ pos_embed = self._cropped_pos_embed(height, width)
124
+ sub_latent = sub_latent + pos_embed
125
+ if padding is not None:
126
+ sub_latent = torch.cat([sub_latent, padding.to(sub_latent.device)], dim=-2)
127
+ patched_latents.append(sub_latent)
128
+ else:
129
+ height, width = hidden_states.shape[-2:]
130
+ pos_embed = self._cropped_pos_embed(height, width)
131
+ hidden_states = self._patch_embeddings(hidden_states, is_input_image)
132
+ patched_latents = hidden_states + pos_embed
133
+
134
+ return patched_latents
135
+
136
+
137
+ class OmniGenSuScaledRotaryEmbedding(nn.Module):
138
+ def __init__(
139
+ self, dim, max_position_embeddings=131072, original_max_position_embeddings=4096, base=10000, rope_scaling=None
140
+ ):
141
+ super().__init__()
142
+
143
+ self.dim = dim
144
+ self.max_position_embeddings = max_position_embeddings
145
+ self.base = base
146
+
147
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
148
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
149
+
150
+ self.short_factor = rope_scaling["short_factor"]
151
+ self.long_factor = rope_scaling["long_factor"]
152
+ self.original_max_position_embeddings = original_max_position_embeddings
153
+
154
+ def forward(self, hidden_states, position_ids):
155
+ seq_len = torch.max(position_ids) + 1
156
+ if seq_len > self.original_max_position_embeddings:
157
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=hidden_states.device)
158
+ else:
159
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=hidden_states.device)
160
+
161
+ inv_freq_shape = (
162
+ torch.arange(0, self.dim, 2, dtype=torch.int64, device=hidden_states.device).float() / self.dim
163
+ )
164
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
165
+
166
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
167
+ position_ids_expanded = position_ids[:, None, :].float()
168
+
169
+ # Force float32 since bfloat16 loses precision on long contexts
170
+ # See https://github.com/huggingface/transformers/pull/29285
171
+ device_type = hidden_states.device.type
172
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
173
+ with torch.autocast(device_type=device_type, enabled=False):
174
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
175
+ emb = torch.cat((freqs, freqs), dim=-1)[0]
176
+
177
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
178
+ if scale <= 1.0:
179
+ scaling_factor = 1.0
180
+ else:
181
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
182
+
183
+ cos = emb.cos() * scaling_factor
184
+ sin = emb.sin() * scaling_factor
185
+ return cos, sin
186
+
187
+
188
+ class OmniGenAttnProcessor2_0:
189
+ r"""
190
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
191
+ used in the OmniGen model.
192
+ """
193
+
194
+ def __init__(self):
195
+ if not hasattr(F, "scaled_dot_product_attention"):
196
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
197
+
198
+ def __call__(
199
+ self,
200
+ attn: Attention,
201
+ hidden_states: torch.Tensor,
202
+ encoder_hidden_states: torch.Tensor,
203
+ attention_mask: Optional[torch.Tensor] = None,
204
+ image_rotary_emb: Optional[torch.Tensor] = None,
205
+ ) -> torch.Tensor:
206
+ batch_size, sequence_length, _ = hidden_states.shape
207
+
208
+ # Get Query-Key-Value Pair
209
+ query = attn.to_q(hidden_states)
210
+ key = attn.to_k(encoder_hidden_states)
211
+ value = attn.to_v(encoder_hidden_states)
212
+
213
+ bsz, q_len, query_dim = query.size()
214
+ inner_dim = key.shape[-1]
215
+ head_dim = query_dim // attn.heads
216
+
217
+ # Get key-value heads
218
+ kv_heads = inner_dim // head_dim
219
+
220
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
221
+ key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
222
+ value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
223
+
224
+ # Apply RoPE if needed
225
+ if image_rotary_emb is not None:
226
+ from ..embeddings import apply_rotary_emb
227
+
228
+ query = apply_rotary_emb(query, image_rotary_emb, use_real_unbind_dim=-2)
229
+ key = apply_rotary_emb(key, image_rotary_emb, use_real_unbind_dim=-2)
230
+
231
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
232
+ hidden_states = hidden_states.transpose(1, 2).type_as(query)
233
+ hidden_states = hidden_states.reshape(bsz, q_len, attn.out_dim)
234
+ hidden_states = attn.to_out[0](hidden_states)
235
+ return hidden_states
236
+
237
+
238
+ class OmniGenBlock(nn.Module):
239
+ def __init__(
240
+ self,
241
+ hidden_size: int,
242
+ num_attention_heads: int,
243
+ num_key_value_heads: int,
244
+ intermediate_size: int,
245
+ rms_norm_eps: float,
246
+ ) -> None:
247
+ super().__init__()
248
+
249
+ self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
250
+ self.self_attn = Attention(
251
+ query_dim=hidden_size,
252
+ cross_attention_dim=hidden_size,
253
+ dim_head=hidden_size // num_attention_heads,
254
+ heads=num_attention_heads,
255
+ kv_heads=num_key_value_heads,
256
+ bias=False,
257
+ out_dim=hidden_size,
258
+ out_bias=False,
259
+ processor=OmniGenAttnProcessor2_0(),
260
+ )
261
+ self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
262
+ self.mlp = OmniGenFeedForward(hidden_size, intermediate_size)
263
+
264
+ def forward(
265
+ self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor
266
+ ) -> torch.Tensor:
267
+ # 1. Attention
268
+ norm_hidden_states = self.input_layernorm(hidden_states)
269
+ attn_output = self.self_attn(
270
+ hidden_states=norm_hidden_states,
271
+ encoder_hidden_states=norm_hidden_states,
272
+ attention_mask=attention_mask,
273
+ image_rotary_emb=image_rotary_emb,
274
+ )
275
+ hidden_states = hidden_states + attn_output
276
+
277
+ # 2. Feed Forward
278
+ norm_hidden_states = self.post_attention_layernorm(hidden_states)
279
+ ff_output = self.mlp(norm_hidden_states)
280
+ hidden_states = hidden_states + ff_output
281
+ return hidden_states
282
+
283
+
284
+ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin):
285
+ """
286
+ The Transformer model introduced in OmniGen (https://arxiv.org/pdf/2409.11340).
287
+
288
+ Parameters:
289
+ in_channels (`int`, defaults to `4`):
290
+ The number of channels in the input.
291
+ patch_size (`int`, defaults to `2`):
292
+ The size of the spatial patches to use in the patch embedding layer.
293
+ hidden_size (`int`, defaults to `3072`):
294
+ The dimensionality of the hidden layers in the model.
295
+ rms_norm_eps (`float`, defaults to `1e-5`):
296
+ Eps for RMSNorm layer.
297
+ num_attention_heads (`int`, defaults to `32`):
298
+ The number of heads to use for multi-head attention.
299
+ num_key_value_heads (`int`, defaults to `32`):
300
+ The number of heads to use for keys and values in multi-head attention.
301
+ intermediate_size (`int`, defaults to `8192`):
302
+ Dimension of the hidden layer in FeedForward layers.
303
+ num_layers (`int`, default to `32`):
304
+ The number of layers of transformer blocks to use.
305
+ pad_token_id (`int`, default to `32000`):
306
+ The id of the padding token.
307
+ vocab_size (`int`, default to `32064`):
308
+ The size of the vocabulary of the embedding vocabulary.
309
+ rope_base (`int`, default to `10000`):
310
+ The default theta value to use when creating RoPE.
311
+ rope_scaling (`Dict`, optional):
312
+ The scaling factors for the RoPE. Must contain `short_factor` and `long_factor`.
313
+ pos_embed_max_size (`int`, default to `192`):
314
+ The maximum size of the positional embeddings.
315
+ time_step_dim (`int`, default to `256`):
316
+ Output dimension of timestep embeddings.
317
+ flip_sin_to_cos (`bool`, default to `True`):
318
+ Whether to flip the sin and cos in the positional embeddings when preparing timestep embeddings.
319
+ downscale_freq_shift (`int`, default to `0`):
320
+ The frequency shift to use when downscaling the timestep embeddings.
321
+ timestep_activation_fn (`str`, default to `silu`):
322
+ The activation function to use for the timestep embeddings.
323
+ """
324
+
325
+ _supports_gradient_checkpointing = True
326
+ _no_split_modules = ["OmniGenBlock"]
327
+ _skip_layerwise_casting_patterns = ["patch_embedding", "embed_tokens", "norm"]
328
+
329
+ @register_to_config
330
+ def __init__(
331
+ self,
332
+ in_channels: int = 4,
333
+ patch_size: int = 2,
334
+ hidden_size: int = 3072,
335
+ rms_norm_eps: float = 1e-5,
336
+ num_attention_heads: int = 32,
337
+ num_key_value_heads: int = 32,
338
+ intermediate_size: int = 8192,
339
+ num_layers: int = 32,
340
+ pad_token_id: int = 32000,
341
+ vocab_size: int = 32064,
342
+ max_position_embeddings: int = 131072,
343
+ original_max_position_embeddings: int = 4096,
344
+ rope_base: int = 10000,
345
+ rope_scaling: Dict = None,
346
+ pos_embed_max_size: int = 192,
347
+ time_step_dim: int = 256,
348
+ flip_sin_to_cos: bool = True,
349
+ downscale_freq_shift: int = 0,
350
+ timestep_activation_fn: str = "silu",
351
+ ):
352
+ super().__init__()
353
+ self.in_channels = in_channels
354
+ self.out_channels = in_channels
355
+
356
+ self.patch_embedding = OmniGenPatchEmbed(
357
+ patch_size=patch_size,
358
+ in_channels=in_channels,
359
+ embed_dim=hidden_size,
360
+ pos_embed_max_size=pos_embed_max_size,
361
+ )
362
+
363
+ self.time_proj = Timesteps(time_step_dim, flip_sin_to_cos, downscale_freq_shift)
364
+ self.time_token = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)
365
+ self.t_embedder = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)
366
+
367
+ self.embed_tokens = nn.Embedding(vocab_size, hidden_size, pad_token_id)
368
+ self.rope = OmniGenSuScaledRotaryEmbedding(
369
+ hidden_size // num_attention_heads,
370
+ max_position_embeddings=max_position_embeddings,
371
+ original_max_position_embeddings=original_max_position_embeddings,
372
+ base=rope_base,
373
+ rope_scaling=rope_scaling,
374
+ )
375
+
376
+ self.layers = nn.ModuleList(
377
+ [
378
+ OmniGenBlock(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, rms_norm_eps)
379
+ for _ in range(num_layers)
380
+ ]
381
+ )
382
+
383
+ self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
384
+ self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1)
385
+ self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
386
+
387
+ self.gradient_checkpointing = False
388
+
389
+ def _get_multimodal_embeddings(
390
+ self, input_ids: torch.Tensor, input_img_latents: List[torch.Tensor], input_image_sizes: Dict
391
+ ) -> Optional[torch.Tensor]:
392
+ if input_ids is None:
393
+ return None
394
+
395
+ input_img_latents = [x.to(self.dtype) for x in input_img_latents]
396
+ condition_tokens = self.embed_tokens(input_ids)
397
+ input_img_inx = 0
398
+ input_image_tokens = self.patch_embedding(input_img_latents, is_input_image=True)
399
+ for b_inx in input_image_sizes.keys():
400
+ for start_inx, end_inx in input_image_sizes[b_inx]:
401
+ # replace the placeholder in text tokens with the image embedding.
402
+ condition_tokens[b_inx, start_inx:end_inx] = input_image_tokens[input_img_inx].to(
403
+ condition_tokens.dtype
404
+ )
405
+ input_img_inx += 1
406
+ return condition_tokens
407
+
408
+ def forward(
409
+ self,
410
+ hidden_states: torch.Tensor,
411
+ timestep: Union[int, float, torch.FloatTensor],
412
+ input_ids: torch.Tensor,
413
+ input_img_latents: List[torch.Tensor],
414
+ input_image_sizes: Dict[int, List[int]],
415
+ attention_mask: torch.Tensor,
416
+ position_ids: torch.Tensor,
417
+ return_dict: bool = True,
418
+ ) -> Union[Transformer2DModelOutput, Tuple[torch.Tensor]]:
419
+ batch_size, num_channels, height, width = hidden_states.shape
420
+ p = self.config.patch_size
421
+ post_patch_height, post_patch_width = height // p, width // p
422
+
423
+ # 1. Patch & Timestep & Conditional Embedding
424
+ hidden_states = self.patch_embedding(hidden_states, is_input_image=False)
425
+ num_tokens_for_output_image = hidden_states.size(1)
426
+
427
+ timestep_proj = self.time_proj(timestep).type_as(hidden_states)
428
+ time_token = self.time_token(timestep_proj).unsqueeze(1)
429
+ temb = self.t_embedder(timestep_proj)
430
+
431
+ condition_tokens = self._get_multimodal_embeddings(input_ids, input_img_latents, input_image_sizes)
432
+ if condition_tokens is not None:
433
+ hidden_states = torch.cat([condition_tokens, time_token, hidden_states], dim=1)
434
+ else:
435
+ hidden_states = torch.cat([time_token, hidden_states], dim=1)
436
+
437
+ seq_length = hidden_states.size(1)
438
+ position_ids = position_ids.view(-1, seq_length).long()
439
+
440
+ # 2. Attention mask preprocessing
441
+ if attention_mask is not None and attention_mask.dim() == 3:
442
+ dtype = hidden_states.dtype
443
+ min_dtype = torch.finfo(dtype).min
444
+ attention_mask = (1 - attention_mask) * min_dtype
445
+ attention_mask = attention_mask.unsqueeze(1).type_as(hidden_states)
446
+
447
+ # 3. Rotary position embedding
448
+ image_rotary_emb = self.rope(hidden_states, position_ids)
449
+
450
+ # 4. Transformer blocks
451
+ for block in self.layers:
452
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
453
+ hidden_states = self._gradient_checkpointing_func(
454
+ block, hidden_states, attention_mask, image_rotary_emb
455
+ )
456
+ else:
457
+ hidden_states = block(hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb)
458
+
459
+ # 5. Output norm & projection
460
+ hidden_states = self.norm(hidden_states)
461
+ hidden_states = hidden_states[:, -num_tokens_for_output_image:]
462
+ hidden_states = self.norm_out(hidden_states, temb=temb)
463
+ hidden_states = self.proj_out(hidden_states)
464
+ hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, p, p, -1)
465
+ output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
466
+
467
+ if not return_dict:
468
+ return (output,)
469
+ return Transformer2DModelOutput(sample=output)