diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -196,6 +196,55 @@ class LTXVideoResnetBlock3d(nn.Module):
196
196
  return hidden_states
197
197
 
198
198
 
199
+ class LTXVideoDownsampler3d(nn.Module):
200
+ def __init__(
201
+ self,
202
+ in_channels: int,
203
+ out_channels: int,
204
+ stride: Union[int, Tuple[int, int, int]] = 1,
205
+ is_causal: bool = True,
206
+ padding_mode: str = "zeros",
207
+ ) -> None:
208
+ super().__init__()
209
+
210
+ self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
211
+ self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels
212
+
213
+ out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2])
214
+
215
+ self.conv = LTXVideoCausalConv3d(
216
+ in_channels=in_channels,
217
+ out_channels=out_channels,
218
+ kernel_size=3,
219
+ stride=1,
220
+ is_causal=is_causal,
221
+ padding_mode=padding_mode,
222
+ )
223
+
224
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
225
+ hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
226
+
227
+ residual = (
228
+ hidden_states.unflatten(4, (-1, self.stride[2]))
229
+ .unflatten(3, (-1, self.stride[1]))
230
+ .unflatten(2, (-1, self.stride[0]))
231
+ )
232
+ residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
233
+ residual = residual.unflatten(1, (-1, self.group_size))
234
+ residual = residual.mean(dim=2)
235
+
236
+ hidden_states = self.conv(hidden_states)
237
+ hidden_states = (
238
+ hidden_states.unflatten(4, (-1, self.stride[2]))
239
+ .unflatten(3, (-1, self.stride[1]))
240
+ .unflatten(2, (-1, self.stride[0]))
241
+ )
242
+ hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
243
+ hidden_states = hidden_states + residual
244
+
245
+ return hidden_states
246
+
247
+
199
248
  class LTXVideoUpsampler3d(nn.Module):
200
249
  def __init__(
201
250
  self,
@@ -204,6 +253,7 @@ class LTXVideoUpsampler3d(nn.Module):
204
253
  is_causal: bool = True,
205
254
  residual: bool = False,
206
255
  upscale_factor: int = 1,
256
+ padding_mode: str = "zeros",
207
257
  ) -> None:
208
258
  super().__init__()
209
259
 
@@ -219,6 +269,7 @@ class LTXVideoUpsampler3d(nn.Module):
219
269
  kernel_size=3,
220
270
  stride=1,
221
271
  is_causal=is_causal,
272
+ padding_mode=padding_mode,
222
273
  )
223
274
 
224
275
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -338,16 +389,122 @@ class LTXVideoDownBlock3D(nn.Module):
338
389
 
339
390
  for i, resnet in enumerate(self.resnets):
340
391
  if torch.is_grad_enabled() and self.gradient_checkpointing:
392
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
393
+ else:
394
+ hidden_states = resnet(hidden_states, temb, generator)
341
395
 
342
- def create_custom_forward(module):
343
- def create_forward(*inputs):
344
- return module(*inputs)
396
+ if self.downsamplers is not None:
397
+ for downsampler in self.downsamplers:
398
+ hidden_states = downsampler(hidden_states)
345
399
 
