diffsynth-engine 0.6.1.dev22__tar.gz → 0.6.1.dev23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (209) hide show
  1. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/PKG-INFO +1 -1
  2. diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
  3. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/configs/pipeline.py +33 -5
  4. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/basic/attention.py +59 -20
  5. diffsynth_engine-0.6.1.dev23/diffsynth_engine/models/basic/video_sparse_attention.py +235 -0
  6. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/flux/flux_controlnet.py +7 -19
  7. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/flux/flux_dit.py +22 -36
  8. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/flux/flux_dit_fbcache.py +9 -7
  9. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  10. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/qwen_image/qwen_image_dit.py +13 -15
  11. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
  12. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/wan/wan_dit.py +62 -22
  13. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/flux_image.py +11 -10
  14. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/qwen_image.py +3 -10
  15. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/wan_s2v.py +3 -8
  16. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/wan_video.py +11 -13
  17. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/constants.py +13 -12
  18. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/flag.py +6 -0
  19. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/parallel.py +51 -6
  20. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine.egg-info/PKG-INFO +1 -1
  21. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine.egg-info/SOURCES.txt +13 -11
  22. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/.gitattributes +0 -0
  23. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/.gitignore +0 -0
  24. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/.pre-commit-config.yaml +0 -0
  25. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/LICENSE +0 -0
  26. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/MANIFEST.in +0 -0
  27. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/README.md +0 -0
  28. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/assets/dingtalk.png +0 -0
  29. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/assets/showcase.jpeg +0 -0
  30. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/assets/tongyi.svg +0 -0
  31. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/__init__.py +0 -0
  32. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/__init__.py +0 -0
  33. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
  34. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
  35. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
  36. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
  37. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
  38. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
  39. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  40. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
  41. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
  42. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
  43. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
  44. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
  45. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
  46. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
  47. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  48. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
  49. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  50. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
  51. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
  52. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
  53. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
  54. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
  55. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
  56. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
  57. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
  58. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
  59. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/components/vae.json +0 -0
  60. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
  61. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
  62. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
  63. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_config.json +0 -0
  64. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_vision_config.json +0 -0
  65. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae.json +0 -0
  66. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae_keymap.json +0 -0
  67. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
  68. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
  69. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
  70. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
  71. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
  72. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
  73. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/dit/wan2.1-flf2v-14b.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan2.1_flf2v_14b.json +0 -0
  74. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/dit/wan2.1-i2v-14b.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan2.1_i2v_14b.json +0 -0
  75. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-1.3b.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_1.3b.json +0 -0
  76. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-14b.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_14b.json +0 -0
  77. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/dit/wan2.2-i2v-a14b.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan2.2_i2v_a14b.json +0 -0
  78. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/dit/wan2.2-s2v-14b.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan2.2_s2v_14b.json +0 -0
  79. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/dit/wan2.2-t2v-a14b.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan2.2_t2v_a14b.json +0 -0
  80. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/dit/wan2.2-ti2v-5b.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan2.2_ti2v_5b.json +0 -0
  81. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/vae/wan2.1-vae.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/vae/wan2.1_vae.json +0 -0
  82. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/vae/wan2.2-vae.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/vae/wan2.2_vae.json +0 -0
  83. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/vae/wan-vae-keymap.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/vae/wan_vae_keymap.json +0 -0
  84. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
  85. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
  86. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
  87. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
  88. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
  89. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  90. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
  91. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
  92. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/qwen_image/qwen2_vl_image_processor.json +0 -0
  93. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/added_tokens.json +0 -0
  94. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/merges.txt +0 -0
  95. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/special_tokens_map.json +0 -0
  96. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer.json +0 -0
  97. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer_config.json +0 -0
  98. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/vocab.json +0 -0
  99. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
  100. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
  101. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
  102. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
  103. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
  104. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
  105. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
  106. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
  107. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
  108. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  109. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
  110. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
  111. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/configs/__init__.py +0 -0
  112. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/configs/controlnet.py +0 -0
  113. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/kernels/__init__.py +0 -0
  114. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/__init__.py +0 -0
  115. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/base.py +0 -0
  116. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/basic/__init__.py +0 -0
  117. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/basic/lora.py +0 -0
  118. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
  119. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/basic/timestep.py +0 -0
  120. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/basic/transformer_helper.py +0 -0
  121. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/basic/unet_helper.py +0 -0
  122. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/flux/__init__.py +0 -0
  123. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/flux/flux_redux.py +0 -0
  124. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
  125. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/flux/flux_vae.py +0 -0
  126. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/hunyuan3d/__init__.py +0 -0
  127. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +0 -0
  128. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/hunyuan3d/hunyuan3d_dit.py +0 -0
  129. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/hunyuan3d/hunyuan3d_vae.py +0 -0
  130. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/hunyuan3d/moe.py +0 -0
  131. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/hunyuan3d/surface_extractor.py +0 -0
  132. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/hunyuan3d/volume_decoder.py +0 -0
  133. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/qwen_image/__init__.py +0 -0
  134. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/qwen_image/qwen2_5_vl.py +0 -0
  135. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/qwen_image/qwen_image_vae.py +0 -0
  136. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd/__init__.py +0 -0
  137. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd/sd_controlnet.py +0 -0
  138. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
  139. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd/sd_unet.py +0 -0
  140. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd/sd_vae.py +0 -0
  141. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd3/__init__.py +0 -0
  142. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd3/sd3_dit.py +0 -0
  143. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
  144. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
  145. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sdxl/__init__.py +0 -0
  146. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sdxl/sdxl_controlnet.py +0 -0
  147. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
  148. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sdxl/sdxl_unet.py +0 -0
  149. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
  150. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/text_encoder/__init__.py +0 -0
  151. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/text_encoder/clip.py +0 -0
  152. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/text_encoder/siglip.py +0 -0
  153. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/text_encoder/t5.py +0 -0
  154. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/vae/__init__.py +0 -0
  155. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/vae/vae.py +0 -0
  156. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/wan/__init__.py +0 -0
  157. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/wan/wan_audio_encoder.py +0 -0
  158. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
  159. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/wan/wan_s2v_dit.py +0 -0
  160. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
  161. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/wan/wan_vae.py +0 -0
  162. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/__init__.py +0 -0
  163. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/base.py +0 -0
  164. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/hunyuan3d_shape.py +0 -0
  165. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/sd_image.py +0 -0
  166. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/sdxl_image.py +0 -0
  167. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/utils.py +0 -0
  168. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/processor/__init__.py +0 -0
  169. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/processor/canny_processor.py +0 -0
  170. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/processor/depth_processor.py +0 -0
  171. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tokenizers/__init__.py +0 -0
  172. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tokenizers/base.py +0 -0
  173. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tokenizers/clip.py +0 -0
  174. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tokenizers/qwen2.py +0 -0
  175. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tokenizers/qwen2_vl_image_processor.py +0 -0
  176. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tokenizers/qwen2_vl_processor.py +0 -0
  177. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tokenizers/t5.py +0 -0
  178. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tokenizers/wan.py +0 -0
  179. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tools/__init__.py +0 -0
  180. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tools/flux_inpainting_tool.py +0 -0
  181. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tools/flux_outpainting_tool.py +0 -0
  182. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tools/flux_reference_tool.py +0 -0
  183. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tools/flux_replace_tool.py +0 -0
  184. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/__init__.py +0 -0
  185. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/cache.py +0 -0
  186. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/download.py +0 -0
  187. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/env.py +0 -0
  188. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/fp8_linear.py +0 -0
  189. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/gguf.py +0 -0
  190. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/image.py +0 -0
  191. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/loader.py +0 -0
  192. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/lock.py +0 -0
  193. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/logging.py +0 -0
  194. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/memory/__init__.py +0 -0
  195. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/memory/linear_regression.py +0 -0
  196. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/memory/memory_predcit_model.py +0 -0
  197. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/offload.py +0 -0
  198. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/onnx.py +0 -0
  199. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/platform.py +0 -0
  200. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/prompt.py +0 -0
  201. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/video.py +0 -0
  202. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
  203. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine.egg-info/requires.txt +0 -0
  204. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine.egg-info/top_level.txt +0 -0
  205. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/docs/tutorial.md +0 -0
  206. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/docs/tutorial_zh.md +0 -0
  207. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/pyproject.toml +0 -0
  208. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/setup.cfg +0 -0
  209. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev23}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev22
