diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -34,6 +34,103 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
34
  CACHE_T = 2
35
35
 
36
36
 
37
+ class AvgDown3D(nn.Module):
38
+ def __init__(
39
+ self,
40
+ in_channels,
41
+ out_channels,
42
+ factor_t,
43
+ factor_s=1,
44
+ ):
45
+ super().__init__()
46
+ self.in_channels = in_channels
47
+ self.out_channels = out_channels
48
+ self.factor_t = factor_t
49
+ self.factor_s = factor_s
50
+ self.factor = self.factor_t * self.factor_s * self.factor_s
51
+
52
+ assert in_channels * self.factor % out_channels == 0
53
+ self.group_size = in_channels * self.factor // out_channels
54
+
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
57
+ pad = (0, 0, 0, 0, pad_t, 0)
58
+ x = F.pad(x, pad)
59
+ B, C, T, H, W = x.shape
60
+ x = x.view(
61
+ B,
62
+ C,
63
+ T // self.factor_t,
64
+ self.factor_t,
65
+ H // self.factor_s,
66
+ self.factor_s,
67
+ W // self.factor_s,
68
+ self.factor_s,
69
+ )
70
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
71
+ x = x.view(
72
+ B,
73
+ C * self.factor,
74
+ T // self.factor_t,
75
+ H // self.factor_s,
76
+ W // self.factor_s,
77
+ )
78
+ x = x.view(
79
+ B,
80
+ self.out_channels,
81
+ self.group_size,
82
+ T // self.factor_t,
83
+ H // self.factor_s,
84
+ W // self.factor_s,
85
+ )
86
+ x = x.mean(dim=2)
87
+ return x
88
+
89
+
90
+ class DupUp3D(nn.Module):
91
+ def __init__(
92
+ self,
93
+ in_channels: int,
94
+ out_channels: int,
95
+ factor_t,
96
+ factor_s=1,
97
+ ):
98
+ super().__init__()
99
+ self.in_channels = in_channels
100
+ self.out_channels = out_channels
101
+
102
+ self.factor_t = factor_t
103
+ self.factor_s = factor_s
104
+ self.factor = self.factor_t * self.factor_s * self.factor_s
105
+
106
+ assert out_channels * self.factor % in_channels == 0
107
+ self.repeats = out_channels * self.factor // in_channels
108
+
109
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
110
+ x = x.repeat_interleave(self.repeats, dim=1)
111
+ x = x.view(
112
+ x.size(0),
113
+ self.out_channels,
114
+ self.factor_t,
115
+ self.factor_s,
116
+ self.factor_s,
117
+ x.size(2),
118
+ x.size(3),
119
+ x.size(4),
120
+ )
121
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
122
+ x = x.view(
123
+ x.size(0),
124
+ self.out_channels,
125
+ x.size(2) * self.factor_t,
126
+ x.size(4) * self.factor_s,
127
+ x.size(6) * self.factor_s,
128
+ )
129
+ if first_chunk:
130
+ x = x[:, :, self.factor_t - 1 :, :, :]
131
+ return x
132
+
133
+
37
134
  class WanCausalConv3d(nn.Conv3d):
38
135
  r"""
39
136
  A custom 3D causal convolution layer with feature caching support.
@@ -134,19 +231,25 @@ class WanResample(nn.Module):
134
231
  - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
135
232
  """
136
233
 
137
- def __init__(self, dim: int, mode: str) -> None:
234
+ def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
138
235
  super().__init__()
139
236
  self.dim = dim
140
237
  self.mode = mode
141
238
 
239
+ # default to dim //2
240
+ if upsample_out_dim is None:
241
+ upsample_out_dim = dim // 2
242
+
142
243
  # layers
143
244
  if mode == "upsample2d":
144
245
  self.resample = nn.Sequential(
145
- WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
246
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
247
+ nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
146
248
  )
147
249
  elif mode == "upsample3d":
148
250
  self.resample = nn.Sequential(
149
- WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
251
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
252
+ nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
150
253
  )
151
254
  self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
152
255
 
@@ -363,6 +466,42 @@ class WanMidBlock(nn.Module):
363
466
  return x
364
467
 
365
468
 
