diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,548 @@
1
+ # Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...loaders import PeftAdapterMixin
24
+ from ...loaders.single_file_model import FromOriginalModelMixin
25
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
26
+ from ..attention import LuminaFeedForward
27
+ from ..attention_processor import Attention
28
+ from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
29
+ from ..modeling_outputs import Transformer2DModelOutput
30
+ from ..modeling_utils import ModelMixin
31
+ from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
38
+ def __init__(
39
+ self,
40
+ hidden_size: int = 4096,
41
+ cap_feat_dim: int = 2048,
42
+ frequency_embedding_size: int = 256,
43
+ norm_eps: float = 1e-5,
44
+ ) -> None:
45
+ super().__init__()
46
+
47
+ self.time_proj = Timesteps(
48
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
49
+ )
50
+
51
+ self.timestep_embedder = TimestepEmbedding(
52
+ in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
53
+ )
54
+
55
+ self.caption_embedder = nn.Sequential(
56
+ RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True)
57
+ )
58
+
59
+ def forward(
60
+ self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor
61
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
62
+ timestep_proj = self.time_proj(timestep).type_as(hidden_states)
63
+ time_embed = self.timestep_embedder(timestep_proj)
64
+ caption_embed = self.caption_embedder(encoder_hidden_states)
65
+ return time_embed, caption_embed
66
+
67
+
68
+ class Lumina2AttnProcessor2_0:
69
+ r"""
70
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
71
+ used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors.
72
+ """
73
+
74
+ def __init__(self):
75
+ if not hasattr(F, "scaled_dot_product_attention"):
76
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
77
+
78
+ def __call__(
79
+ self,
80
+ attn: Attention,
81
+ hidden_states: torch.Tensor,
82
+ encoder_hidden_states: torch.Tensor,
83
+ attention_mask: Optional[torch.Tensor] = None,
84
+ image_rotary_emb: Optional[torch.Tensor] = None,
85
+ base_sequence_length: Optional[int] = None,
86
+ ) -> torch.Tensor:
87
+ batch_size, sequence_length, _ = hidden_states.shape
88
+
89
+ # Get Query-Key-Value Pair
90
+ query = attn.to_q(hidden_states)
91
+ key = attn.to_k(encoder_hidden_states)
92
+ value = attn.to_v(encoder_hidden_states)
93
+
94
+ query_dim = query.shape[-1]
95
+ inner_dim = key.shape[-1]
96
+ head_dim = query_dim // attn.heads
97
+ dtype = query.dtype
98
+
99
+ # Get key-value heads
100
+ kv_heads = inner_dim // head_dim
101
+
102
+ query = query.view(batch_size, -1, attn.heads, head_dim)
103
+ key = key.view(batch_size, -1, kv_heads, head_dim)
104
+ value = value.view(batch_size, -1, kv_heads, head_dim)
105
+
106
+ # Apply Query-Key Norm if needed
107
+ if attn.norm_q is not None:
108
+ query = attn.norm_q(query)
109
+ if attn.norm_k is not None:
110
+ key = attn.norm_k(key)
111
+
112
+ # Apply RoPE if needed
113
+ if image_rotary_emb is not None:
114
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
115
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
116
+
117
+ query, key = query.to(dtype), key.to(dtype)
118
+
119
+ # Apply proportional attention if true
120
+ if base_sequence_length is not None:
121
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
122
+ else:
123
+ softmax_scale = attn.scale
124
+
125
+ # perform Grouped-qurey Attention (GQA)
126
+ n_rep = attn.heads // kv_heads
127
+ if n_rep >= 1:
128
+ key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
129
+ value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
130
+
131
+ # scaled_dot_product_attention expects attention_mask shape to be
132
+ # (batch, heads, source_length, target_length)
133
+ if attention_mask is not None:
134
+ attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
135
+
136
+ query = query.transpose(1, 2)
137
+ key = key.transpose(1, 2)
138
+ value = value.transpose(1, 2)
139
+
140
+ hidden_states = F.scaled_dot_product_attention(
141
+ query, key, value, attn_mask=attention_mask, scale=softmax_scale
142
+ )
143
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
144
+ hidden_states = hidden_states.type_as(query)
145
+
146
+ # linear proj
147
+ hidden_states = attn.to_out[0](hidden_states)
148
+ hidden_states = attn.to_out[1](hidden_states)
149
+ return hidden_states
150
+
151
+
152
+ class Lumina2TransformerBlock(nn.Module):
153
+ def __init__(
154
+ self,
155
+ dim: int,
156
+ num_attention_heads: int,
157
+ num_kv_heads: int,
158
+ multiple_of: int,
159
+ ffn_dim_multiplier: float,
160
+ norm_eps: float,
161
+ modulation: bool = True,
162
+ ) -> None:
163
+ super().__init__()
164
+ self.head_dim = dim // num_attention_heads
165
+ self.modulation = modulation
166
+
167
+ self.attn = Attention(
168
+ query_dim=dim,
169
+ cross_attention_dim=None,
170
+ dim_head=dim // num_attention_heads,
171
+ qk_norm="rms_norm",
172
+ heads=num_attention_heads,
173
+ kv_heads=num_kv_heads,
174
+ eps=1e-5,
175
+ bias=False,
176
+ out_bias=False,
177
+ processor=Lumina2AttnProcessor2_0(),
178
+ )
179
+
180
+ self.feed_forward = LuminaFeedForward(
181
+ dim=dim,
182
+ inner_dim=4 * dim,
183
+ multiple_of=multiple_of,
184
+ ffn_dim_multiplier=ffn_dim_multiplier,
185
+ )
186
+
187
+ if modulation:
188
+ self.norm1 = LuminaRMSNormZero(
189
+ embedding_dim=dim,
190
+ norm_eps=norm_eps,
191
+ norm_elementwise_affine=True,
192
+ )
193
+ else:
194
+ self.norm1 = RMSNorm(dim, eps=norm_eps)
195
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
196
+
197
+ self.norm2 = RMSNorm(dim, eps=norm_eps)
198
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
199
+
200
+ def forward(
201
+ self,
202
+ hidden_states: torch.Tensor,
203
+ attention_mask: torch.Tensor,
204
+ image_rotary_emb: torch.Tensor,
205
+ temb: Optional[torch.Tensor] = None,
206
+ ) -> torch.Tensor:
207
+ if self.modulation:
208
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
209
+ attn_output = self.attn(
210
+ hidden_states=norm_hidden_states,
211
+ encoder_hidden_states=norm_hidden_states,
212
+ attention_mask=attention_mask,
213
+ image_rotary_emb=image_rotary_emb,
214
+ )
215
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
216
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
217
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
218
+ else:
219
+ norm_hidden_states = self.norm1(hidden_states)
220
+ attn_output = self.attn(
221
+ hidden_states=norm_hidden_states,
222
+ encoder_hidden_states=norm_hidden_states,
223
+ attention_mask=attention_mask,
224
+ image_rotary_emb=image_rotary_emb,
225
+ )
226
+ hidden_states = hidden_states + self.norm2(attn_output)
227
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
228
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
229
+
230
+ return hidden_states
231
+
232
+
233
+ class Lumina2RotaryPosEmbed(nn.Module):
234
+ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2):
235
+ super().__init__()
236
+ self.theta = theta
237
+ self.axes_dim = axes_dim
238
+ self.axes_lens = axes_lens
239
+ self.patch_size = patch_size
240
+
241
+ self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta)
242
+
243
+ def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
244
+ freqs_cis = []
245
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
246
+ for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
247
+ emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=freqs_dtype)
248
+ freqs_cis.append(emb)
249
+ return freqs_cis
250
+
251
+ def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
252
+ device = ids.device
253
+ if ids.device.type == "mps":
254
+ ids = ids.to("cpu")
255
+
256
+ result = []
257
+ for i in range(len(self.axes_dim)):
258
+ freqs = self.freqs_cis[i].to(ids.device)
259
+ index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
260
+ result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
261
+ return torch.cat(result, dim=-1).to(device)
262
+
263
+ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
264
+ batch_size, channels, height, width = hidden_states.shape
265
+ p = self.patch_size
266
+ post_patch_height, post_patch_width = height // p, width // p
267
+ image_seq_len = post_patch_height * post_patch_width
268
+ device = hidden_states.device
269
+
270
+ encoder_seq_len = attention_mask.shape[1]
271
+ l_effective_cap_len = attention_mask.sum(dim=1).tolist()
272
+ seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len]
273
+ max_seq_len = max(seq_lengths)
274
+
275
+ # Create position IDs
276
+ position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
277
+
278
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
279
+ # add caption position ids
280
+ position_ids[i, :cap_seq_len, 0] = torch.arange(cap_seq_len, dtype=torch.int32, device=device)
281
+ position_ids[i, cap_seq_len:seq_len, 0] = cap_seq_len
282
+
283
+ # add image position ids
284
+ row_ids = (
285
+ torch.arange(post_patch_height, dtype=torch.int32, device=device)
286
+ .view(-1, 1)
287
+ .repeat(1, post_patch_width)
288
+ .flatten()
289
+ )
290
+ col_ids = (
291
+ torch.arange(post_patch_width, dtype=torch.int32, device=device)
292
+ .view(1, -1)
293
+ .repeat(post_patch_height, 1)
294
+ .flatten()
295
+ )
296
+ position_ids[i, cap_seq_len:seq_len, 1] = row_ids
297
+ position_ids[i, cap_seq_len:seq_len, 2] = col_ids
298
+
299
+ # Get combined rotary embeddings
300
+ freqs_cis = self._get_freqs_cis(position_ids)
301
+
302
+ # create separate rotary embeddings for captions and images
303
+ cap_freqs_cis = torch.zeros(
304
+ batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
305
+ )
306
+ img_freqs_cis = torch.zeros(
307
+ batch_size, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
308
+ )
309
+
310
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
311
+ cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
312
+ img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_seq_len:seq_len]
313
+
314
+ # image patch embeddings
315
+ hidden_states = (
316
+ hidden_states.view(batch_size, channels, post_patch_height, p, post_patch_width, p)
317
+ .permute(0, 2, 4, 3, 5, 1)
318
+ .flatten(3)
319
+ .flatten(1, 2)
320
+ )
321
+
322
+ return hidden_states, cap_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
323
+
324
+
325
+ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
326
+ r"""
327
+ Lumina2NextDiT: Diffusion model with a Transformer backbone.
328
+
329
+ Parameters:
330
+ sample_size (`int`): The width of the latent images. This is fixed during training since
331
+ it is used to learn a number of position embeddings.
332
+ patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
333
+ The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
334
+ in_channels (`int`, *optional*, defaults to 4):
335
+ The number of input channels for the model. Typically, this matches the number of channels in the input
336
+ images.
337
+ hidden_size (`int`, *optional*, defaults to 4096):
338
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
339
+ hidden representations.
340
+ num_layers (`int`, *optional*, default to 32):
341
+ The number of layers in the model. This defines the depth of the neural network.
342
+ num_attention_heads (`int`, *optional*, defaults to 32):
343
+ The number of attention heads in each attention layer. This parameter specifies how many separate attention
344
+ mechanisms are used.
345
+ num_kv_heads (`int`, *optional*, defaults to 8):
346
+ The number of key-value heads in the attention mechanism, if different from the number of attention heads.
347
+ If None, it defaults to num_attention_heads.
348
+ multiple_of (`int`, *optional*, defaults to 256):
349
+ A factor that the hidden size should be a multiple of. This can help optimize certain hardware
350
+ configurations.
351
+ ffn_dim_multiplier (`float`, *optional*):
352
+ A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
353
+ the model configuration.
354
+ norm_eps (`float`, *optional*, defaults to 1e-5):
355
+ A small value added to the denominator for numerical stability in normalization layers.
356
+ scaling_factor (`float`, *optional*, defaults to 1.0):
357
+ A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
358
+ overall scale of the model's operations.
359
+ """
360
+
361
+ _supports_gradient_checkpointing = True
362
+ _no_split_modules = ["Lumina2TransformerBlock"]
363
+ _skip_layerwise_casting_patterns = ["x_embedder", "norm"]
364
+
365
+ @register_to_config
366
+ def __init__(
367
+ self,
368
+ sample_size: int = 128,
369
+ patch_size: int = 2,
370
+ in_channels: int = 16,
371
+ out_channels: Optional[int] = None,
372
+ hidden_size: int = 2304,
373
+ num_layers: int = 26,
374
+ num_refiner_layers: int = 2,
375
+ num_attention_heads: int = 24,
376
+ num_kv_heads: int = 8,
377
+ multiple_of: int = 256,
378
+ ffn_dim_multiplier: Optional[float] = None,
379
+ norm_eps: float = 1e-5,
380
+ scaling_factor: float = 1.0,
381
+ axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
382
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
383
+ cap_feat_dim: int = 1024,
384
+ ) -> None:
385
+ super().__init__()
386
+ self.out_channels = out_channels or in_channels
387
+
388
+ # 1. Positional, patch & conditional embeddings
389
+ self.rope_embedder = Lumina2RotaryPosEmbed(
390
+ theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size
391
+ )
392
+
393
+ self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size)
394
+
395
+ self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
396
+ hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps
397
+ )
398
+
399
+ # 2. Noise and context refinement blocks
400
+ self.noise_refiner = nn.ModuleList(
401
+ [
402
+ Lumina2TransformerBlock(
403
+ hidden_size,
404
+ num_attention_heads,
405
+ num_kv_heads,
406
+ multiple_of,
407
+ ffn_dim_multiplier,
408
+ norm_eps,
409
+ modulation=True,
410
+ )
411
+ for _ in range(num_refiner_layers)
412
+ ]
413
+ )
414
+
415
+ self.context_refiner = nn.ModuleList(
416
+ [
417
+ Lumina2TransformerBlock(
418
+ hidden_size,
419
+ num_attention_heads,
420
+ num_kv_heads,
421
+ multiple_of,
422
+ ffn_dim_multiplier,
423
+ norm_eps,
424
+ modulation=False,
425
+ )
426
+ for _ in range(num_refiner_layers)
427
+ ]
428
+ )
429
+
430
+ # 3. Transformer blocks
431
+ self.layers = nn.ModuleList(
432
+ [
433
+ Lumina2TransformerBlock(
434
+ hidden_size,
435
+ num_attention_heads,
436
+ num_kv_heads,
437
+ multiple_of,
438
+ ffn_dim_multiplier,
439
+ norm_eps,
440
+ modulation=True,
441
+ )
442
+ for _ in range(num_layers)
443
+ ]
444
+ )
445
+
446
+ # 4. Output norm & projection
447
+ self.norm_out = LuminaLayerNormContinuous(
448
+ embedding_dim=hidden_size,
449
+ conditioning_embedding_dim=min(hidden_size, 1024),
450
+ elementwise_affine=False,
451
+ eps=1e-6,
452
+ bias=True,
453
+ out_dim=patch_size * patch_size * self.out_channels,
454
+ )
455
+
456
+ self.gradient_checkpointing = False
457
+
458
+ def forward(
459
+ self,
460
+ hidden_states: torch.Tensor,
461
+ timestep: torch.Tensor,
462
+ encoder_hidden_states: torch.Tensor,
463
+ encoder_attention_mask: torch.Tensor,
464
+ attention_kwargs: Optional[Dict[str, Any]] = None,
465
+ return_dict: bool = True,
466
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
467
+ if attention_kwargs is not None:
468
+ attention_kwargs = attention_kwargs.copy()
469
+ lora_scale = attention_kwargs.pop("scale", 1.0)
470
+ else:
471
+ lora_scale = 1.0
472
+
473
+ if USE_PEFT_BACKEND:
474
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
475
+ scale_lora_layers(self, lora_scale)
476
+ else:
477
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
478
+ logger.warning(
479
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
480
+ )
481
+
482
+ # 1. Condition, positional & patch embedding
483
+ batch_size, _, height, width = hidden_states.shape
484
+
485
+ temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
486
+
487
+ (
488
+ hidden_states,
489
+ context_rotary_emb,
490
+ noise_rotary_emb,
491
+ rotary_emb,
492
+ encoder_seq_lengths,
493
+ seq_lengths,
494
+ ) = self.rope_embedder(hidden_states, encoder_attention_mask)
495
+
496
+ hidden_states = self.x_embedder(hidden_states)
497
+
498
+ # 2. Context & noise refinement
499
+ for layer in self.context_refiner:
500
+ encoder_hidden_states = layer(encoder_hidden_states, encoder_attention_mask, context_rotary_emb)
501
+
502
+ for layer in self.noise_refiner:
503
+ hidden_states = layer(hidden_states, None, noise_rotary_emb, temb)
504
+
505
+ # 3. Joint Transformer blocks
506
+ max_seq_len = max(seq_lengths)
507
+ use_mask = len(set(seq_lengths)) > 1
508
+
509
+ attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
510
+ joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
511
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
512
+ attention_mask[i, :seq_len] = True
513
+ joint_hidden_states[i, :encoder_seq_len] = encoder_hidden_states[i, :encoder_seq_len]
514
+ joint_hidden_states[i, encoder_seq_len:seq_len] = hidden_states[i]
515
+
516
+ hidden_states = joint_hidden_states
517
+
518
+ for layer in self.layers:
519
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
520
+ hidden_states = self._gradient_checkpointing_func(
521
+ layer, hidden_states, attention_mask if use_mask else None, rotary_emb, temb
522
+ )
523
+ else:
524
+ hidden_states = layer(hidden_states, attention_mask if use_mask else None, rotary_emb, temb)
525
+
526
+ # 4. Output norm & projection
527
+ hidden_states = self.norm_out(hidden_states, temb)
528
+
529
+ # 5. Unpatchify
530
+ p = self.config.patch_size
531
+ output = []
532
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
533
+ output.append(
534
+ hidden_states[i][encoder_seq_len:seq_len]
535
+ .view(height // p, width // p, p, p, self.out_channels)
536
+ .permute(4, 0, 2, 1, 3)
537
+ .flatten(3, 4)
538
+ .flatten(1, 2)
539
+ )
540
+ output = torch.stack(output, dim=0)
541
+
542
+ if USE_PEFT_BACKEND:
543
+ # remove `lora_scale` from each PEFT layer
544
+ unscale_lora_layers(self, lora_scale)
545
+
546
+ if not return_dict:
547
+ return (output,)
548
+ return Transformer2DModelOutput(sample=output)
@@ -21,10 +21,11 @@ import torch.nn as nn
21
21
  from ...configuration_utils import ConfigMixin, register_to_config
22
22
  from ...loaders import PeftAdapterMixin
23
23
  from ...loaders.single_file_model import FromOriginalModelMixin
24
- from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
24
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25
25
  from ...utils.torch_utils import maybe_allow_in_graph
26
26
  from ..attention import FeedForward
27
27
  from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
28
+ from ..cache_utils import CacheMixin
28
29
  from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
29
30
  from ..modeling_outputs import Transformer2DModelOutput
30
31
  from ..modeling_utils import ModelMixin
@@ -305,7 +306,7 @@ class MochiRoPE(nn.Module):
305
306
 
306
307
 
307
308
  @maybe_allow_in_graph
308
- class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
309
+ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
309
310
  r"""
310
311
  A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
311
312
 
@@ -336,6 +337,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
336
337
 
337
338
  _supports_gradient_checkpointing = True
338
339
  _no_split_modules = ["MochiTransformerBlock"]
340
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm"]
339
341
 
340
342
  @register_to_config
341
343
  def __init__(
@@ -402,10 +404,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
402
404
 
403
405
  self.gradient_checkpointing = False
404
406
 
405
- def _set_gradient_checkpointing(self, module, value=False):
406
- if hasattr(module, "gradient_checkpointing"):
407
- module.gradient_checkpointing = value
408
-
409
407
  def forward(
410
408
  self,
411
409
  hidden_states: torch.Tensor,
@@ -458,22 +456,13 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
458
456
 
459
457
  for i, block in enumerate(self.transformer_blocks):
460
458
  if torch.is_grad_enabled() and self.gradient_checkpointing:
461
-
462
- def create_custom_forward(module):
463
- def custom_forward(*inputs):
464
- return module(*inputs)
465
-
466
- return custom_forward
467
-
468
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
469
- hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
470
- create_custom_forward(block),
459
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
460
+ block,
471
461
  hidden_states,
472
462
  encoder_hidden_states,
473
463
  temb,
474
464
  encoder_attention_mask,
475
465
  image_rotary_emb,
476
- **ckpt_kwargs,
477
466
  )
478
467
  else:
479
468
  hidden_states, encoder_hidden_states = block(