3
+ Version: 0.6.1.dev23
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -0,0 +1,41 @@
1
+ {
2
+ "diffusers": {
3
+ "global_rename_dict": {
4
+ "patch_embedding": "patch_embedding",
5
+ "condition_embedder.text_embedder.linear_1": "text_embedding.0",
6
+ "condition_embedder.text_embedder.linear_2": "text_embedding.2",
7
+ "condition_embedder.time_embedder.linear_1": "time_embedding.0",
8
+ "condition_embedder.time_embedder.linear_2": "time_embedding.2",
9
+ "condition_embedder.time_proj": "time_projection.1",
10
+ "condition_embedder.image_embedder.norm1": "img_emb.proj.0",
11
+ "condition_embedder.image_embedder.ff.net.0.proj": "img_emb.proj.1",
12
+ "condition_embedder.image_embedder.ff.net.2": "img_emb.proj.3",
13
+ "condition_embedder.image_embedder.norm2": "img_emb.proj.4",
14
+ "condition_embedder.image_embedder.pos_embed": "img_emb.emb_pos",
15
+ "proj_out": "head.head",
16
+ "scale_shift_table": "head.modulation"
17
+ },
18
+ "rename_dict": {
19
+ "attn1.to_q": "self_attn.q",
20
+ "attn1.to_k": "self_attn.k",
21
+ "attn1.to_v": "self_attn.v",
22
+ "attn1.to_out.0": "self_attn.o",
23
+ "attn1.norm_q": "self_attn.norm_q",
24
+ "attn1.norm_k": "self_attn.norm_k",
25
+ "to_gate_compress": "self_attn.gate_compress",
26
+ "attn2.to_q": "cross_attn.q",
27
+ "attn2.to_k": "cross_attn.k",
28
+ "attn2.to_v": "cross_attn.v",
29
+ "attn2.to_out.0": "cross_attn.o",
30
+ "attn2.norm_q": "cross_attn.norm_q",
31
+ "attn2.norm_k": "cross_attn.norm_k",
32
+ "attn2.add_k_proj": "cross_attn.k_img",
33
+ "attn2.add_v_proj": "cross_attn.v_img",
34
+ "attn2.norm_added_k": "cross_attn.norm_k_img",
35
+ "norm2": "norm3",
36
+ "ffn.net.0.proj": "ffn.0",
37
+ "ffn.net.2": "ffn.2",
38
+ "scale_shift_table": "modulation"
39
+ }
40
+ }
41
+ }
@@ -5,6 +5,7 @@ from dataclasses import dataclass, field
5
5
  from typing import List, Dict, Tuple, Optional