469
+ class WanResidualDownBlock(nn.Module):
470
+ def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample=False, down_flag=False):
471
+ super().__init__()
472
+
473
+ # Shortcut path with downsample
474
+ self.avg_shortcut = AvgDown3D(
475
+ in_dim,
476
+ out_dim,
477
+ factor_t=2 if temperal_downsample else 1,
478
+ factor_s=2 if down_flag else 1,
479
+ )
480
+
481
+ # Main path with residual blocks and downsample
482
+ resnets = []
483
+ for _ in range(num_res_blocks):
484
+ resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
485
+ in_dim = out_dim
486
+ self.resnets = nn.ModuleList(resnets)
487
+
488
+ # Add the final downsample block
489
+ if down_flag:
490
+ mode = "downsample3d" if temperal_downsample else "downsample2d"
491
+ self.downsampler = WanResample(out_dim, mode=mode)
492
+ else:
493
+ self.downsampler = None
494
+
495
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
496
+ x_copy = x.clone()
497
+ for resnet in self.resnets:
498
+ x = resnet(x, feat_cache, feat_idx)
499
+ if self.downsampler is not None:
500
+ x = self.downsampler(x, feat_cache, feat_idx)
501
+
502
+ return x + self.avg_shortcut(x_copy)
503
+
504
+
366
505
  class WanEncoder3d(nn.Module):
367
506
  r"""
368
507
  A 3D encoder module.
@@ -380,6 +519,7 @@ class WanEncoder3d(nn.Module):
380
519
 
381
520
  def __init__(
382
521
  self,
522
+ in_channels: int = 3,
383
523
  dim=128,
384
524
  z_dim=4,
385
525
  dim_mult=[1, 2, 4, 4],
@@ -388,6 +528,7 @@ class WanEncoder3d(nn.Module):
388
528
  temperal_downsample=[True, True, False],
389
529
  dropout=0.0,
390
530
  non_linearity: str = "silu",
531
+ is_residual: bool = False, # wan 2.2 vae use a residual downblock
391
532
  ):
392
533
  super().__init__()
393
534
  self.dim = dim
@@ -403,23 +544,35 @@ class WanEncoder3d(nn.Module):
403
544
  scale = 1.0
404
545
 
405
546
  # init block
406
- self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
547
+ self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
407
548
 
408
549
  # downsample blocks
409
550
  self.down_blocks = nn.ModuleList([])
410
551
  for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
411
552
  # residual (+attention) blocks
412
- for _ in range(num_res_blocks):
413
- self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
414
- if scale in attn_scales:
415
- self.down_blocks.append(WanAttentionBlock(out_dim))
416
- in_dim = out_dim
417
-
418
- # downsample block
419
- if i != len(dim_mult) - 1:
420
- mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
421
- self.down_blocks.append(WanResample(out_dim, mode=mode))
422
- scale /= 2.0
553
+ if is_residual:
554
+ self.down_blocks.append(
555
+ WanResidualDownBlock(
556
+ in_dim,
557
+ out_dim,
558
+ dropout,
559
+ num_res_blocks,
560
+ temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
561
+ down_flag=i != len(dim_mult) - 1,
562
+ )
563
+ )
564
+ else:
565
+ for _ in range(num_res_blocks):
566
+ self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
567
+ if scale in attn_scales:
568
+ self.down_blocks.append(WanAttentionBlock(out_dim))
569
+ in_dim = out_dim
570
+
571
+ # downsample block
572
+ if i != len(dim_mult) - 1:
573
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
574
+ self.down_blocks.append(WanResample(out_dim, mode=mode))
575
+ scale /= 2.0
423
576
 
424
577
  # middle blocks
425
578
  self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
@@ -470,6 +623,94 @@ class WanEncoder3d(nn.Module):
470
623
  return x
471
624
 
472
625
 
626
+ class WanResidualUpBlock(nn.Module):
627
+ """
628
+ A block that handles upsampling for the WanVAE decoder.
629
+
630
+ Args:
631
+ in_dim (int): Input dimension
632
+ out_dim (int): Output dimension
633
+ num_res_blocks (int): Number of residual blocks
634
+ dropout (float): Dropout rate
635
+ temperal_upsample (bool): Whether to upsample on temporal dimension
636
+ up_flag (bool): Whether to upsample or not
637
+ non_linearity (str): Type of non-linearity to use
638
+ """
639
+
640
+ def __init__(
641
+ self,
642
+ in_dim: int,
643
+ out_dim: int,
644
+ num_res_blocks: int,
645
+ dropout: float = 0.0,
646
+ temperal_upsample: bool = False,
647
+ up_flag: bool = False,
648
+ non_linearity: str = "silu",
649
+ ):
650
+ super().__init__()
651
+ self.in_dim = in_dim
652
+ self.out_dim = out_dim
653
+
654
+ if up_flag:
655
+ self.avg_shortcut = DupUp3D(
656
+ in_dim,
657
+ out_dim,
658
+ factor_t=2 if temperal_upsample else 1,
659
+ factor_s=2,
660
+ )
661
+ else:
662
+ self.avg_shortcut = None
663
+
664
+ # create residual blocks
665
+ resnets = []
666
+ current_dim = in_dim
667
+ for _ in range(num_res_blocks + 1):
668
+ resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
669
+ current_dim = out_dim
670
+
671
+ self.resnets = nn.ModuleList(resnets)
672
+
673
+ # Add upsampling layer if needed
674
+ if up_flag:
675
+ upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
676
+ self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
677
+ else:
678
+ self.upsampler = None
679
+
680
+ self.gradient_checkpointing = False
681
+
682
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
683
+ """
684
+ Forward pass through the upsampling block.
685
+
686
+ Args:
687
+ x (torch.Tensor): Input tensor
688
+ feat_cache (list, optional): Feature cache for causal convolutions
689
+ feat_idx (list, optional): Feature index for cache management
690
+
691
+ Returns:
692
+ torch.Tensor: Output tensor
693
+ """
694
+ x_copy = x.clone()
695
+
696
+ for resnet in self.resnets:
697
+ if feat_cache is not None:
698
+ x = resnet(x, feat_cache, feat_idx)
699
+ else:
700
+ x = resnet(x)
701
+
702
+ if self.upsampler is not None:
703
+ if feat_cache is not None:
704
+ x = self.upsampler(x, feat_cache, feat_idx)
705
+ else:
706
+ x = self.upsampler(x)
707
+
708
+ if self.avg_shortcut is not None:
709
+ x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
710
+
711
+ return x
712
+
713
+
473
714
  class WanUpBlock(nn.Module):
474
715
  """
