diffusers 0.27.1__py3-none-any.whl → 0.32.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (445) hide show
  1. diffusers/__init__.py +233 -6
  2. diffusers/callbacks.py +209 -0
  3. diffusers/commands/env.py +102 -6
  4. diffusers/configuration_utils.py +45 -16
  5. diffusers/dependency_versions_table.py +4 -3
  6. diffusers/image_processor.py +434 -110
  7. diffusers/loaders/__init__.py +42 -9
  8. diffusers/loaders/ip_adapter.py +626 -36
  9. diffusers/loaders/lora_base.py +900 -0
  10. diffusers/loaders/lora_conversion_utils.py +991 -125
  11. diffusers/loaders/lora_pipeline.py +3812 -0
  12. diffusers/loaders/peft.py +571 -7
  13. diffusers/loaders/single_file.py +405 -173
  14. diffusers/loaders/single_file_model.py +385 -0
  15. diffusers/loaders/single_file_utils.py +1783 -713
  16. diffusers/loaders/textual_inversion.py +41 -23
  17. diffusers/loaders/transformer_flux.py +181 -0
  18. diffusers/loaders/transformer_sd3.py +89 -0
  19. diffusers/loaders/unet.py +464 -540
  20. diffusers/loaders/unet_loader_utils.py +163 -0
  21. diffusers/models/__init__.py +76 -7
  22. diffusers/models/activations.py +65 -10
  23. diffusers/models/adapter.py +53 -53
  24. diffusers/models/attention.py +605 -18
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +4304 -687
  27. diffusers/models/autoencoders/__init__.py +8 -0
  28. diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
  29. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  30. diffusers/models/autoencoders/autoencoder_kl.py +110 -28
  31. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  32. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  33. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  34. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  35. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  36. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
  37. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  38. diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
  39. diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
  40. diffusers/models/autoencoders/vae.py +41 -29
  41. diffusers/models/autoencoders/vq_model.py +182 -0
  42. diffusers/models/controlnet.py +47 -800
  43. diffusers/models/controlnet_flux.py +70 -0
  44. diffusers/models/controlnet_sd3.py +68 -0
  45. diffusers/models/controlnet_sparsectrl.py +116 -0
  46. diffusers/models/controlnets/__init__.py +23 -0
  47. diffusers/models/controlnets/controlnet.py +872 -0
  48. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
  49. diffusers/models/controlnets/controlnet_flux.py +536 -0
  50. diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
  51. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  52. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  53. diffusers/models/controlnets/controlnet_union.py +832 -0
  54. diffusers/models/controlnets/controlnet_xs.py +1946 -0
  55. diffusers/models/controlnets/multicontrolnet.py +183 -0
  56. diffusers/models/downsampling.py +85 -18
  57. diffusers/models/embeddings.py +1856 -158
  58. diffusers/models/embeddings_flax.py +23 -9
  59. diffusers/models/model_loading_utils.py +480 -0
  60. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  61. diffusers/models/modeling_flax_utils.py +2 -7
  62. diffusers/models/modeling_outputs.py +14 -0
  63. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  64. diffusers/models/modeling_utils.py +611 -146
  65. diffusers/models/normalization.py +361 -20
  66. diffusers/models/resnet.py +18 -23
  67. diffusers/models/transformers/__init__.py +16 -0
  68. diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
  69. diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
  70. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  71. diffusers/models/transformers/dual_transformer_2d.py +9 -8
  72. diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
  73. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  74. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  75. diffusers/models/transformers/pixart_transformer_2d.py +445 -0
  76. diffusers/models/transformers/prior_transformer.py +13 -13
  77. diffusers/models/transformers/sana_transformer.py +488 -0
  78. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  79. diffusers/models/transformers/t5_film_transformer.py +17 -19
  80. diffusers/models/transformers/transformer_2d.py +297 -187
  81. diffusers/models/transformers/transformer_allegro.py +422 -0
  82. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  83. diffusers/models/transformers/transformer_flux.py +593 -0
  84. diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
  85. diffusers/models/transformers/transformer_ltx.py +469 -0
  86. diffusers/models/transformers/transformer_mochi.py +499 -0
  87. diffusers/models/transformers/transformer_sd3.py +461 -0
  88. diffusers/models/transformers/transformer_temporal.py +21 -19
  89. diffusers/models/unets/unet_1d.py +8 -8
  90. diffusers/models/unets/unet_1d_blocks.py +31 -31
  91. diffusers/models/unets/unet_2d.py +17 -10
  92. diffusers/models/unets/unet_2d_blocks.py +225 -149
  93. diffusers/models/unets/unet_2d_condition.py +41 -40
  94. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  95. diffusers/models/unets/unet_3d_blocks.py +192 -1057
  96. diffusers/models/unets/unet_3d_condition.py +22 -27
  97. diffusers/models/unets/unet_i2vgen_xl.py +22 -18
  98. diffusers/models/unets/unet_kandinsky3.py +2 -2
  99. diffusers/models/unets/unet_motion_model.py +1413 -89
  100. diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
  101. diffusers/models/unets/unet_stable_cascade.py +19 -18
  102. diffusers/models/unets/uvit_2d.py +2 -2
  103. diffusers/models/upsampling.py +95 -26
  104. diffusers/models/vq_model.py +12 -164
  105. diffusers/optimization.py +1 -1
  106. diffusers/pipelines/__init__.py +202 -3
  107. diffusers/pipelines/allegro/__init__.py +48 -0
  108. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  109. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  110. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  111. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  112. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  113. diffusers/pipelines/animatediff/__init__.py +8 -0
  114. diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
  115. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
  116. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
  117. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
  118. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
  119. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  120. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  121. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  122. diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
  123. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
  124. diffusers/pipelines/aura_flow/__init__.py +48 -0
  125. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
  126. diffusers/pipelines/auto_pipeline.py +196 -28
  127. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  128. diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
  129. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  130. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  131. diffusers/pipelines/cogvideo/__init__.py +54 -0
  132. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
  133. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  134. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
  135. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
  136. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  137. diffusers/pipelines/cogview3/__init__.py +47 -0
  138. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  139. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  140. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
  141. diffusers/pipelines/controlnet/__init__.py +86 -80
  142. diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
  143. diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
  144. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  145. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
  146. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
  147. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
  148. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
  149. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
  150. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  151. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  152. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  153. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  154. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  155. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
  156. diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
  157. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
  158. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  159. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  160. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
  161. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
  162. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  163. diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
  164. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
  165. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
  166. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
  167. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
  168. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
  169. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  170. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  171. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  172. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
  173. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
  174. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  175. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  176. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
  177. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
  178. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
  179. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
  180. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
  181. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
  182. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  183. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  184. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  185. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  186. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  187. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  188. diffusers/pipelines/flux/__init__.py +69 -0
  189. diffusers/pipelines/flux/modeling_flux.py +47 -0
  190. diffusers/pipelines/flux/pipeline_flux.py +957 -0
  191. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  192. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  193. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  194. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  195. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  196. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  197. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  198. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  199. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  200. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  201. diffusers/pipelines/flux/pipeline_output.py +37 -0
  202. diffusers/pipelines/free_init_utils.py +41 -38
  203. diffusers/pipelines/free_noise_utils.py +596 -0
  204. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  205. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  206. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  207. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  208. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
  209. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  210. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  211. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
  212. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  213. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  214. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  215. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  216. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
  217. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  218. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  219. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  220. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  221. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  222. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  223. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  224. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
  225. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
  226. diffusers/pipelines/kolors/__init__.py +54 -0
  227. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  228. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
  229. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  230. diffusers/pipelines/kolors/text_encoder.py +889 -0
  231. diffusers/pipelines/kolors/tokenizer.py +338 -0
  232. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
  233. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
  234. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
  235. diffusers/pipelines/latte/__init__.py +48 -0
  236. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  237. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
  238. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
  239. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  240. diffusers/pipelines/ltx/__init__.py +50 -0
  241. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  242. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  243. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  244. diffusers/pipelines/lumina/__init__.py +48 -0
  245. diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
  246. diffusers/pipelines/marigold/__init__.py +50 -0
  247. diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
  248. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  249. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  250. diffusers/pipelines/mochi/__init__.py +48 -0
  251. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  252. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  253. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  254. diffusers/pipelines/pag/__init__.py +80 -0
  255. diffusers/pipelines/pag/pag_utils.py +243 -0
  256. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
  257. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  258. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
  259. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  260. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
  261. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  262. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
  263. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  264. diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
  265. diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
  266. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  267. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
  268. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  269. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  270. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
  271. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
  272. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
  273. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  274. diffusers/pipelines/pia/pipeline_pia.py +74 -164
  275. diffusers/pipelines/pipeline_flax_utils.py +5 -10
  276. diffusers/pipelines/pipeline_loading_utils.py +515 -53
  277. diffusers/pipelines/pipeline_utils.py +411 -222
  278. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  279. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
  280. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
  281. diffusers/pipelines/sana/__init__.py +47 -0
  282. diffusers/pipelines/sana/pipeline_output.py +21 -0
  283. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  284. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
  285. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  286. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  287. diffusers/pipelines/shap_e/renderer.py +1 -1
  288. diffusers/pipelines/stable_audio/__init__.py +50 -0
  289. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  290. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
  291. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
  292. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  293. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
  294. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  295. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
  296. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  297. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  298. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  299. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
  300. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
  301. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  302. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
  303. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
  304. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
  305. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
  306. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
  307. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
  308. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
  309. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  310. diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  311. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  312. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  313. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
  314. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
  315. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
  316. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
  317. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
  318. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
  319. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
  320. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
  321. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
  322. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
  323. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  324. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  325. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
  326. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
  327. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
  328. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
  329. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
  330. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  331. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  332. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
  333. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
  334. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  335. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
  336. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
  337. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
  338. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
  339. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  340. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  341. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  342. diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
  343. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
  344. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  345. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  346. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
  347. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  348. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  349. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
  350. diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
  351. diffusers/quantizers/auto.py +139 -0
  352. diffusers/quantizers/base.py +233 -0
  353. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  354. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  355. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  356. diffusers/quantizers/gguf/__init__.py +1 -0
  357. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  358. diffusers/quantizers/gguf/utils.py +456 -0
  359. diffusers/quantizers/quantization_config.py +669 -0
  360. diffusers/quantizers/torchao/__init__.py +15 -0
  361. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  362. diffusers/schedulers/__init__.py +12 -2
  363. diffusers/schedulers/deprecated/__init__.py +1 -1
  364. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  365. diffusers/schedulers/scheduling_amused.py +5 -5
  366. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  367. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  368. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  369. diffusers/schedulers/scheduling_ddim.py +27 -26
  370. diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
  371. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  372. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  373. diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
  374. diffusers/schedulers/scheduling_ddpm.py +27 -30
  375. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  376. diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
  377. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  378. diffusers/schedulers/scheduling_deis_multistep.py +150 -50
  379. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  380. diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
  381. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  382. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
  383. diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
  384. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
  385. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
  386. diffusers/schedulers/scheduling_edm_euler.py +62 -39
  387. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
  388. diffusers/schedulers/scheduling_euler_discrete.py +255 -74
  389. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
  390. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
  391. diffusers/schedulers/scheduling_heun_discrete.py +174 -46
  392. diffusers/schedulers/scheduling_ipndm.py +9 -9
  393. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
  394. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
  395. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  396. diffusers/schedulers/scheduling_lcm.py +23 -29
  397. diffusers/schedulers/scheduling_lms_discrete.py +105 -28
  398. diffusers/schedulers/scheduling_pndm.py +20 -20
  399. diffusers/schedulers/scheduling_repaint.py +21 -21
  400. diffusers/schedulers/scheduling_sasolver.py +157 -60
  401. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  402. diffusers/schedulers/scheduling_tcd.py +41 -36
  403. diffusers/schedulers/scheduling_unclip.py +19 -16
  404. diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
  405. diffusers/schedulers/scheduling_utils.py +12 -5
  406. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  407. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  408. diffusers/training_utils.py +214 -30
  409. diffusers/utils/__init__.py +17 -1
  410. diffusers/utils/constants.py +3 -0
  411. diffusers/utils/doc_utils.py +1 -0
  412. diffusers/utils/dummy_pt_objects.py +592 -7
  413. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  414. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  415. diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
  416. diffusers/utils/dynamic_modules_utils.py +34 -29
  417. diffusers/utils/export_utils.py +50 -6
  418. diffusers/utils/hub_utils.py +131 -17
  419. diffusers/utils/import_utils.py +210 -8
  420. diffusers/utils/loading_utils.py +118 -5
  421. diffusers/utils/logging.py +4 -2
  422. diffusers/utils/peft_utils.py +37 -7
  423. diffusers/utils/state_dict_utils.py +13 -2
  424. diffusers/utils/testing_utils.py +193 -11
  425. diffusers/utils/torch_utils.py +4 -0
  426. diffusers/video_processor.py +113 -0
  427. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
  428. diffusers-0.32.2.dist-info/RECORD +550 -0
  429. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
  430. diffusers/loaders/autoencoder.py +0 -146
  431. diffusers/loaders/controlnet.py +0 -136
  432. diffusers/loaders/lora.py +0 -1349
  433. diffusers/models/prior_transformer.py +0 -12
  434. diffusers/models/t5_film_transformer.py +0 -70
  435. diffusers/models/transformer_2d.py +0 -25
  436. diffusers/models/transformer_temporal.py +0 -34
  437. diffusers/models/unet_1d.py +0 -26
  438. diffusers/models/unet_1d_blocks.py +0 -203
  439. diffusers/models/unet_2d.py +0 -27
  440. diffusers/models/unet_2d_blocks.py +0 -375
  441. diffusers/models/unet_2d_condition.py +0 -25
  442. diffusers-0.27.1.dist-info/RECORD +0 -399
  443. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
  444. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
  445. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Any, Dict, Optional