346
- return create_forward
400
+ if self.conv_out is not None:
401
+ hidden_states = self.conv_out(hidden_states, temb, generator)
402
+
403
+ return hidden_states
404
+
405
+
406
+ class LTXVideo095DownBlock3D(nn.Module):
407
+ r"""
408
+ Down block used in the LTXVideo model.
409
+
410
+ Args:
411
+ in_channels (`int`):
412
+ Number of input channels.
413
+ out_channels (`int`, *optional*):
414
+ Number of output channels. If None, defaults to `in_channels`.
415
+ num_layers (`int`, defaults to `1`):
416
+ Number of resnet layers.
417
+ dropout (`float`, defaults to `0.0`):
418
+ Dropout rate.
419
+ resnet_eps (`float`, defaults to `1e-6`):
420
+ Epsilon value for normalization layers.
421
+ resnet_act_fn (`str`, defaults to `"swish"`):
422
+ Activation function to use.
423
+ spatio_temporal_scale (`bool`, defaults to `True`):
424
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
425
+ Whether or not to downsample across temporal dimension.
426
+ is_causal (`bool`, defaults to `True`):
427
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
428
+ """
429
+
430
+ _supports_gradient_checkpointing = True
431
+
432
+ def __init__(
433
+ self,
434
+ in_channels: int,
435
+ out_channels: Optional[int] = None,
436
+ num_layers: int = 1,
437
+ dropout: float = 0.0,
438
+ resnet_eps: float = 1e-6,
439
+ resnet_act_fn: str = "swish",
440
+ spatio_temporal_scale: bool = True,
441
+ is_causal: bool = True,
442
+ downsample_type: str = "conv",
443
+ ):
444
+ super().__init__()
445
+
446
+ out_channels = out_channels or in_channels
447
+
448
+ resnets = []
449
+ for _ in range(num_layers):
450
+ resnets.append(
451
+ LTXVideoResnetBlock3d(
452
+ in_channels=in_channels,
453
+ out_channels=in_channels,
454
+ dropout=dropout,
455
+ eps=resnet_eps,
456
+ non_linearity=resnet_act_fn,
457
+ is_causal=is_causal,
458
+ )
459
+ )
460
+ self.resnets = nn.ModuleList(resnets)
461
+
462
+ self.downsamplers = None
463
+ if spatio_temporal_scale:
464
+ self.downsamplers = nn.ModuleList()
347
465
 
348
- hidden_states = torch.utils.checkpoint.checkpoint(
349
- create_custom_forward(resnet), hidden_states, temb, generator
466
+ if downsample_type == "conv":
467
+ self.downsamplers.append(
468
+ LTXVideoCausalConv3d(
469
+ in_channels=in_channels,
470
+ out_channels=in_channels,
471
+ kernel_size=3,
472
+ stride=(2, 2, 2),
473
+ is_causal=is_causal,
474
+ )
475
+ )
476
+ elif downsample_type == "spatial":
477
+ self.downsamplers.append(
478
+ LTXVideoDownsampler3d(
479
+ in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal
480
+ )
481
+ )
482
+ elif downsample_type == "temporal":
483
+ self.downsamplers.append(
484
+ LTXVideoDownsampler3d(
485
+ in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal
486
+ )
487
+ )
488
+ elif downsample_type == "spatiotemporal":
489
+ self.downsamplers.append(
490
+ LTXVideoDownsampler3d(
491
+ in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal
492
+ )
350
493
  )
494
+
495
+ self.gradient_checkpointing = False
496
+
497
+ def forward(
498
+ self,
499
+ hidden_states: torch.Tensor,
500
+ temb: Optional[torch.Tensor] = None,
501
+ generator: Optional[torch.Generator] = None,
502
+ ) -> torch.Tensor:
503
+ r"""Forward method of the `LTXDownBlock3D` class."""
504
+
505
+ for i, resnet in enumerate(self.resnets):
506
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
507
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
351
508
  else:
352
509
  hidden_states = resnet(hidden_states, temb, generator)
353
510
 
@@ -355,9 +512,6 @@ class LTXVideoDownBlock3D(nn.Module):
355
512
  for downsampler in self.downsamplers:
356
513
  hidden_states = downsampler(hidden_states)
357
514
 
358
- if self.conv_out is not None:
359
- hidden_states = self.conv_out(hidden_states, temb, generator)
360
-
361
515
  return hidden_states
362
516
 
363
517
 
@@ -438,16 +592,7 @@ class LTXVideoMidBlock3d(nn.Module):
438
592
 
439
593
  for i, resnet in enumerate(self.resnets):
440
594
  if torch.is_grad_enabled() and self.gradient_checkpointing:
441
-
442
- def create_custom_forward(module):
443
- def create_forward(*inputs):
444
- return module(*inputs)
445
-
446
- return create_forward
447
-
448
- hidden_states = torch.utils.checkpoint.checkpoint(
449
- create_custom_forward(resnet), hidden_states, temb, generator
450
- )
595
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
451
596
  else:
452
597
  hidden_states = resnet(hidden_states, temb, generator)
453
598
 
@@ -573,16 +718,7 @@ class LTXVideoUpBlock3d(nn.Module):
573
718
 
574
719
  for i, resnet in enumerate(self.resnets):
575
720
  if torch.is_grad_enabled() and self.gradient_checkpointing:
576
-
577
- def create_custom_forward(module):
578
- def create_forward(*inputs):
579
- return module(*inputs)
580
-
581
- return create_forward
582
-
583
- hidden_states = torch.utils.checkpoint.checkpoint(
584
- create_custom_forward(resnet), hidden_states, temb, generator
585
- )
721
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
586
722
  else:
587
723
  hidden_states = resnet(hidden_states, temb, generator)
588
724
 
@@ -620,8 +756,15 @@ class LTXVideoEncoder3d(nn.Module):
620
756
  in_channels: int = 3,
621
757
  out_channels: int = 128,
622
758
  block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
759
+ down_block_types: Tuple[str, ...] = (
760
+ "LTXVideoDownBlock3D",
761
+ "LTXVideoDownBlock3D",
762
+ "LTXVideoDownBlock3D",
763
+ "LTXVideoDownBlock3D",
764
+ ),
623
765
  spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
624
766
  layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
767
+ downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
625
768
  patch_size: int = 4,
626
769
  patch_size_t: int = 1,
627
770
  resnet_norm_eps: float = 1e-6,
@@ -644,20 +787,37 @@ class LTXVideoEncoder3d(nn.Module):
644
787
  )