475
716
  A block that handles upsampling for the WanVAE decoder.
@@ -513,7 +754,7 @@ class WanUpBlock(nn.Module):
513
754
 
514
755
  self.gradient_checkpointing = False
515
756
 
516
- def forward(self, x, feat_cache=None, feat_idx=[0]):
757
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
517
758
  """
518
759
  Forward pass through the upsampling block.
519
760
 
@@ -564,6 +805,8 @@ class WanDecoder3d(nn.Module):
564
805
  temperal_upsample=[False, True, True],
565
806
  dropout=0.0,
566
807
  non_linearity: str = "silu",
808
+ out_channels: int = 3,
809
+ is_residual: bool = False,
567
810
  ):
568
811
  super().__init__()
569
812
  self.dim = dim
@@ -577,7 +820,6 @@ class WanDecoder3d(nn.Module):
577
820
 
578
821
  # dimensions
579
822
  dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
580
- scale = 1.0 / 2 ** (len(dim_mult) - 2)
581
823
 
582
824
  # init block
583
825
  self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
@@ -589,36 +831,47 @@ class WanDecoder3d(nn.Module):
589
831
  self.up_blocks = nn.ModuleList([])
590
832
  for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
591
833
  # residual (+attention) blocks
592
- if i > 0:
834
+ if i > 0 and not is_residual:
835
+ # wan vae 2.1
593
836
  in_dim = in_dim // 2
594
837
 
595
- # Determine if we need upsampling
838
+ # determine if we need upsampling
839
+ up_flag = i != len(dim_mult) - 1
840
+ # determine upsampling mode, if not upsampling, set to None
596
841
  upsample_mode = None
597
- if i != len(dim_mult) - 1:
598
- upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
599
-
842
+ if up_flag and temperal_upsample[i]:
843
+ upsample_mode = "upsample3d"
844
+ elif up_flag:
845
+ upsample_mode = "upsample2d"
600
846
  # Create and add the upsampling block
601
- up_block = WanUpBlock(
602
- in_dim=in_dim,
603
- out_dim=out_dim,
604
- num_res_blocks=num_res_blocks,
605
- dropout=dropout,
606
- upsample_mode=upsample_mode,
607
- non_linearity=non_linearity,
608
- )
847
+ if is_residual:
848
+ up_block = WanResidualUpBlock(
849
+ in_dim=in_dim,
850
+ out_dim=out_dim,
851
+ num_res_blocks=num_res_blocks,
852
+ dropout=dropout,
853
+ temperal_upsample=temperal_upsample[i] if up_flag else False,
854
+ up_flag=up_flag,
855
+ non_linearity=non_linearity,
856
+ )
857
+ else:
858
+ up_block = WanUpBlock(
859
+ in_dim=in_dim,
860
+ out_dim=out_dim,
861
+ num_res_blocks=num_res_blocks,
862
+ dropout=dropout,
863
+ upsample_mode=upsample_mode,
864
+ non_linearity=non_linearity,
865
+ )
609
866
  self.up_blocks.append(up_block)
610
867
 
611
- # Update scale for next iteration
612
- if upsample_mode is not None:
613
- scale *= 2.0
614
-
615
868
  # output blocks
616
869
  self.norm_out = WanRMS_norm(out_dim, images=False)
617
- self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
870
+ self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
618
871
 
619
872
  self.gradient_checkpointing = False
620
873
 
621
- def forward(self, x, feat_cache=None, feat_idx=[0]):
874
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
622
875
  ## conv1
623
876
  if feat_cache is not None:
624
877
  idx = feat_idx[0]
@@ -637,7 +890,7 @@ class WanDecoder3d(nn.Module):
637
890
 
638
891
  ## upsamples
639
892
  for up_block in self.up_blocks:
640
- x = up_block(x, feat_cache, feat_idx)
893
+ x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
641
894
 
642
895
  ## head
643
896
  x = self.norm_out(x)
@@ -656,6 +909,49 @@ class WanDecoder3d(nn.Module):
656
909
  return x
657
910
 
658
911
 
912
+ def patchify(x, patch_size):
913
+ if patch_size == 1:
914
+ return x
915
+
916
+ if x.dim() != 5:
917
+ raise ValueError(f"Invalid input shape: {x.shape}")
918
+ # x shape: [batch_size, channels, frames, height, width]
919
+ batch_size, channels, frames, height, width = x.shape
920
+
921
+ # Ensure height and width are divisible by patch_size
922
+ if height % patch_size != 0 or width % patch_size != 0:
923
+ raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
924
+
925
+ # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
926
+ x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
927
+
928
+ # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
929
+ x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
930
+ x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
931
+
932
+ return x
933
+
934
+
935
+ def unpatchify(x, patch_size):
936
+ if patch_size == 1:
937
+ return x
938
+
939
+ if x.dim() != 5:
940
+ raise ValueError(f"Invalid input shape: {x.shape}")
941
+ # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
942
+ batch_size, c_patches, frames, height, width = x.shape
943
+ channels = c_patches // (patch_size * patch_size)
944
+
945
+ # Reshape to [b, c, patch_size, patch_size, f, h, w]
946
+ x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
947
+
948
+ # Rearrange to [b, c, f, h * patch_size, w * patch_size]
949
+ x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
950
+ x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
951
+
952
+ return x
953
+
954
+
659
955
  class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
660
956
  r"""