14
+ from typing import Any, Dict, List, Optional, Tuple
15
15
 
16
16
  import torch
17
17
  import torch.nn.functional as F
@@ -19,10 +19,10 @@ from torch import nn
19
19
 
20
20
  from ..utils import deprecate, logging
21
21
  from ..utils.torch_utils import maybe_allow_in_graph
22
- from .activations import GEGLU, GELU, ApproximateGELU
23
- from .attention_processor import Attention
22
+ from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
23
+ from .attention_processor import Attention, JointAttnProcessor2_0
24
24
  from .embeddings import SinusoidalPositionalEmbedding
25
- from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
25
+ from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
26
26
 
27
27
 
28
28
  logger = logging.get_logger(__name__)
@@ -85,6 +85,178 @@ class GatedSelfAttentionDense(nn.Module):
85
85
  return x
86
86
 
87
87
 
88
+ @maybe_allow_in_graph
89
+ class JointTransformerBlock(nn.Module):
90
+ r"""
91
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
92
+
93
+ Reference: https://arxiv.org/abs/2403.03206
94
+
95
+ Parameters:
96
+ dim (`int`): The number of channels in the input and output.
97
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
98
+ attention_head_dim (`int`): The number of channels in each head.
99
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
100
+ processing of `context` conditions.
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ dim: int,
106
+ num_attention_heads: int,
107
+ attention_head_dim: int,
108
+ context_pre_only: bool = False,
109
+ qk_norm: Optional[str] = None,
110
+ use_dual_attention: bool = False,
111
+ ):
112
+ super().__init__()
113
+
114
+ self.use_dual_attention = use_dual_attention
115
+ self.context_pre_only = context_pre_only
116
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
117
+
118
+ if use_dual_attention:
119
+ self.norm1 = SD35AdaLayerNormZeroX(dim)
120
+ else:
121
+ self.norm1 = AdaLayerNormZero(dim)
122
+
123
+ if context_norm_type == "ada_norm_continous":
124
+ self.norm1_context = AdaLayerNormContinuous(
125
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
126
+ )
127
+ elif context_norm_type == "ada_norm_zero":
128
+ self.norm1_context = AdaLayerNormZero(dim)
129
+ else:
130
+ raise ValueError(
131
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
132
+ )
133
+
134
+ if hasattr(F, "scaled_dot_product_attention"):
135
+ processor = JointAttnProcessor2_0()
136
+ else:
137
+ raise ValueError(
138
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
139
+ )
140
+
141
+ self.attn = Attention(
142
+ query_dim=dim,
143
+ cross_attention_dim=None,
144
+ added_kv_proj_dim=dim,
145
+ dim_head=attention_head_dim,
146
+ heads=num_attention_heads,
147
+ out_dim=dim,
148
+ context_pre_only=context_pre_only,
149
+ bias=True,
150
+ processor=processor,
151
+ qk_norm=qk_norm,
152
+ eps=1e-6,
153
+ )
154
+
155
+ if use_dual_attention:
156
+ self.attn2 = Attention(
157
+ query_dim=dim,
158
+ cross_attention_dim=None,
159
+ dim_head=attention_head_dim,
160
+ heads=num_attention_heads,
161
+ out_dim=dim,
162
+ bias=True,
163
+ processor=processor,
164
+ qk_norm=qk_norm,
165
+ eps=1e-6,
166
+ )
167
+ else:
168
+ self.attn2 = None
169
+
170
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
171
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
172
+
173
+ if not context_pre_only:
174
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
175
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
176
+ else:
177
+ self.norm2_context = None
178
+ self.ff_context = None
179
+
180
+ # let chunk size default to None
181
+ self._chunk_size = None
182
+ self._chunk_dim = 0
183
+
184
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
185
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
186
+ # Sets chunk feed-forward
187
+ self._chunk_size = chunk_size
188
+ self._chunk_dim = dim
189
+
190
+ def forward(
191
+ self,
192
+ hidden_states: torch.FloatTensor,
193
+ encoder_hidden_states: torch.FloatTensor,
194
+ temb: torch.FloatTensor,
195
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
196
+ ):
197
+ joint_attention_kwargs = joint_attention_kwargs or {}
198
+ if self.use_dual_attention:
199
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
200
+ hidden_states, emb=temb
201
+ )
202
+ else:
203
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
204
+
205
+ if self.context_pre_only:
206
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
207
+ else:
208
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
209
+ encoder_hidden_states, emb=temb
210
+ )
211
+
212
+ # Attention.
213
+ attn_output, context_attn_output = self.attn(
214
+ hidden_states=norm_hidden_states,
215
+ encoder_hidden_states=norm_encoder_hidden_states,
216
+ **joint_attention_kwargs,
217
+ )
218
+
219
+ # Process attention outputs for the `hidden_states`.
220
+ attn_output = gate_msa.unsqueeze(1) * attn_output
221
+ hidden_states = hidden_states + attn_output
222
+
223
+ if self.use_dual_attention:
224
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
225
+ attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
226
+ hidden_states = hidden_states + attn_output2
227
+
228
+ norm_hidden_states = self.norm2(hidden_states)
229
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
230
+ if self._chunk_size is not None:
231
+ # "feed_forward_chunk_size" can be used to save memory
232
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
233
+ else:
234
+ ff_output = self.ff(norm_hidden_states)
235
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
236
+
237
+ hidden_states = hidden_states + ff_output
238
+
239
+ # Process attention outputs for the `encoder_hidden_states`.
240
+ if self.context_pre_only:
241
+ encoder_hidden_states = None
242
+ else:
243
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
244
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
245
+
246
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
247
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
248
+ if self._chunk_size is not None:
249
+ # "feed_forward_chunk_size" can be used to save memory
250
+ context_ff_output = _chunked_feed_forward(
251
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
252
+ )
253
+ else:
254
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
255
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
256
+
257
+ return encoder_hidden_states, hidden_states
258
+
259
+
88
260
  @maybe_allow_in_graph
89
261
  class BasicTransformerBlock(nn.Module):
90
262
  r"""