6
6
 
7
7
  from diffsynth_engine.configs.controlnet import ControlType
8
+ from diffsynth_engine.models.basic.video_sparse_attention import get_vsa_kwargs
8
9
 
9
10
 
10
11
  @dataclass
@@ -30,16 +31,43 @@ class AttnImpl(Enum):
30
31
  SDPA = "sdpa" # Scaled Dot Product Attention
31
32
  SAGE = "sage" # Sage Attention
32
33
  SPARGE = "sparge" # Sparge Attention
34
+ VSA = "vsa" # Video Sparse Attention
35
+
36
+
37
+ @dataclass
38
+ class SpargeAttentionParams:
39
+ smooth_k: bool = True
40
+ cdfthreshd: float = 0.6
41
+ simthreshd1: float = 0.98
42
+ pvthreshd: float = 50.0
43
+
44
+
45
+ @dataclass
46
+ class VideoSparseAttentionParams:
47
+ sparsity: float = 0.9
33
48
 
34
49
 
35
50
  @dataclass
36
51
  class AttentionConfig:
37
52
  dit_attn_impl: AttnImpl = AttnImpl.AUTO
38
- # Sparge Attention
39
- sparge_smooth_k: bool = True
40
- sparge_cdfthreshd: float = 0.6
41
- sparge_simthreshd1: float = 0.98
42
- sparge_pvthreshd: float = 50.0
53
+ attn_params: Optional[SpargeAttentionParams | VideoSparseAttentionParams] = None
54
+
55
+ def get_attn_kwargs(self, latents: torch.Tensor, device: str) -> Dict:
56
+ attn_kwargs = {"attn_impl": self.dit_attn_impl.value}
57
+ if isinstance(self.attn_params, SpargeAttentionParams):
58
+ assert self.dit_attn_impl == AttnImpl.SPARGE
59
+ attn_kwargs.update(
60
+ {
61
+ "smooth_k": self.attn_params.smooth_k,
62
+ "simthreshd1": self.attn_params.simthreshd1,
63
+ "cdfthreshd": self.attn_params.cdfthreshd,
64
+ "pvthreshd": self.attn_params.pvthreshd,
65
+ }
66
+ )
67
+ elif isinstance(self.attn_params, VideoSparseAttentionParams):
68
+ assert self.dit_attn_impl == AttnImpl.VSA
69
+ attn_kwargs.update(get_vsa_kwargs(latents.shape[2:], (1, 2, 2), self.attn_params.sparsity, device=device))
70
+ return attn_kwargs
43
71
 