645
788
 
646
789
  # down blocks
647
- num_block_out_channels = len(block_out_channels)
790
+ is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D"
791
+ num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0)
648
792
  self.down_blocks = nn.ModuleList([])
649
793
  for i in range(num_block_out_channels):
650
794
  input_channel = output_channel
651
- output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
652
-
653
- down_block = LTXVideoDownBlock3D(
654
- in_channels=input_channel,
655
- out_channels=output_channel,
656
- num_layers=layers_per_block[i],
657
- resnet_eps=resnet_norm_eps,
658
- spatio_temporal_scale=spatio_temporal_scaling[i],
659
- is_causal=is_causal,
660
- )
795
+ if not is_ltx_095:
796
+ output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
797
+ else:
798
+ output_channel = block_out_channels[i + 1]
799
+
800
+ if down_block_types[i] == "LTXVideoDownBlock3D":
801
+ down_block = LTXVideoDownBlock3D(
802
+ in_channels=input_channel,
803
+ out_channels=output_channel,
804
+ num_layers=layers_per_block[i],
805
+ resnet_eps=resnet_norm_eps,
806
+ spatio_temporal_scale=spatio_temporal_scaling[i],
807
+ is_causal=is_causal,
808
+ )
809
+ elif down_block_types[i] == "LTXVideo095DownBlock3D":
810
+ down_block = LTXVideo095DownBlock3D(
811
+ in_channels=input_channel,
812
+ out_channels=output_channel,
813
+ num_layers=layers_per_block[i],
814
+ resnet_eps=resnet_norm_eps,
815
+ spatio_temporal_scale=spatio_temporal_scaling[i],
816
+ is_causal=is_causal,
817
+ downsample_type=downsample_type[i],
818
+ )
819
+ else:
820
+ raise ValueError(f"Unknown down block type: {down_block_types[i]}")
661
821
 
662
822
  self.down_blocks.append(down_block)
663
823
 
@@ -697,17 +857,10 @@ class LTXVideoEncoder3d(nn.Module):
697
857
  hidden_states = self.conv_in(hidden_states)
698
858
 
699
859
  if torch.is_grad_enabled() and self.gradient_checkpointing:
700
-
701
- def create_custom_forward(module):
702
- def create_forward(*inputs):
703
- return module(*inputs)
704
-
705
- return create_forward
706
-
707
860
  for down_block in self.down_blocks:
708
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), hidden_states)
861
+ hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
709
862
 
710
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
863
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
711
864
  else:
712
865
  for down_block in self.down_blocks:
713
866
  hidden_states = down_block(hidden_states)
@@ -828,7 +981,9 @@ class LTXVideoDecoder3d(nn.Module):
828
981
  # timestep embedding
829
982
  self.time_embedder = None
830
983
  self.scale_shift_table = None
984
+ self.timestep_scale_multiplier = None
831
985
  if timestep_conditioning:
986
+ self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
832
987
  self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
833
988
  self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
834
989
 
@@ -837,20 +992,14 @@ class LTXVideoDecoder3d(nn.Module):
837
992
  def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
838
993
  hidden_states = self.conv_in(hidden_states)
839
994
 
840
- if torch.is_grad_enabled() and self.gradient_checkpointing:
995
+ if self.timestep_scale_multiplier is not None:
996
+ temb = temb * self.timestep_scale_multiplier
841
997
 
842
- def create_custom_forward(module):
843
- def create_forward(*inputs):
844
- return module(*inputs)
845
-
846
- return create_forward
847
-
848
- hidden_states = torch.utils.checkpoint.checkpoint(
849
- create_custom_forward(self.mid_block), hidden_states, temb
850
- )
998
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
999
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
851
1000
 
852
1001
  for up_block in self.up_blocks:
853
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states, temb)
1002
+ hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb)
854
1003
  else:
855
1004
  hidden_states = self.mid_block(hidden_states, temb)
856
1005
 
@@ -934,12 +1083,19 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
934
1083
  out_channels: int = 3,
935
1084
  latent_channels: int = 128,
936
1085
  block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
1086
+ down_block_types: Tuple[str, ...] = (
1087
+ "LTXVideoDownBlock3D",
1088
+ "LTXVideoDownBlock3D",
1089
+ "LTXVideoDownBlock3D",
1090
+ "LTXVideoDownBlock3D",
1091
+ ),
937
1092
  decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
938
1093
  layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
939
1094
  decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
940
1095
  spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
941
1096
  decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
942
1097
  decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
1098
+ downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
943
1099
  upsample_residual: Tuple[bool, ...] = (False, False, False, False),
944
1100
  upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
945
1101
  timestep_conditioning: bool = False,
@@ -949,6 +1105,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
949
1105
  scaling_factor: float = 1.0,
950
1106
  encoder_causal: bool = True,
951
1107
  decoder_causal: bool = False,
1108
+ spatial_compression_ratio: int = None,
1109
+ temporal_compression_ratio: int = None,
952
1110
  ) -> None:
953
1111
  super().__init__()
954
1112
 
@@ -956,8 +1114,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
956
1114
  in_channels=in_channels,
957
1115
  out_channels=latent_channels,
958
1116
  block_out_channels=block_out_channels,
1117
+ down_block_types=down_block_types,
959
1118
  spatio_temporal_scaling=spatio_temporal_scaling,
960
1119
  layers_per_block=layers_per_block,
1120
+ downsample_type=downsample_type,
961
1121
  patch_size=patch_size,
962
1122
  patch_size_t=patch_size_t,
963
1123
  resnet_norm_eps=resnet_norm_eps,
@@ -984,8 +1144,16 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
984
1144
  self.register_buffer("latents_mean", latents_mean, persistent=True)
985
1145
  self.register_buffer("latents_std", latents_std, persistent=True)
986
1146
 
987
- self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
988
- self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
1147
+ self.spatial_compression_ratio = (
1148
+ patch_size * 2 ** sum(spatio_temporal_scaling)
1149
+ if spatial_compression_ratio is None
1150
+ else spatial_compression_ratio
1151
+ )
1152
+ self.temporal_compression_ratio = (
1153
+ patch_size_t * 2 ** sum(spatio_temporal_scaling)
1154
+ if temporal_compression_ratio is None
1155
+ else temporal_compression_ratio
1156
+ )
989
1157
 