@@ -148,6 +320,17 @@ class BasicTransformerBlock(nn.Module):
148
320
  attention_out_bias: bool = True,
149
321
  ):
150
322
  super().__init__()
323
+ self.dim = dim
324
+ self.num_attention_heads = num_attention_heads
325
+ self.attention_head_dim = attention_head_dim
326
+ self.dropout = dropout
327
+ self.cross_attention_dim = cross_attention_dim
328
+ self.activation_fn = activation_fn
329
+ self.attention_bias = attention_bias
330
+ self.double_self_attention = double_self_attention
331
+ self.norm_elementwise_affine = norm_elementwise_affine
332
+ self.positional_embeddings = positional_embeddings
333
+ self.num_positional_embeddings = num_positional_embeddings
151
334
  self.only_cross_attention = only_cross_attention
152
335
 
153
336
  # We keep these boolean flags for backward-compatibility.
@@ -235,7 +418,10 @@ class BasicTransformerBlock(nn.Module):
235
418
  out_bias=attention_out_bias,
236
419
  ) # is self-attn if encoder_hidden_states is none
237
420
  else:
238
- self.norm2 = None
421
+ if norm_type == "ada_norm_single": # For Latte
422
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
423
+ else:
424
+ self.norm2 = None
239
425
  self.attn2 = None