44
72
 
45
73
  @dataclass
@@ -12,6 +12,7 @@ from diffsynth_engine.utils.flag import (
12
12
  SDPA_AVAILABLE,
13
13
  SAGE_ATTN_AVAILABLE,
14
14
  SPARGE_ATTN_AVAILABLE,
15
+ VIDEO_SPARSE_ATTN_AVAILABLE,
15
16
  )
16
17
  from diffsynth_engine.utils.platform import DTYPE_FP8
17
18
 
@@ -20,12 +21,6 @@ FA3_MAX_HEADDIM = 256
20
21
  logger = logging.get_logger(__name__)
21
22
 
22
23
 
23
- def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
24
- padding_size = (alignment - x.shape[dim] % alignment) % alignment
25
- padded_x = F.pad(x, (0, padding_size), "constant", 0)
26
- return padded_x[..., : x.shape[dim]]
27
-
28
-
29
24
  if FLASH_ATTN_3_AVAILABLE:
30
25
  from flash_attn_interface import flash_attn_func as flash_attn3
31
26
  if FLASH_ATTN_2_AVAILABLE:
@@ -33,6 +28,11 @@ if FLASH_ATTN_2_AVAILABLE:
33
28
  if XFORMERS_AVAILABLE:
34
29
  from xformers.ops import memory_efficient_attention
35
30
 
31
+ def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
32
+ padding_size = (alignment - x.shape[dim] % alignment) % alignment
33
+ padded_x = F.pad(x, (0, padding_size), "constant", 0)
34
+ return padded_x[..., : x.shape[dim]]
35
+
36
36
  def xformers_attn(q, k, v, attn_mask=None, scale=None):
37
37
  if attn_mask is not None:
38
38
  if attn_mask.ndim == 2:
@@ -94,6 +94,13 @@ if SPARGE_ATTN_AVAILABLE:
94
94
  return out.transpose(1, 2)
95
95
 
96
96
 
97
+ if VIDEO_SPARSE_ATTN_AVAILABLE:
98
+ from diffsynth_engine.models.basic.video_sparse_attention import (
99
+ video_sparse_attn,
100
+ distributed_video_sparse_attn,
101
+ )
102
+
103
+
97
104
  def eager_attn(q, k, v, attn_mask=None, scale=None):
98
105
  q = q.transpose(1, 2)
99
106
  k = k.transpose(1, 2)
@@ -109,9 +116,10 @@ def eager_attn(q, k, v, attn_mask=None, scale=None):
109
116
 
110
117
 