990
1158
  # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
991
1159
  # to perform decoding of a single video latent at a time.
@@ -1010,21 +1178,21 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1010
1178
  # The minimal tile height and width for spatial tiling to be used
1011
1179
  self.tile_sample_min_height = 512
1012
1180
  self.tile_sample_min_width = 512
1181
+ self.tile_sample_min_num_frames = 16
1013
1182
 
1014
1183
  # The minimal distance between two spatial tiles
1015
1184
  self.tile_sample_stride_height = 448
1016
1185
  self.tile_sample_stride_width = 448
1017
-
1018
- def _set_gradient_checkpointing(self, module, value=False):
1019
- if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
1020
- module.gradient_checkpointing = value
1186
+ self.tile_sample_stride_num_frames = 8
1021
1187
 
1022
1188
  def enable_tiling(
1023
1189
  self,
1024
1190
  tile_sample_min_height: Optional[int] = None,
1025
1191
  tile_sample_min_width: Optional[int] = None,
1192
+ tile_sample_min_num_frames: Optional[int] = None,
1026
1193
  tile_sample_stride_height: Optional[float] = None,
1027
1194
  tile_sample_stride_width: Optional[float] = None,
1195
+ tile_sample_stride_num_frames: Optional[float] = None,
1028
1196
  ) -> None:
1029
1197
  r"""
1030
1198
  Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
@@ -1046,8 +1214,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1046
1214
  self.use_tiling = True
1047
1215
  self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1048
1216
  self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1217
+ self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
1049
1218
  self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
1050
1219
  self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
1220
+ self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
1051
1221
 
1052
1222
  def disable_tiling(self) -> None:
1053
1223
  r"""
@@ -1073,18 +1243,13 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1073
1243
  def _encode(self, x: torch.Tensor) -> torch.Tensor:
1074
1244
  batch_size, num_channels, num_frames, height, width = x.shape
1075
1245
 
1246
+ if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:
1247
+ return self._temporal_tiled_encode(x)
1248
+
1076
1249
  if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1077
1250
  return self.tiled_encode(x)
1078
1251
 
1079
- if self.use_framewise_encoding:
1080
- # TODO(aryan): requires investigation
1081
- raise NotImplementedError(
1082
- "Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1083
- "quality issues caused by splitting inference across frame dimension. If you believe this "
1084
- "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1085
- )
1086
- else:
1087
- enc = self.encoder(x)
1252
+ enc = self.encoder(x)
1088
1253
 
1089
1254
  return enc
1090
1255
 
@@ -1121,19 +1286,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1121
1286
  batch_size, num_channels, num_frames, height, width = z.shape
1122
1287
  tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1123
1288
  tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1289
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
1290
+
1291
+ if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
1292
+ return self._temporal_tiled_decode(z, temb, return_dict=return_dict)
1124
1293
 
1125
1294
  if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
1126
1295
  return self.tiled_decode(z, temb, return_dict=return_dict)
1127
1296
 
1128
- if self.use_framewise_decoding:
1129
- # TODO(aryan): requires investigation
1130
- raise NotImplementedError(
1131
- "Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1132
- "quality issues caused by splitting inference across frame dimension. If you believe this "
1133
- "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1134
- )
1135
- else:
1136
- dec = self.decoder(z, temb)
1297
+ dec = self.decoder(z, temb)
1137
1298
 
1138
1299
  if not return_dict:
1139
1300
  return (dec,)
@@ -1189,6 +1350,14 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1189
1350
  )
1190
1351
  return b
1191
1352
 
1353
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1354
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
1355
+ for x in range(blend_extent):
1356
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
1357
+ x / blend_extent
1358
+ )
1359
+ return b
1360
+
1192
1361
  def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1193