240
426
 
241
427
  # 3. Feed-forward
@@ -249,7 +435,7 @@ class BasicTransformerBlock(nn.Module):
249
435
  "layer_norm",
250
436
  )
251
437
 
252
- elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]:
438
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
253
439
  self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
254
440
  elif norm_type == "layer_norm_i2vgen":
255
441
  self.norm3 = None
@@ -282,18 +468,18 @@ class BasicTransformerBlock(nn.Module):
282
468
 
283
469
  def forward(
284
470
  self,
285
- hidden_states: torch.FloatTensor,
286
- attention_mask: Optional[torch.FloatTensor] = None,
287
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
288
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
471
+ hidden_states: torch.Tensor,
472
+ attention_mask: Optional[torch.Tensor] = None,
473
+ encoder_hidden_states: Optional[torch.Tensor] = None,
474
+ encoder_attention_mask: Optional[torch.Tensor] = None,
289
475
  timestep: Optional[torch.LongTensor] = None,
290
476
  cross_attention_kwargs: Dict[str, Any] = None,
291
477
  class_labels: Optional[torch.LongTensor] = None,
292
478
  added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
293
- ) -> torch.FloatTensor:
479
+ ) -> torch.Tensor:
294
480
  if cross_attention_kwargs is not None:
295
481
  if cross_attention_kwargs.get("scale", None) is not None:
296
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
482
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
297
483
 
298
484
  # Notice that normalization is always applied before the real computation in the following blocks.
299
485
  # 0. Self-Attention
@@ -315,7 +501,6 @@ class BasicTransformerBlock(nn.Module):
315
501
  ).chunk(6, dim=1)
316
502
  norm_hidden_states = self.norm1(hidden_states)
317
503
  norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
318
- norm_hidden_states = norm_hidden_states.squeeze(1)
319
504
  else:
320
505
  raise ValueError("Incorrect norm used")
321
506
 
@@ -332,6 +517,7 @@ class BasicTransformerBlock(nn.Module):
332
517
  attention_mask=attention_mask,
333
518
  **cross_attention_kwargs,
334
519
  )
520
+
335
521
  if self.norm_type == "ada_norm_zero":
336
522
  attn_output = gate_msa.unsqueeze(1) * attn_output
337
523
  elif self.norm_type == "ada_norm_single":
@@ -403,6 +589,56 @@ class BasicTransformerBlock(nn.Module):
403
589
  return hidden_states
404
590
 
405
591
 
592
+ class LuminaFeedForward(nn.Module):
593
+ r"""
594
+ A feed-forward layer.
595
+
596
+ Parameters:
597
+ hidden_size (`int`):
598
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
599
+ hidden representations.
600
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
601
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
602
+ of this value.
603
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
604
+ dimension. Defaults to None.
605
+ """
606
+
607
+ def __init__(
608
+ self,
609
+ dim: int,
610
+ inner_dim: int,
611
+ multiple_of: Optional[int] = 256,
612
+ ffn_dim_multiplier: Optional[float] = None,
613
+ ):
614
+ super().__init__()
615
+ inner_dim = int(2 * inner_dim / 3)
616
+ # custom hidden_size factor multiplier
617
+ if ffn_dim_multiplier is not None:
618
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
619
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
620
+
621
+ self.linear_1 = nn.Linear(
622
+ dim,
623
+ inner_dim,
624
+ bias=False,
625
+ )
626
+ self.linear_2 = nn.Linear(
627
+ inner_dim,
628
+ dim,
629
+ bias=False,
630
+ )
631
+ self.linear_3 = nn.Linear(
632
+ dim,
633
+ inner_dim,
634
+ bias=False,
635
+ )
636
+ self.silu = FP32SiLU()
637
+
638
+ def forward(self, x):
639
+ return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
640
+
641
+
406
642
  @maybe_allow_in_graph
407
643
  class TemporalBasicTransformerBlock(nn.Module):