111
118
  def attention(
112
- q,
113
- k,
114
- v,
119
+ q: torch.Tensor,
120
+ k: torch.Tensor,
121
+ v: torch.Tensor,
122
+ g: Optional[torch.Tensor] = None,
115
123
  attn_impl: Optional[str] = "auto",
116
124
  attn_mask: Optional[torch.Tensor] = None,
117
125
  scale: Optional[float] = None,
@@ -133,6 +141,7 @@ def attention(
133
141
  "sdpa",
134
142
  "sage",
135
143
  "sparge",
144
+ "vsa",
136
145
  ]
137
146
  flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
138
147
  if attn_impl is None or attn_impl == "auto":
@@ -189,10 +198,24 @@ def attention(
189
198
  v,
190
199
  attn_mask=attn_mask,
191
200
  scale=scale,
192
- smooth_k=kwargs.get("sparge_smooth_k", True),
193
- simthreshd1=kwargs.get("sparge_simthreshd1", 0.6),
194
- cdfthreshd=kwargs.get("sparge_cdfthreshd", 0.98),
195
- pvthreshd=kwargs.get("sparge_pvthreshd", 50),
201
+ smooth_k=kwargs.get("smooth_k", True),
202
+ simthreshd1=kwargs.get("simthreshd1", 0.6),
203
+ cdfthreshd=kwargs.get("cdfthreshd", 0.98),
204
+ pvthreshd=kwargs.get("pvthreshd", 50),
205
+ )
206
+ if attn_impl == "vsa":
207
+ return video_sparse_attn(
208
+ q,
209
+ k,
210
+ v,
211
+ g,
212
+ sparsity=kwargs.get("sparsity"),
213
+ num_tiles=kwargs.get("num_tiles"),
214
+ total_seq_length=kwargs.get("total_seq_length"),
215
+ tile_partition_indices=kwargs.get("tile_partition_indices"),
216
+ reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"),
217
+ variable_block_sizes=kwargs.get("variable_block_sizes"),
218
+ non_pad_index=kwargs.get("non_pad_index"),
196
219
  )
197
220
  raise ValueError(f"Invalid attention implementation: {attn_impl}")
198
221
 
@@ -242,9 +265,10 @@ class Attention(nn.Module):
242
265
 
243
266
 
244
267
  def long_context_attention(
245
- q,
246
- k,
247
- v,
268
+ q: torch.Tensor,
269
+ k: torch.Tensor,
270
+ v: torch.Tensor,
271
+ g: Optional[torch.Tensor] = None,
248
272
  attn_impl: Optional[str] = None,
249
273
  attn_mask: Optional[torch.Tensor] = None,
250
274
  scale: Optional[float] = None,
@@ -267,6 +291,7 @@ def long_context_attention(
267
291
  "sdpa",
268
292
  "sage",
269
293
  "sparge",
294
+ "vsa",
270
295
  ]
271
296
  assert attn_mask is None, "long context attention does not support attention mask"
272
297
  flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
@@ -307,11 +332,25 @@ def long_context_attention(
307
332
  if attn_impl == "sparge":
308
333
  attn_processor = SparseAttentionMeansim()
309
334
  # default args from spas_sage2_attn_meansim_cuda
310
- attn_processor.smooth_k = torch.tensor(kwargs.get("sparge_smooth_k", True))
311
- attn_processor.simthreshd1 = torch.tensor(kwargs.get("sparge_simthreshd1", 0.6))
312
- attn_processor.cdfthreshd = torch.tensor(kwargs.get("sparge_cdfthreshd", 0.98))
313
- attn_processor.pvthreshd = torch.tensor(kwargs.get("sparge_pvthreshd", 50))
335
+ attn_processor.smooth_k = torch.tensor(kwargs.get("smooth_k", True))
336
+ attn_processor.simthreshd1 = torch.tensor(kwargs.get("simthreshd1", 0.6))
337
+ attn_processor.cdfthreshd = torch.tensor(kwargs.get("cdfthreshd", 0.98))
338
+ attn_processor.pvthreshd = torch.tensor(kwargs.get("pvthreshd", 50))
314
339
  return LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)(
315
340
  q, k, v, softmax_scale=scale
316
341
  )
342
+ if attn_impl == "vsa":
343
+ return distributed_video_sparse_attn(
344
+ q,
345
+ k,
346
+ v,
347
+ g,
348
+ sparsity=kwargs.get("sparsity"),
349
+ num_tiles=kwargs.get("num_tiles"),
350
+ total_seq_length=kwargs.get("total_seq_length"),
351
+ tile_partition_indices=kwargs.get("tile_partition_indices"),
352
+ reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"),
353
+ variable_block_sizes=kwargs.get("variable_block_sizes"),
354
+ non_pad_index=kwargs.get("non_pad_index"),
355
+ )
317
356
  raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
@@ -0,0 +1,235 @@
1
+ import torch
2
+ import math
3
+ import functools
4
+
5
+ from vsa import video_sparse_attn as vsa_core
6
+ from diffsynth_engine.utils.parallel import get_sp_ulysses_group, get_sp_ring_world_size
7
+
8
+ VSA_TILE_SIZE = (4, 4, 4)
9
+
10
+
11
+ @functools.lru_cache(maxsize=10)
12
+ def get_tile_partition_indices(
13
+ dit_seq_shape: tuple[int, int, int],
14
+ tile_size: tuple[int, int, int],
15
+ device: torch.device,
16
+ ) -> torch.LongTensor:
17
+ T, H, W = dit_seq_shape
18
+ ts, hs, ws = tile_size
19
+ indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W)
20
+ ls = []
21
+ for t in range(math.ceil(T / ts)):
22
+ for h in range(math.ceil(H / hs)):
23
+ for w in range(math.ceil(W / ws)):
24
+ ls.append(
25
+ indices[
26
+ t * ts : min(t * ts + ts, T), h * hs : min(h * hs + hs, H), w * ws : min(w * ws + ws, W)
27
+ ].flatten()
28
+ )
29
+ index = torch.cat(ls, dim=0)
30
+ return index
31
+
32
+
33
+ @functools.lru_cache(maxsize=10)
34
+ def get_reverse_tile_partition_indices(
35
+ dit_seq_shape: tuple[int, int, int],
36
+ tile_size: tuple[int, int, int],
37
+ device: torch.device,
38
+ ) -> torch.LongTensor:
39
+ return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device))
40
+
41
+
42
+ @functools.lru_cache(maxsize=10)
43
+ def construct_variable_block_sizes(
44
+ dit_seq_shape: tuple[int, int, int],
45
+ num_tiles: tuple[int, int, int],
46
+ device: torch.device,
47
+ ) -> torch.LongTensor:
48
+ """
49
+ Compute the number of valid (non-padded) tokens inside every
50
+ (ts_t x ts_h x ts_w) tile after padding -- flattened in the order
51
+ (t-tile, h-tile, w-tile) that `rearrange` uses.
52
+
53
+ Returns
54
+ -------
55
+ torch.LongTensor # shape: [∏ full_window_size]
56
+ """
57
+ # unpack
58
+ t, h, w = dit_seq_shape
59
+ ts_t, ts_h, ts_w = VSA_TILE_SIZE
60
+ n_t, n_h, n_w = num_tiles
61
+
62
+ def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor:
63
+ """Vector with the size of each tile along one dimension."""
64
+ sizes = torch.full((n_tiles,), tile, dtype=torch.int, device=device)
65
+ # size of last (possibly partial) tile
66
+ remainder = dim_len - (n_tiles - 1) * tile
67
+ sizes[-1] = remainder if remainder > 0 else tile
68
+ return sizes
69
+
70
+ t_sizes = _sizes(t, ts_t, n_t) # [n_t]
71
+ h_sizes = _sizes(h, ts_h, n_h) # [n_h]
72
+ w_sizes = _sizes(w, ts_w, n_w) # [n_w]
73
+
74
+ # broadcast‑multiply to get voxels per tile, then flatten
75
+ block_sizes = (
76
+ t_sizes[:, None, None] # [n_t, 1, 1]
77
+ * h_sizes[None, :, None] # [1, n_h, 1]
78
+ * w_sizes[None, None, :] # [1, 1, n_w]
79
+ ).reshape(-1) # [n_t * n_h * n_w]
80
+
81
+ return block_sizes
82
+
83
+
84
+ @functools.lru_cache(maxsize=10)
85
+ def get_non_pad_index(
86
+ variable_block_sizes: torch.LongTensor,
87
+ max_block_size: int,
88
+ ):
89
+ n_win = variable_block_sizes.shape[0]
90
+ device = variable_block_sizes.device
91
+ starts_pad = torch.arange(n_win, device=device) * max_block_size
92
+ index_pad = starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :]
93
+ index_mask = torch.arange(max_block_size, device=device)[None, :] < variable_block_sizes[:, None]
94
+ return index_pad[index_mask]
95
+
96
+
97
+ def get_vsa_kwargs(
98
+ latent_shape: tuple[int, int, int],
99
+ patch_size: tuple[int, int, int],
100
+ sparsity: float,
101
+ device: torch.device,
102
+ ):
103
+ dit_seq_shape = (
104
+ latent_shape[0] // patch_size[0],
105
+ latent_shape[1] // patch_size[1],
106
+ latent_shape[2] // patch_size[2],
107
+ )
108
+
109
+ num_tiles = (
110
+ math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]),
111
+ math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]),
112
+ math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]),
113
+ )
114
+ total_seq_length = math.prod(dit_seq_shape)
115
+
116
+ tile_partition_indices = get_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
117
+ reverse_tile_partition_indices = get_reverse_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
118
+ variable_block_sizes = construct_variable_block_sizes(dit_seq_shape, num_tiles, device)
119
+ non_pad_index = get_non_pad_index(variable_block_sizes, math.prod(VSA_TILE_SIZE))
120
+
121
+ return {
122
+ "sparsity": sparsity,
123
+ "num_tiles": num_tiles,
124
+ "total_seq_length": total_seq_length,
125
+ "tile_partition_indices": tile_partition_indices,
126
+ "reverse_tile_partition_indices": reverse_tile_partition_indices,
127
+ "variable_block_sizes": variable_block_sizes,
128
+ "non_pad_index": non_pad_index,
129
+ }
130
+
131
+
132
+ def tile(
133
+ x: torch.Tensor,
134
+ num_tiles: tuple[int, int, int],
135
+ tile_partition_indices: torch.LongTensor,
136
+ non_pad_index: torch.LongTensor,
137
+ ) -> torch.Tensor:
138
+ t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]
139
+ h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]
140
+ w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]
141
+
142
+ x_padded = torch.zeros(
143
+ (x.shape[0], t_padded_size * h_padded_size * w_padded_size, x.shape[-2], x.shape[-1]),
144
+ device=x.device,
145
+ dtype=x.dtype,
146
+ )
147
+ x_padded[:, non_pad_index] = x[:, tile_partition_indices]
148
+ return x_padded
149
+
150
+
151
+ def untile(
152
+ x: torch.Tensor, reverse_tile_partition_indices: torch.LongTensor, non_pad_index: torch.LongTensor
153
+ ) -> torch.Tensor:
154
+ x = x[:, non_pad_index][:, reverse_tile_partition_indices]
155
+ return x
156
+
157
+
158
+ def video_sparse_attn(
159
+ q: torch.Tensor,
160
+ k: torch.Tensor,
161
+ v: torch.Tensor,
162
+ g: torch.Tensor,
163
+ sparsity: float,
164
+ num_tiles: tuple[int, int, int],
165
+ total_seq_length: int,
166
+ tile_partition_indices: torch.LongTensor,
167
+ reverse_tile_partition_indices: torch.LongTensor,
168
+ variable_block_sizes: torch.LongTensor,
169
+ non_pad_index: torch.LongTensor,
170
+ ):
171
+ q = tile(q, num_tiles, tile_partition_indices, non_pad_index)
172
+ k = tile(k, num_tiles, tile_partition_indices, non_pad_index)
173
+ v = tile(v, num_tiles, tile_partition_indices, non_pad_index)
174
+ g = tile(g, num_tiles, tile_partition_indices, non_pad_index)
175
+
176
+ q = q.transpose(1, 2).contiguous()
177
+ k = k.transpose(1, 2).contiguous()
178
+ v = v.transpose(1, 2).contiguous()
179
+ g = g.transpose(1, 2).contiguous()
180
+
181
+ topk = math.ceil((1 - sparsity) * (total_seq_length / math.prod(VSA_TILE_SIZE)))
182
+ out = vsa_core(
183
+ q,
184
+ k,
185
+ v,
186
+ variable_block_sizes=variable_block_sizes,
187
+ topk=topk,
188
+ block_size=VSA_TILE_SIZE,
189
+ compress_attn_weight=g,
190
+ ).transpose(1, 2)
191
+ out = untile(out, reverse_tile_partition_indices, non_pad_index)
192
+ return out
193
+
194
+
195
+ def distributed_video_sparse_attn(
196
+ q: torch.Tensor,
197
+ k: torch.Tensor,
198
+ v: torch.Tensor,
199
+ g: torch.Tensor,
200
+ sparsity: float,
201
+ num_tiles: tuple[int, int, int],
202
+ total_seq_length: int,
203
+ tile_partition_indices: torch.LongTensor,
204
+ reverse_tile_partition_indices: torch.LongTensor,
205
+ variable_block_sizes: torch.LongTensor,
206
+ non_pad_index: torch.LongTensor,
207
+ scatter_idx: int = 2,
208
+ gather_idx: int = 1,
209
+ ):
210
+ from yunchang.comm.all_to_all import SeqAllToAll4D
211
+
212
+ assert get_sp_ring_world_size() == 1, "distributed video sparse attention requires ring degree to be 1"
213
+ sp_ulysses_group = get_sp_ulysses_group()
214
+
215
+ q = SeqAllToAll4D.apply(sp_ulysses_group, q, scatter_idx, gather_idx)
216
+ k = SeqAllToAll4D.apply(sp_ulysses_group, k, scatter_idx, gather_idx)
217
+ v = SeqAllToAll4D.apply(sp_ulysses_group, v, scatter_idx, gather_idx)
218
+ g = SeqAllToAll4D.apply(sp_ulysses_group, g, scatter_idx, gather_idx)
219
+
220
+ out = video_sparse_attn(
221
+ q,
222
+ k,
223
+ v,
224
+ g,
225
+ sparsity,
226
+ num_tiles,
227
+ total_seq_length,
228
+ tile_partition_indices,
229
+ reverse_tile_partition_indices,
230
+ variable_block_sizes,
231
+ non_pad_index,
232
+ )
233
+
234
+ out = SeqAllToAll4D.apply(sp_ulysses_group, out, gather_idx, scatter_idx)
235
+ return out
@@ -86,7 +86,6 @@ class FluxControlNet(PreTrainedModel):
86
86
  def __init__(
87
87
  self,
88
88
  condition_channels: int = 64,
89
- attn_kwargs: Optional[Dict[str, Any]] = None,
90
89
  device: str = "cuda:0",
91
90
  dtype: torch.dtype = torch.bfloat16,
92
91
  ):