1362
  r"""Encode a batch of images using a tiled encoder.
1194
1363
 
@@ -1217,17 +1386,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1217
1386
  for i in range(0, height, self.tile_sample_stride_height):
1218
1387
  row = []
1219
1388
  for j in range(0, width, self.tile_sample_stride_width):
1220
- if self.use_framewise_encoding:
1221
- # TODO(aryan): requires investigation
1222
- raise NotImplementedError(
1223
- "Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1224
- "quality issues caused by splitting inference across frame dimension. If you believe this "
1225
- "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1226
- )
1227
- else:
1228
- time = self.encoder(
1229
- x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
1230
- )
1389
+ time = self.encoder(
1390
+ x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
1391
+ )
1231
1392
 
1232
1393
  row.append(time)
1233
1394
  rows.append(row)
@@ -1283,17 +1444,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1283
1444
  for i in range(0, height, tile_latent_stride_height):
1284
1445
  row = []
1285
1446
  for j in range(0, width, tile_latent_stride_width):
1286
- if self.use_framewise_decoding:
1287
- # TODO(aryan): requires investigation
1288
- raise NotImplementedError(
1289
- "Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1290
- "quality issues caused by splitting inference across frame dimension. If you believe this "
1291
- "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1292
- )
1293
- else:
1294
- time = self.decoder(
1295
- z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
1296
- )
1447
+ time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb)
1297
1448
 
1298
1449
  row.append(time)
1299
1450
  rows.append(row)
@@ -1318,6 +1469,74 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1318
1469
 
1319
1470
  return DecoderOutput(sample=dec)
1320
1471
 
1472
+ def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
1473
+ batch_size, num_channels, num_frames, height, width = x.shape
1474
+ latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
1475
+
1476
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
1477
+ tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
1478
+ blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
1479
+
1480
+ row = []
1481
+ for i in range(0, num_frames, self.tile_sample_stride_num_frames):
1482
+ tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
1483
+ if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
1484
+ tile = self.tiled_encode(tile)
1485
+ else:
1486
+ tile = self.encoder(tile)
1487
+ if i > 0:
1488
+ tile = tile[:, :, 1:, :, :]
1489
+ row.append(tile)
1490
+
1491
+ result_row = []
1492
+ for i, tile in enumerate(row):
1493
+ if i > 0:
1494
+ tile = self.blend_t(row[i - 1], tile, blend_num_frames)
1495
+ result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
1496
+ else:
1497
+ result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])
1498
+
1499
+ enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
1500
+ return enc
1501
+
1502
+ def _temporal_tiled_decode(
1503
+ self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
1504
+ ) -> Union[DecoderOutput, torch.Tensor]:
1505
+ batch_size, num_channels, num_frames, height, width = z.shape
1506
+ num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
1507
+
1508
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1509
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1510
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
1511
+ tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
1512
+ blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
1513
+
1514
+ row = []
1515
+ for i in range(0, num_frames, tile_latent_stride_num_frames):
1516
+ tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
1517
+ if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
1518
+ decoded = self.tiled_decode(tile, temb, return_dict=True).sample
1519
+ else:
1520
+ decoded = self.decoder(tile, temb)
1521
+ if i > 0:
1522
+ decoded = decoded[:, :, :-1, :, :]
1523
+ row.append(decoded)
1524
+
1525
+ result_row = []
1526
+ for i, tile in enumerate(row):
1527
+ if i > 0:
1528
+ tile = self.blend_t(row[i - 1], tile, blend_num_frames)
1529
+ tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :]
1530
+ result_row.append(tile)
1531
+ else:
1532
+ result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :])
1533
+
1534
+ dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
1535
+
1536
+ if not return_dict:
1537
+ return (dec,)
1538
+ return DecoderOutput(sample=dec)
1539
+
1321
1540
  def forward(
1322
1541
  self,
1323
1542
  sample: torch.Tensor,
@@ -1334,5 +1553,5 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1334
1553
  z = posterior.mode()
1335
1554
  dec = self.decode(z, temb)
1336
1555
  if not return_dict:
1337
- return (dec,)
1556
+ return (dec.sample,)
1338
1557
  return dec