408
644
  r"""
@@ -477,10 +713,10 @@ class TemporalBasicTransformerBlock(nn.Module):
477
713
 
478
714
  def forward(
479
715
  self,
480
- hidden_states: torch.FloatTensor,
716
+ hidden_states: torch.Tensor,
481
717
  num_frames: int,
482
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
483
- ) -> torch.FloatTensor:
718
+ encoder_hidden_states: Optional[torch.Tensor] = None,
719
+ ) -> torch.Tensor:
484
720
  # Notice that normalization is always applied before the real computation in the following blocks.
485
721
  # 0. Self-Attention
486
722
  batch_size = hidden_states.shape[0]
@@ -605,6 +841,354 @@ class SkipFFTransformerBlock(nn.Module):
605
841
  return hidden_states
606
842
 
607
843
 
844
+ @maybe_allow_in_graph
845
+ class FreeNoiseTransformerBlock(nn.Module):
846
+ r"""
847
+ A FreeNoise Transformer block.
848
+
849
+ Parameters:
850
+ dim (`int`):
851
+ The number of channels in the input and output.
852
+ num_attention_heads (`int`):
853
+ The number of heads to use for multi-head attention.
854
+ attention_head_dim (`int`):
855
+ The number of channels in each head.
856
+ dropout (`float`, *optional*, defaults to 0.0):
857
+ The dropout probability to use.
858
+ cross_attention_dim (`int`, *optional*):
859
+ The size of the encoder_hidden_states vector for cross attention.
860
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
861
+ Activation function to be used in feed-forward.
862
+ num_embeds_ada_norm (`int`, *optional*):
863
+ The number of diffusion steps used during training. See `Transformer2DModel`.
864
+ attention_bias (`bool`, defaults to `False`):
865
+ Configure if the attentions should contain a bias parameter.
866
+ only_cross_attention (`bool`, defaults to `False`):
867
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
868
+ double_self_attention (`bool`, defaults to `False`):
869
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
870
+ upcast_attention (`bool`, defaults to `False`):
871
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
872
+ norm_elementwise_affine (`bool`, defaults to `True`):
873
+ Whether to use learnable elementwise affine parameters for normalization.
874
+ norm_type (`str`, defaults to `"layer_norm"`):
875
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
876
+ final_dropout (`bool` defaults to `False`):
877
+ Whether to apply a final dropout after the last feed-forward layer.
878
+ attention_type (`str`, defaults to `"default"`):
879
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
880
+ positional_embeddings (`str`, *optional*):
881
+ The type of positional embeddings to apply to.
882
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
883
+ The maximum number of positional embeddings to apply.
884
+ ff_inner_dim (`int`, *optional*):
885
+ Hidden dimension of feed-forward MLP.
886
+ ff_bias (`bool`, defaults to `True`):
887
+ Whether or not to use bias in feed-forward MLP.
888
+ attention_out_bias (`bool`, defaults to `True`):
889
+ Whether or not to use bias in attention output project layer.
890
+ context_length (`int`, defaults to `16`):
891
+ The maximum number of frames that the FreeNoise block processes at once.
892
+ context_stride (`int`, defaults to `4`):
893
+ The number of frames to be skipped before starting to process a new batch of `context_length` frames.
894
+ weighting_scheme (`str`, defaults to `"pyramid"`):
895
+ The weighting scheme to use for weighting averaging of processed latent frames. As described in the
896
+ Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
897
+ used.
898
+ """
899
+
900
+ def __init__(
901
+ self,
902
+ dim: int,
903
+ num_attention_heads: int,
904
+ attention_head_dim: int,
905
+ dropout: float = 0.0,
906
+ cross_attention_dim: Optional[int] = None,
907
+ activation_fn: str = "geglu",
908
+ num_embeds_ada_norm: Optional[int] = None,
909
+ attention_bias: bool = False,
910
+ only_cross_attention: bool = False,
911
+ double_self_attention: bool = False,
912
+ upcast_attention: bool = False,
913
+ norm_elementwise_affine: bool = True,
914
+ norm_type: str = "layer_norm",
915
+ norm_eps: float = 1e-5,
916
+ final_dropout: bool = False,
917
+ positional_embeddings: Optional[str] = None,
918
+ num_positional_embeddings: Optional[int] = None,
919
+ ff_inner_dim: Optional[int] = None,
920
+ ff_bias: bool = True,
921
+ attention_out_bias: bool = True,
922
+ context_length: int = 16,
923
+ context_stride: int = 4,
924
+ weighting_scheme: str = "pyramid",
925
+ ):
926
+ super().__init__()
927
+ self.dim = dim
928
+ self.num_attention_heads = num_attention_heads
929
+ self.attention_head_dim = attention_head_dim
930
+ self.dropout = dropout
931
+ self.cross_attention_dim = cross_attention_dim
932
+ self.activation_fn = activation_fn
933
+ self.attention_bias = attention_bias
934
+ self.double_self_attention = double_self_attention
935
+ self.norm_elementwise_affine = norm_elementwise_affine
936
+ self.positional_embeddings = positional_embeddings
937
+ self.num_positional_embeddings = num_positional_embeddings
938
+ self.only_cross_attention = only_cross_attention
939
+
940
+ self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
941
+
942
+ # We keep these boolean flags for backward-compatibility.
943
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
944
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
945
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
946
+ self.use_layer_norm = norm_type == "layer_norm"
947
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
948
+
949
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
950
+ raise ValueError(
951
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
952
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
953
+ )
954
+
955
+ self.norm_type = norm_type
956
+ self.num_embeds_ada_norm = num_embeds_ada_norm
957
+
958
+ if positional_embeddings and (num_positional_embeddings is None):
959
+ raise ValueError(
960
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
961
+ )
962
+
963
+ if positional_embeddings == "sinusoidal":
964
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
965
+ else:
966
+ self.pos_embed = None
967
+
968
+ # Define 3 blocks. Each block has its own normalization layer.
969
+ # 1. Self-Attn
970
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
971
+
972
+ self.attn1 = Attention(
973
+ query_dim=dim,
974
+ heads=num_attention_heads,
975
+ dim_head=attention_head_dim,
976
+ dropout=dropout,
977
+ bias=attention_bias,
978
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
979
+ upcast_attention=upcast_attention,
980
+ out_bias=attention_out_bias,
981
+ )
982
+
983
+ # 2. Cross-Attn
984
+ if cross_attention_dim is not None or double_self_attention:
985
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
986
+
987
+ self.attn2 = Attention(
988
+ query_dim=dim,
989
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
990
+ heads=num_attention_heads,
991
+ dim_head=attention_head_dim,
992
+ dropout=dropout,
993
+ bias=attention_bias,
994
+ upcast_attention=upcast_attention,
995
+ out_bias=attention_out_bias,
996
+ ) # is self-attn if encoder_hidden_states is none
997
+
998
+ # 3. Feed-forward
999
+ self.ff = FeedForward(
1000
+ dim,
1001
+ dropout=dropout,
1002
+ activation_fn=activation_fn,
1003
+ final_dropout=final_dropout,
1004
+ inner_dim=ff_inner_dim,
1005
+ bias=ff_bias,
1006
+ )
1007
+
1008
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
1009
+
1010
+ # let chunk size default to None
1011
+ self._chunk_size = None
1012
+ self._chunk_dim = 0
1013
+
1014
+ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
1015
+ frame_indices = []
1016
+ for i in range(0, num_frames - self.context_length + 1, self.context_stride):
1017
+ window_start = i
1018
+ window_end = min(num_frames, i + self.context_length)
1019
+ frame_indices.append((window_start, window_end))
1020
+ return frame_indices
1021
+
1022
+ def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
1023
+ if weighting_scheme == "flat":
1024
+ weights = [1.0] * num_frames
1025
+
1026
+ elif weighting_scheme == "pyramid":
1027
+ if num_frames % 2 == 0:
1028
+ # num_frames = 4 => [1, 2, 2, 1]
1029
+ mid = num_frames // 2
1030
+ weights = list(range(1, mid + 1))
1031
+ weights = weights + weights[::-1]
1032
+ else:
1033
+ # num_frames = 5 => [1, 2, 3, 2, 1]
1034
+ mid = (num_frames + 1) // 2
1035
+ weights = list(range(1, mid))
1036
+ weights = weights + [mid] + weights[::-1]
1037
+
1038
+ elif weighting_scheme == "delayed_reverse_sawtooth":
1039
+ if num_frames % 2 == 0:
1040
+ # num_frames = 4 => [0.01, 2, 2, 1]
1041
+ mid = num_frames // 2
1042
+ weights = [0.01] * (mid - 1) + [mid]
1043
+ weights = weights + list(range(mid, 0, -1))
1044
+ else:
1045
+ # num_frames = 5 => [0.01, 0.01, 3, 2, 1]
1046
+ mid = (num_frames + 1) // 2
1047
+ weights = [0.01] * mid
1048
+ weights = weights + list(range(mid, 0, -1))
1049
+ else:
1050
+ raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
1051
+
1052
+ return weights
1053
+
1054
+ def set_free_noise_properties(
1055
+ self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
1056
+ ) -> None:
1057
+ self.context_length = context_length
1058
+ self.context_stride = context_stride
1059
+ self.weighting_scheme = weighting_scheme
1060
+
1061
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
1062
+ # Sets chunk feed-forward
1063
+ self._chunk_size = chunk_size
1064
+ self._chunk_dim = dim
1065
+
1066
+ def forward(
1067
+ self,
1068
+ hidden_states: torch.Tensor,
1069
+ attention_mask: Optional[torch.Tensor] = None,
1070
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1071
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1072
+ cross_attention_kwargs: Dict[str, Any] = None,
1073
+ *args,
1074
+ **kwargs,
1075
+ ) -> torch.Tensor:
1076
+ if cross_attention_kwargs is not None:
1077
+ if cross_attention_kwargs.get("scale", None) is not None:
1078
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1079
+
1080
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
1081
+
1082
+ # hidden_states: [B x H x W, F, C]
1083
+ device = hidden_states.device
1084
+ dtype = hidden_states.dtype
1085
+
1086
+ num_frames = hidden_states.size(1)
1087
+ frame_indices = self._get_frame_indices(num_frames)
1088
+ frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
1089
+ frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
1090
+ is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
1091
+
1092
+ # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
1093
+ # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
1094
+ # [(0, 16), (4, 20), (8, 24), (10, 26)]
1095
+ if not is_last_frame_batch_complete:
1096
+ if num_frames < self.context_length:
1097
+ raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
1098
+ last_frame_batch_length = num_frames - frame_indices[-1][1]
1099
+ frame_indices.append((num_frames - self.context_length, num_frames))
1100
+
1101
+ num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
1102
+ accumulated_values = torch.zeros_like(hidden_states)
1103
+
1104
+ for i, (frame_start, frame_end) in enumerate(frame_indices):
1105
+ # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
1106
+ # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
1107
+ # essentially a non-multiple of `context_length`.
1108
+ weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
1109
+ weights *= frame_weights
1110
+
1111
+ hidden_states_chunk = hidden_states[:, frame_start:frame_end]
1112
+
1113
+ # Notice that normalization is always applied before the real computation in the following blocks.
1114
+ # 1. Self-Attention
1115
+ norm_hidden_states = self.norm1(hidden_states_chunk)
1116
+
1117
+ if self.pos_embed is not None:
1118
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1119
+
1120
+ attn_output = self.attn1(
1121
+ norm_hidden_states,
1122
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1123
+ attention_mask=attention_mask,
1124
+ **cross_attention_kwargs,
1125
+ )
1126
+
1127
+ hidden_states_chunk = attn_output + hidden_states_chunk
1128
+ if hidden_states_chunk.ndim == 4:
1129
+ hidden_states_chunk = hidden_states_chunk.squeeze(1)
1130
+
1131
+ # 2. Cross-Attention
1132
+ if self.attn2 is not None:
1133
+ norm_hidden_states = self.norm2(hidden_states_chunk)
1134
+
1135
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
1136
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1137
+
1138
+ attn_output = self.attn2(
1139
+ norm_hidden_states,
1140
+ encoder_hidden_states=encoder_hidden_states,
1141
+ attention_mask=encoder_attention_mask,
1142
+ **cross_attention_kwargs,
1143
+ )
1144
+ hidden_states_chunk = attn_output + hidden_states_chunk
1145
+
1146
+ if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
1147
+ accumulated_values[:, -last_frame_batch_length:] += (
1148
+ hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
1149
+ )
1150
+ num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
1151
+ else:
1152
+ accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
1153
+ num_times_accumulated[:, frame_start:frame_end] += weights
1154
+
1155
+ # TODO(aryan): Maybe this could be done in a better way.
1156
+ #
1157
+ # Previously, this was:
1158
+ # hidden_states = torch.where(
1159
+ # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1160
+ # )
1161
+ #
1162
+ # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
1163
+ # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
1164
+ # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
1165
+ # looked into this deeply because other memory optimizations led to more pronounced reductions.
1166
+ hidden_states = torch.cat(
1167
+ [
1168
+ torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
1169
+ for accumulated_split, num_times_split in zip(
1170
+ accumulated_values.split(self.context_length, dim=1),
1171
+ num_times_accumulated.split(self.context_length, dim=1),
1172
+ )
1173
+ ],
1174
+ dim=1,
1175
+ ).to(dtype)
1176
+
1177
+ # 3. Feed-forward
1178
+ norm_hidden_states = self.norm3(hidden_states)
1179
+
1180
+ if self._chunk_size is not None:
1181
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
1182
+ else:
1183
+ ff_output = self.ff(norm_hidden_states)
1184
+
1185
+ hidden_states = ff_output + hidden_states
1186
+ if hidden_states.ndim == 4:
1187
+ hidden_states = hidden_states.squeeze(1)
1188
+
1189
+ return hidden_states
1190
+
1191
+
608
1192
  class FeedForward(nn.Module):
609
1193
  r"""