661
957
  A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
@@ -671,6 +967,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
671
967
  def __init__(
672
968
  self,
673
969
  base_dim: int = 96,
970
+ decoder_base_dim: Optional[int] = None,
674
971
  z_dim: int = 16,
675
972
  dim_mult: Tuple[int] = [1, 2, 4, 4],
676
973
  num_res_blocks: int = 2,
@@ -713,6 +1010,12 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
713
1010
  2.8251,
714
1011
  1.9160,
715
1012
  ],
1013
+ is_residual: bool = False,
1014
+ in_channels: int = 3,
1015
+ out_channels: int = 3,
1016
+ patch_size: Optional[int] = None,
1017
+ scale_factor_temporal: Optional[int] = 4,
1018
+ scale_factor_spatial: Optional[int] = 8,
716
1019
  ) -> None:
717
1020
  super().__init__()
718
1021
 
@@ -720,37 +1023,135 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
720
1023
  self.temperal_downsample = temperal_downsample
721
1024
  self.temperal_upsample = temperal_downsample[::-1]
722
1025
 
1026
+ if decoder_base_dim is None:
1027
+ decoder_base_dim = base_dim
1028
+
723
1029
  self.encoder = WanEncoder3d(
724
- base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
1030
+ in_channels=in_channels,
1031
+ dim=base_dim,
1032
+ z_dim=z_dim * 2,
1033
+ dim_mult=dim_mult,
1034
+ num_res_blocks=num_res_blocks,
1035
+ attn_scales=attn_scales,
1036
+ temperal_downsample=temperal_downsample,
1037
+ dropout=dropout,
1038
+ is_residual=is_residual,
725
1039
  )