@@ -103,10 +102,7 @@ class FluxControlNet(PreTrainedModel):
103
102
  self.x_embedder = nn.Linear(64, 3072, device=device, dtype=dtype)
104
103
  self.controlnet_x_embedder = nn.Linear(condition_channels, 3072)
105
104
  self.blocks = nn.ModuleList(
106
- [
107
- FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
108
- for _ in range(6)
109
- ]
105
+ [FluxDoubleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(6)]
110
106
  )
111
107
  # controlnet projection
112
108
  self.blocks_proj = nn.ModuleList(
@@ -128,6 +124,7 @@ class FluxControlNet(PreTrainedModel):
128
124
  image_ids: torch.Tensor,
129
125
  text_ids: torch.Tensor,
130
126
  guidance: torch.Tensor,
127
+ attn_kwargs: Optional[Dict[str, Any]] = None,
131
128
  ):
132
129
  hidden_states = self.x_embedder(hidden_states) + self.controlnet_x_embedder(control_condition)
133
130
  condition = (
@@ -141,7 +138,9 @@ class FluxControlNet(PreTrainedModel):
141
138
  # double block
142
139
  double_block_outputs = []
143
140
  for i, block in enumerate(self.blocks):
144
- hidden_states, prompt_emb = block(hidden_states, prompt_emb, condition, image_rotary_emb)
141
+ hidden_states, prompt_emb = block(
142
+ hidden_states, prompt_emb, condition, image_rotary_emb, attn_kwargs=attn_kwargs
143
+ )
145
144
  double_block_outputs.append(self.blocks_proj[i](hidden_states))
146
145
 
147
146
  # apply control scale
@@ -149,24 +148,13 @@ class FluxControlNet(PreTrainedModel):
149
148
  return double_block_outputs, None
150
149
 
151
150
  @classmethod
152
- def from_state_dict(
153
- cls,
154
- state_dict: Dict[str, torch.Tensor],
155
- device: str,
156
- dtype: torch.dtype,
157
- attn_kwargs: Optional[Dict[str, Any]] = None,
158
- ):
151
+ def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
159
152
  if "controlnet_x_embedder.weight" in state_dict:
160
153
  condition_channels = state_dict["controlnet_x_embedder.weight"].shape[1]
161
154
  else:
162
155
  condition_channels = 64
163
156
 
164
- model = cls(
165
- condition_channels=condition_channels,
166
- attn_kwargs=attn_kwargs,
167
- device="meta",
168
- dtype=dtype,
169
- )
157
+ model = cls(condition_channels=condition_channels, device="meta", dtype=dtype)
170
158
  model.requires_grad_(False)
171
159
  model.load_state_dict(state_dict, assign=True)
172
160
  model.to(device=device, dtype=dtype, non_blocking=True)