610
1194
  A feed-forward layer.
@@ -634,7 +1218,6 @@ class FeedForward(nn.Module):
634
1218
  if inner_dim is None:
635
1219
  inner_dim = int(dim * mult)
636
1220
  dim_out = dim_out if dim_out is not None else dim
637
- linear_cls = nn.Linear
638
1221
 
639
1222
  if activation_fn == "gelu":
640
1223
  act_fn = GELU(dim, inner_dim, bias=bias)
@@ -644,6 +1227,10 @@ class FeedForward(nn.Module):
644
1227
  act_fn = GEGLU(dim, inner_dim, bias=bias)
645
1228
  elif activation_fn == "geglu-approximate":
646
1229
  act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1230
+ elif activation_fn == "swiglu":
1231
+ act_fn = SwiGLU(dim, inner_dim, bias=bias)
1232
+ elif activation_fn == "linear-silu":
1233
+ act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
647
1234
 
648
1235
  self.net = nn.ModuleList([])
649
1236
  # project in
@@ -651,7 +1238,7 @@ class FeedForward(nn.Module):
651
1238
  # project dropout
652
1239
  self.net.append(nn.Dropout(dropout))
653
1240
  # project out
654
- self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
1241
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
655
1242
  # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
656
1243
  if final_dropout:
657
1244
  self.net.append(nn.Dropout(dropout))