726
1040
  self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
727
1041
  self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
728
1042
 
729
1043
  self.decoder = WanDecoder3d(
730
- base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
1044
+ dim=decoder_base_dim,
1045
+ z_dim=z_dim,
1046
+ dim_mult=dim_mult,
1047
+ num_res_blocks=num_res_blocks,
1048
+ attn_scales=attn_scales,
1049
+ temperal_upsample=self.temperal_upsample,
1050
+ dropout=dropout,
1051
+ out_channels=out_channels,
1052
+ is_residual=is_residual,
731
1053
  )
732
1054
 
1055
+ self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
1056
+
1057
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
1058
+ # to perform decoding of a single video latent at a time.
1059
+ self.use_slicing = False
1060
+
1061
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
1062
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
1063
+ # intermediate tiles together, the memory requirement can be lowered.
1064
+ self.use_tiling = False
1065
+
1066
+ # The minimal tile height and width for spatial tiling to be used
1067
+ self.tile_sample_min_height = 256
1068
+ self.tile_sample_min_width = 256
1069
+
1070
+ # The minimal distance between two spatial tiles
1071
+ self.tile_sample_stride_height = 192
1072
+ self.tile_sample_stride_width = 192
1073
+
1074
+ # Precompute and cache conv counts for encoder and decoder for clear_cache speedup
1075
+ self._cached_conv_counts = {
1076
+ "decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
1077
+ if self.decoder is not None
1078
+ else 0,
1079
+ "encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
1080
+ if self.encoder is not None
1081
+ else 0,
1082
+ }
1083
+
1084
+ def enable_tiling(
1085
+ self,
1086
+ tile_sample_min_height: Optional[int] = None,
1087
+ tile_sample_min_width: Optional[int] = None,
1088
+ tile_sample_stride_height: Optional[float] = None,
1089
+ tile_sample_stride_width: Optional[float] = None,
1090
+ ) -> None:
1091
+ r"""
1092
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1093
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1094
+ processing larger images.
1095
+
1096
+ Args:
1097
+ tile_sample_min_height (`int`, *optional*):
1098
+ The minimum height required for a sample to be separated into tiles across the height dimension.
1099
+ tile_sample_min_width (`int`, *optional*):
1100
+ The minimum width required for a sample to be separated into tiles across the width dimension.
1101
+ tile_sample_stride_height (`int`, *optional*):
1102
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
1103
+ no tiling artifacts produced across the height dimension.
1104
+ tile_sample_stride_width (`int`, *optional*):
1105
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
1106
+ artifacts produced across the width dimension.
1107
+ """
1108
+ self.use_tiling = True
1109
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1110
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1111
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
1112
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
1113
+
1114
+ def disable_tiling(self) -> None:
1115
+ r"""
1116
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1117
+ decoding in one step.
1118
+ """
1119
+ self.use_tiling = False
1120
+
1121
+ def enable_slicing(self) -> None:
1122
+ r"""
1123
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1124
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1125
+ """
1126
+ self.use_slicing = True
1127
+
1128
+ def disable_slicing(self) -> None:
1129
+ r"""
1130
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1131
+ decoding in one step.
1132
+ """
1133
+ self.use_slicing = False
1134
+
733
1135
  def clear_cache(self):
734
- def _count_conv3d(model):
735
- count = 0
736
- for m in model.modules():
737
- if isinstance(m, WanCausalConv3d):
738
- count += 1
739
- return count
740
-
741
- self._conv_num = _count_conv3d(self.decoder)
1136
+ # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
1137
+ self._conv_num = self._cached_conv_counts["decoder"]
742
1138
  self._conv_idx = [0]
743
1139
  self._feat_map = [None] * self._conv_num
744
1140
  # cache encode
745
- self._enc_conv_num = _count_conv3d(self.encoder)
1141
+ self._enc_conv_num = self._cached_conv_counts["encoder"]
746
1142
  self._enc_conv_idx = [0]
747
1143
  self._enc_feat_map = [None] * self._enc_conv_num
748
1144
 
749
- def _encode(self, x: torch.Tensor) -> torch.Tensor:
1145
+ def _encode(self, x: torch.Tensor):
1146
+ _, _, num_frame, height, width = x.shape
1147
+
1148
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1149
+ return self.tiled_encode(x)
1150
+
750
1151
  self.clear_cache()
751
- ## cache
752
- t = x.shape[2]
753
- iter_ = 1 + (t - 1) // 4
1152
+ if self.config.patch_size is not None:
1153
+ x = patchify(x, patch_size=self.config.patch_size)
1154
+ iter_ = 1 + (num_frame - 1) // 4
754
1155
  for i in range(iter_):
755
1156
  self._enc_conv_idx = [0]
756
1157
  if i == 0:
@@ -764,8 +1165,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
764
1165
  out = torch.cat([out, out_], 2)
765
1166
 
766
1167
  enc = self.quant_conv(out)
767
- mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
768
- enc = torch.cat([mu, logvar], dim=1)
769
1168
  self.clear_cache()
770
1169
  return enc
771
1170
 
@@ -785,26 +1184,42 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
785
1184
  The latent representations of the encoded videos. If `return_dict` is True, a
786
1185
  [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
787
1186
  """
788
- h = self._encode(x)
1187
+ if self.use_slicing and x.shape[0] > 1:
1188
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1189
+ h = torch.cat(encoded_slices)
1190
+ else:
1191
+ h = self._encode(x)
789
1192
  posterior = DiagonalGaussianDistribution(h)
1193
+
790
1194
  if not return_dict:
791
1195
  return (posterior,)
792
1196
  return AutoencoderKLOutput(latent_dist=posterior)
793
1197
 
794
- def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
795
- self.clear_cache()
1198
+ def _decode(self, z: torch.Tensor, return_dict: bool = True):
1199
+ _, _, num_frame, height, width = z.shape
1200
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1201
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
796
1202
 
797
- iter_ = z.shape[2]
1203
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
1204
+ return self.tiled_decode(z, return_dict=return_dict)
1205
+
1206
+ self.clear_cache()
798
1207
  x = self.post_quant_conv(z)
799
- for i in range(iter_):
1208
+ for i in range(num_frame):
800
1209
  self._conv_idx = [0]
801
1210
  if i == 0:
802
- out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
1211
+ out = self.decoder(
1212
+ x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True
1213
+ )
803
1214
  else:
804
1215
  out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
805
1216
  out = torch.cat([out, out_], 2)
806
1217
 
1218
+ if self.config.patch_size is not None:
1219
+ out = unpatchify(out, patch_size=self.config.patch_size)
1220
+
807
1221
  out = torch.clamp(out, min=-1.0, max=1.0)
1222
+
808
1223
  self.clear_cache()
809
1224
  if not return_dict:
810
1225
  return (out,)
@@ -826,12 +1241,161 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
826
1241
  If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
827
1242
  returned.
828
1243
  """
829
- decoded = self._decode(z).sample
1244
+ if self.use_slicing and z.shape[0] > 1:
1245
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1246
+ decoded = torch.cat(decoded_slices)
1247
+ else:
1248
+ decoded = self._decode(z).sample
1249
+
830
1250
  if not return_dict:
831
1251
  return (decoded,)
832
-
833
1252
  return DecoderOutput(sample=decoded)
834
1253
 
1254
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1255
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
1256
+ for y in range(blend_extent):
1257
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1258
+ y / blend_extent
1259
+ )
1260
+ return b
1261
+
1262
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1263
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
1264
+ for x in range(blend_extent):
1265
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1266
+ x / blend_extent
1267
+ )
1268
+ return b
1269
+
1270
+ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
1271
+ r"""Encode a batch of images using a tiled encoder.
1272
+
1273
+ Args:
1274
+ x (`torch.Tensor`): Input batch of videos.
1275
+
1276
+ Returns:
1277
+ `torch.Tensor`:
1278
+ The latent representation of the encoded videos.
1279
+ """
1280
+ _, _, num_frames, height, width = x.shape
1281
+ latent_height = height // self.spatial_compression_ratio
1282
+ latent_width = width // self.spatial_compression_ratio
1283
+
1284
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1285
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1286
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1287
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1288
+
1289
+ blend_height = tile_latent_min_height - tile_latent_stride_height
1290
+ blend_width = tile_latent_min_width - tile_latent_stride_width
1291
+
1292
+ # Split x into overlapping tiles and encode them separately.
1293
+ # The tiles have an overlap to avoid seams between tiles.
1294
+ rows = []
1295
+ for i in range(0, height, self.tile_sample_stride_height):
1296
+ row = []
1297
+ for j in range(0, width, self.tile_sample_stride_width):
1298
+ self.clear_cache()
1299
+ time = []
1300
+ frame_range = 1 + (num_frames - 1) // 4
1301
+ for k in range(frame_range):
1302
+ self._enc_conv_idx = [0]
1303
+ if k == 0:
1304
+ tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
1305
+ else:
1306
+ tile = x[
1307
+ :,
1308
+ :,
1309
+ 1 + 4 * (k - 1) : 1 + 4 * k,
1310
+ i : i + self.tile_sample_min_height,
1311
+ j : j + self.tile_sample_min_width,
1312
+ ]
1313
+ tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
1314
+ tile = self.quant_conv(tile)
1315
+ time.append(tile)
1316
+ row.append(torch.cat(time, dim=2))
1317
+ rows.append(row)
1318
+ self.clear_cache()
1319
+
1320
+ result_rows = []
1321
+ for i, row in enumerate(rows):
1322
+ result_row = []
1323
+ for j, tile in enumerate(row):
1324
+ # blend the above tile and the left tile
1325
+ # to the current tile and add the current tile to the result row
1326
+ if i > 0:
1327
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1328
+ if j > 0:
1329
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1330
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
1331
+ result_rows.append(torch.cat(result_row, dim=-1))
1332
+
1333
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
1334
+ return enc
1335
+
1336
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1337
+ r"""
1338
+ Decode a batch of images using a tiled decoder.
1339
+
1340
+ Args:
1341
+ z (`torch.Tensor`): Input batch of latent vectors.
1342
+ return_dict (`bool`, *optional*, defaults to `True`):
1343
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1344
+
1345
+ Returns:
1346
+ [`~models.vae.DecoderOutput`] or `tuple`:
1347
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1348
+ returned.
1349
+ """
1350
+ _, _, num_frames, height, width = z.shape
1351
+ sample_height = height * self.spatial_compression_ratio
1352
+ sample_width = width * self.spatial_compression_ratio
1353
+
1354
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1355
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1356
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1357
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1358
+
1359
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1360
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
1361
+
1362
+ # Split z into overlapping tiles and decode them separately.
1363
+ # The tiles have an overlap to avoid seams between tiles.
1364
+ rows = []
1365
+ for i in range(0, height, tile_latent_stride_height):
1366
+ row = []
1367
+ for j in range(0, width, tile_latent_stride_width):
1368
+ self.clear_cache()
1369
+ time = []
1370
+ for k in range(num_frames):
1371
+ self._conv_idx = [0]
1372
+ tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
1373
+ tile = self.post_quant_conv(tile)
1374
+ decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
1375
+ time.append(decoded)
1376
+ row.append(torch.cat(time, dim=2))
1377
+ rows.append(row)
1378
+ self.clear_cache()
1379
+
1380
+ result_rows = []
1381
+ for i, row in enumerate(rows):
1382
+ result_row = []
1383
+ for j, tile in enumerate(row):
1384
+ # blend the above tile and the left tile
1385
+ # to the current tile and add the current tile to the result row
1386
+ if i > 0:
1387
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1388
+ if j > 0:
1389
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1390
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1391
+ result_rows.append(torch.cat(result_row, dim=-1))
1392
+
1393
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1394
+
1395
+ if not return_dict:
1396
+ return (dec,)
1397
+ return DecoderOutput(sample=dec)
1398
+
835
1399
  def forward(
836
1400
  self,
837
1401
  sample: torch.Tensor,