diffsynth-engine 0.6.1.dev22__tar.gz → 0.6.1.dev24__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (210) hide show
  1. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/PKG-INFO +1 -1
  2. diffsynth_engine-0.6.1.dev24/diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
  3. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/configs/pipeline.py +35 -12
  4. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/basic/attention.py +59 -20
  5. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/basic/transformer_helper.py +36 -2
  6. diffsynth_engine-0.6.1.dev24/diffsynth_engine/models/basic/video_sparse_attention.py +235 -0
  7. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/flux/flux_controlnet.py +7 -19
  8. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/flux/flux_dit.py +22 -36
  9. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/flux/flux_dit_fbcache.py +9 -7
  10. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  11. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/qwen_image/qwen_image_dit.py +26 -32
  12. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
  13. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/wan/wan_dit.py +62 -22
  14. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/pipelines/flux_image.py +11 -10
  15. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/pipelines/qwen_image.py +16 -15
  16. diffsynth_engine-0.6.1.dev24/diffsynth_engine/pipelines/utils.py +71 -0
  17. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/pipelines/wan_s2v.py +3 -8
  18. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/pipelines/wan_video.py +11 -13
  19. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/tokenizers/base.py +6 -0
  20. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/tokenizers/qwen2.py +12 -4
  21. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/constants.py +13 -12
  22. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/flag.py +6 -0
  23. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/parallel.py +51 -6
  24. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine.egg-info/PKG-INFO +1 -1
  25. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine.egg-info/SOURCES.txt +13 -11
  26. diffsynth_engine-0.6.1.dev22/diffsynth_engine/pipelines/utils.py +0 -19
  27. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/.gitattributes +0 -0
  28. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/.gitignore +0 -0
  29. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/.pre-commit-config.yaml +0 -0
  30. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/LICENSE +0 -0
  31. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/MANIFEST.in +0 -0
  32. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/README.md +0 -0
  33. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/assets/dingtalk.png +0 -0
  34. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/assets/showcase.jpeg +0 -0
  35. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/assets/tongyi.svg +0 -0
  36. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/__init__.py +0 -0
  37. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/__init__.py +0 -0
  38. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
  39. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
  40. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
  41. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
  42. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
  43. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
  44. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  45. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
  46. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
  47. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
  48. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
  49. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
  50. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
  51. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
  52. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  53. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
  54. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  55. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
  56. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
  57. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
  58. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
  59. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
  60. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
  61. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
  62. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
  63. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
  64. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/models/components/vae.json +0 -0
  65. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
  66. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
  67. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
  68. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_config.json +0 -0
  69. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_vision_config.json +0 -0
  70. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae.json +0 -0
  71. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae_keymap.json +0 -0
  72. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
  73. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
  74. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
  75. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
  76. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
  77. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
  78. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/dit/wan2.1-flf2v-14b.json → /diffsynth_engine-0.6.1.dev24/diffsynth_engine/conf/models/wan/dit/wan2.1_flf2v_14b.json +0 -0
  79. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/dit/wan2.1-i2v-14b.json → /diffsynth_engine-0.6.1.dev24/diffsynth_engine/conf/models/wan/dit/wan2.1_i2v_14b.json +0 -0
  80. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-1.3b.json → /diffsynth_engine-0.6.1.dev24/diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_1.3b.json +0 -0
  81. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-14b.json → /diffsynth_engine-0.6.1.dev24/diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_14b.json +0 -0
  82. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/dit/wan2.2-i2v-a14b.json → /diffsynth_engine-0.6.1.dev24/diffsynth_engine/conf/models/wan/dit/wan2.2_i2v_a14b.json +0 -0
  83. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/dit/wan2.2-s2v-14b.json → /diffsynth_engine-0.6.1.dev24/diffsynth_engine/conf/models/wan/dit/wan2.2_s2v_14b.json +0 -0
  84. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/dit/wan2.2-t2v-a14b.json → /diffsynth_engine-0.6.1.dev24/diffsynth_engine/conf/models/wan/dit/wan2.2_t2v_a14b.json +0 -0
  85. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/dit/wan2.2-ti2v-5b.json → /diffsynth_engine-0.6.1.dev24/diffsynth_engine/conf/models/wan/dit/wan2.2_ti2v_5b.json +0 -0
  86. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/vae/wan2.1-vae.json → /diffsynth_engine-0.6.1.dev24/diffsynth_engine/conf/models/wan/vae/wan2.1_vae.json +0 -0
  87. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/vae/wan2.2-vae.json → /diffsynth_engine-0.6.1.dev24/diffsynth_engine/conf/models/wan/vae/wan2.2_vae.json +0 -0
  88. /diffsynth_engine-0.6.1.dev22/diffsynth_engine/conf/models/wan/vae/wan-vae-keymap.json → /diffsynth_engine-0.6.1.dev24/diffsynth_engine/conf/models/wan/vae/wan_vae_keymap.json +0 -0
  89. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
  90. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
  91. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
  92. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
  93. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
  94. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  95. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
  96. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
  97. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/qwen_image/qwen2_vl_image_processor.json +0 -0
  98. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/added_tokens.json +0 -0
  99. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/merges.txt +0 -0
  100. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/special_tokens_map.json +0 -0
  101. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer.json +0 -0
  102. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer_config.json +0 -0
  103. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/vocab.json +0 -0
  104. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
  105. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
  106. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
  107. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
  108. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
  109. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
  110. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
  111. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
  112. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
  113. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  114. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
  115. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
  116. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/configs/__init__.py +0 -0
  117. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/configs/controlnet.py +0 -0
  118. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/kernels/__init__.py +0 -0
  119. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/__init__.py +0 -0
  120. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/base.py +0 -0
  121. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/basic/__init__.py +0 -0
  122. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/basic/lora.py +0 -0
  123. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
  124. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/basic/timestep.py +0 -0
  125. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/basic/unet_helper.py +0 -0
  126. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/flux/__init__.py +0 -0
  127. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/flux/flux_redux.py +0 -0
  128. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
  129. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/flux/flux_vae.py +0 -0
  130. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/hunyuan3d/__init__.py +0 -0
  131. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +0 -0
  132. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/hunyuan3d/hunyuan3d_dit.py +0 -0
  133. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/hunyuan3d/hunyuan3d_vae.py +0 -0
  134. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/hunyuan3d/moe.py +0 -0
  135. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/hunyuan3d/surface_extractor.py +0 -0
  136. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/hunyuan3d/volume_decoder.py +0 -0
  137. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/qwen_image/__init__.py +0 -0
  138. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/qwen_image/qwen2_5_vl.py +0 -0
  139. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/qwen_image/qwen_image_vae.py +0 -0
  140. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/sd/__init__.py +0 -0
  141. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/sd/sd_controlnet.py +0 -0
  142. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
  143. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/sd/sd_unet.py +0 -0
  144. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/sd/sd_vae.py +0 -0
  145. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/sd3/__init__.py +0 -0
  146. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/sd3/sd3_dit.py +0 -0
  147. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
  148. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
  149. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/sdxl/__init__.py +0 -0
  150. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/sdxl/sdxl_controlnet.py +0 -0
  151. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
  152. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/sdxl/sdxl_unet.py +0 -0
  153. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
  154. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/text_encoder/__init__.py +0 -0
  155. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/text_encoder/clip.py +0 -0
  156. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/text_encoder/siglip.py +0 -0
  157. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/text_encoder/t5.py +0 -0
  158. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/vae/__init__.py +0 -0
  159. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/vae/vae.py +0 -0
  160. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/wan/__init__.py +0 -0
  161. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/wan/wan_audio_encoder.py +0 -0
  162. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
  163. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/wan/wan_s2v_dit.py +0 -0
  164. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
  165. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/models/wan/wan_vae.py +0 -0
  166. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/pipelines/__init__.py +0 -0
  167. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/pipelines/base.py +0 -0
  168. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/pipelines/hunyuan3d_shape.py +0 -0
  169. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/pipelines/sd_image.py +0 -0
  170. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/pipelines/sdxl_image.py +0 -0
  171. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/processor/__init__.py +0 -0
  172. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/processor/canny_processor.py +0 -0
  173. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/processor/depth_processor.py +0 -0
  174. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/tokenizers/__init__.py +0 -0
  175. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/tokenizers/clip.py +0 -0
  176. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/tokenizers/qwen2_vl_image_processor.py +0 -0
  177. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/tokenizers/qwen2_vl_processor.py +0 -0
  178. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/tokenizers/t5.py +0 -0
  179. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/tokenizers/wan.py +0 -0
  180. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/tools/__init__.py +0 -0
  181. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/tools/flux_inpainting_tool.py +0 -0
  182. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/tools/flux_outpainting_tool.py +0 -0
  183. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/tools/flux_reference_tool.py +0 -0
  184. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/tools/flux_replace_tool.py +0 -0
  185. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/__init__.py +0 -0
  186. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/cache.py +0 -0
  187. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/download.py +0 -0
  188. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/env.py +0 -0
  189. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/fp8_linear.py +0 -0
  190. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/gguf.py +0 -0
  191. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/image.py +0 -0
  192. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/loader.py +0 -0
  193. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/lock.py +0 -0
  194. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/logging.py +0 -0
  195. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/memory/__init__.py +0 -0
  196. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/memory/linear_regression.py +0 -0
  197. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/memory/memory_predcit_model.py +0 -0
  198. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/offload.py +0 -0
  199. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/onnx.py +0 -0
  200. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/platform.py +0 -0
  201. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/prompt.py +0 -0
  202. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine/utils/video.py +0 -0
  203. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
  204. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine.egg-info/requires.txt +0 -0
  205. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/diffsynth_engine.egg-info/top_level.txt +0 -0
  206. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/docs/tutorial.md +0 -0
  207. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/docs/tutorial_zh.md +0 -0
  208. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/pyproject.toml +0 -0
  209. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/setup.cfg +0 -0
  210. {diffsynth_engine-0.6.1.dev22 → diffsynth_engine-0.6.1.dev24}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev22
3
+ Version: 0.6.1.dev24
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -0,0 +1,41 @@
1
+ {
2
+ "diffusers": {
3
+ "global_rename_dict": {
4
+ "patch_embedding": "patch_embedding",
5
+ "condition_embedder.text_embedder.linear_1": "text_embedding.0",
6
+ "condition_embedder.text_embedder.linear_2": "text_embedding.2",
7
+ "condition_embedder.time_embedder.linear_1": "time_embedding.0",
8
+ "condition_embedder.time_embedder.linear_2": "time_embedding.2",
9
+ "condition_embedder.time_proj": "time_projection.1",
10
+ "condition_embedder.image_embedder.norm1": "img_emb.proj.0",
11
+ "condition_embedder.image_embedder.ff.net.0.proj": "img_emb.proj.1",
12
+ "condition_embedder.image_embedder.ff.net.2": "img_emb.proj.3",
13
+ "condition_embedder.image_embedder.norm2": "img_emb.proj.4",
14
+ "condition_embedder.image_embedder.pos_embed": "img_emb.emb_pos",
15
+ "proj_out": "head.head",
16
+ "scale_shift_table": "head.modulation"
17
+ },
18
+ "rename_dict": {
19
+ "attn1.to_q": "self_attn.q",
20
+ "attn1.to_k": "self_attn.k",
21
+ "attn1.to_v": "self_attn.v",
22
+ "attn1.to_out.0": "self_attn.o",
23
+ "attn1.norm_q": "self_attn.norm_q",
24
+ "attn1.norm_k": "self_attn.norm_k",
25
+ "to_gate_compress": "self_attn.gate_compress",
26
+ "attn2.to_q": "cross_attn.q",
27
+ "attn2.to_k": "cross_attn.k",
28
+ "attn2.to_v": "cross_attn.v",
29
+ "attn2.to_out.0": "cross_attn.o",
30
+ "attn2.norm_q": "cross_attn.norm_q",
31
+ "attn2.norm_k": "cross_attn.norm_k",
32
+ "attn2.add_k_proj": "cross_attn.k_img",
33
+ "attn2.add_v_proj": "cross_attn.v_img",
34
+ "attn2.norm_added_k": "cross_attn.norm_k_img",
35
+ "norm2": "norm3",
36
+ "ffn.net.0.proj": "ffn.0",
37
+ "ffn.net.2": "ffn.2",
38
+ "scale_shift_table": "modulation"
39
+ }
40
+ }
41
+ }
@@ -5,6 +5,7 @@ from dataclasses import dataclass, field
5
5
  from typing import List, Dict, Tuple, Optional
6
6
 
7
7
  from diffsynth_engine.configs.controlnet import ControlType
8
+ from diffsynth_engine.models.basic.video_sparse_attention import get_vsa_kwargs
8
9
 
9
10
 
10
11
  @dataclass
@@ -30,16 +31,43 @@ class AttnImpl(Enum):
30
31
  SDPA = "sdpa" # Scaled Dot Product Attention
31
32
  SAGE = "sage" # Sage Attention
32
33
  SPARGE = "sparge" # Sparge Attention
34
+ VSA = "vsa" # Video Sparse Attention
35
+
36
+
37
+ @dataclass
38
+ class SpargeAttentionParams:
39
+ smooth_k: bool = True
40
+ cdfthreshd: float = 0.6
41
+ simthreshd1: float = 0.98
42
+ pvthreshd: float = 50.0
43
+
44
+
45
+ @dataclass
46
+ class VideoSparseAttentionParams:
47
+ sparsity: float = 0.9
33
48
 
34
49
 
35
50
  @dataclass
36
51
  class AttentionConfig:
37
52
  dit_attn_impl: AttnImpl = AttnImpl.AUTO
38
- # Sparge Attention
39
- sparge_smooth_k: bool = True
40
- sparge_cdfthreshd: float = 0.6
41
- sparge_simthreshd1: float = 0.98
42
- sparge_pvthreshd: float = 50.0
53
+ attn_params: Optional[SpargeAttentionParams | VideoSparseAttentionParams] = None
54
+
55
+ def get_attn_kwargs(self, latents: torch.Tensor, device: str) -> Dict:
56
+ attn_kwargs = {"attn_impl": self.dit_attn_impl.value}
57
+ if isinstance(self.attn_params, SpargeAttentionParams):
58
+ assert self.dit_attn_impl == AttnImpl.SPARGE
59
+ attn_kwargs.update(
60
+ {
61
+ "smooth_k": self.attn_params.smooth_k,
62
+ "simthreshd1": self.attn_params.simthreshd1,
63
+ "cdfthreshd": self.attn_params.cdfthreshd,
64
+ "pvthreshd": self.attn_params.pvthreshd,
65
+ }
66
+ )
67
+ elif isinstance(self.attn_params, VideoSparseAttentionParams):
68
+ assert self.dit_attn_impl == AttnImpl.VSA
69
+ attn_kwargs.update(get_vsa_kwargs(latents.shape[2:], (1, 2, 2), self.attn_params.sparsity, device=device))
70
+ return attn_kwargs
43
71
 
44
72
 
45
73
  @dataclass
@@ -234,16 +262,11 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
234
262
  encoder_dtype: torch.dtype = torch.bfloat16
235
263
  vae_dtype: torch.dtype = torch.float32
236
264
 
265
+ load_encoder: bool = True
266
+
237
267
  # override OptimizationConfig
238
268
  fbcache_relative_l1_threshold = 0.009
239
269
 
240
- # override BaseConfig
241
- vae_tiled: bool = True
242
- vae_tile_size: Tuple[int, int] = (34, 34)
243
- vae_tile_stride: Tuple[int, int] = (18, 16)
244
-
245
- load_encoder: bool = True
246
-
247
270
  @classmethod
248
271
  def basic_config(
249
272
  cls,
@@ -12,6 +12,7 @@ from diffsynth_engine.utils.flag import (
12
12
  SDPA_AVAILABLE,
13
13
  SAGE_ATTN_AVAILABLE,
14
14
  SPARGE_ATTN_AVAILABLE,
15
+ VIDEO_SPARSE_ATTN_AVAILABLE,
15
16
  )
16
17
  from diffsynth_engine.utils.platform import DTYPE_FP8
17
18
 
@@ -20,12 +21,6 @@ FA3_MAX_HEADDIM = 256
20
21
  logger = logging.get_logger(__name__)
21
22
 
22
23
 
23
- def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
24
- padding_size = (alignment - x.shape[dim] % alignment) % alignment
25
- padded_x = F.pad(x, (0, padding_size), "constant", 0)
26
- return padded_x[..., : x.shape[dim]]
27
-
28
-
29
24
  if FLASH_ATTN_3_AVAILABLE:
30
25
  from flash_attn_interface import flash_attn_func as flash_attn3
31
26
  if FLASH_ATTN_2_AVAILABLE:
@@ -33,6 +28,11 @@ if FLASH_ATTN_2_AVAILABLE:
33
28
  if XFORMERS_AVAILABLE:
34
29
  from xformers.ops import memory_efficient_attention
35
30
 
31
+ def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
32
+ padding_size = (alignment - x.shape[dim] % alignment) % alignment
33
+ padded_x = F.pad(x, (0, padding_size), "constant", 0)
34
+ return padded_x[..., : x.shape[dim]]
35
+
36
36
  def xformers_attn(q, k, v, attn_mask=None, scale=None):
37
37
  if attn_mask is not None:
38
38
  if attn_mask.ndim == 2:
@@ -94,6 +94,13 @@ if SPARGE_ATTN_AVAILABLE:
94
94
  return out.transpose(1, 2)
95
95
 
96
96
 
97
+ if VIDEO_SPARSE_ATTN_AVAILABLE:
98
+ from diffsynth_engine.models.basic.video_sparse_attention import (
99
+ video_sparse_attn,
100
+ distributed_video_sparse_attn,
101
+ )
102
+
103
+
97
104
  def eager_attn(q, k, v, attn_mask=None, scale=None):
98
105
  q = q.transpose(1, 2)
99
106
  k = k.transpose(1, 2)
@@ -109,9 +116,10 @@ def eager_attn(q, k, v, attn_mask=None, scale=None):
109
116
 
110
117
 
111
118
  def attention(
112
- q,
113
- k,
114
- v,
119
+ q: torch.Tensor,
120
+ k: torch.Tensor,
121
+ v: torch.Tensor,
122
+ g: Optional[torch.Tensor] = None,
115
123
  attn_impl: Optional[str] = "auto",
116
124
  attn_mask: Optional[torch.Tensor] = None,
117
125
  scale: Optional[float] = None,
@@ -133,6 +141,7 @@ def attention(
133
141
  "sdpa",
134
142
  "sage",
135
143
  "sparge",
144
+ "vsa",
136
145
  ]
137
146
  flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
138
147
  if attn_impl is None or attn_impl == "auto":
@@ -189,10 +198,24 @@ def attention(
189
198
  v,
190
199
  attn_mask=attn_mask,
191
200
  scale=scale,
192
- smooth_k=kwargs.get("sparge_smooth_k", True),
193
- simthreshd1=kwargs.get("sparge_simthreshd1", 0.6),
194
- cdfthreshd=kwargs.get("sparge_cdfthreshd", 0.98),
195
- pvthreshd=kwargs.get("sparge_pvthreshd", 50),
201
+ smooth_k=kwargs.get("smooth_k", True),
202
+ simthreshd1=kwargs.get("simthreshd1", 0.6),
203
+ cdfthreshd=kwargs.get("cdfthreshd", 0.98),
204
+ pvthreshd=kwargs.get("pvthreshd", 50),
205
+ )
206
+ if attn_impl == "vsa":
207
+ return video_sparse_attn(
208
+ q,
209
+ k,
210
+ v,
211
+ g,
212
+ sparsity=kwargs.get("sparsity"),
213
+ num_tiles=kwargs.get("num_tiles"),
214
+ total_seq_length=kwargs.get("total_seq_length"),
215
+ tile_partition_indices=kwargs.get("tile_partition_indices"),
216
+ reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"),
217
+ variable_block_sizes=kwargs.get("variable_block_sizes"),
218
+ non_pad_index=kwargs.get("non_pad_index"),
196
219
  )
197
220
  raise ValueError(f"Invalid attention implementation: {attn_impl}")
198
221
 
@@ -242,9 +265,10 @@ class Attention(nn.Module):
242
265
 
243
266
 
244
267
  def long_context_attention(
245
- q,
246
- k,
247
- v,
268
+ q: torch.Tensor,
269
+ k: torch.Tensor,
270
+ v: torch.Tensor,
271
+ g: Optional[torch.Tensor] = None,
248
272
  attn_impl: Optional[str] = None,
249
273
  attn_mask: Optional[torch.Tensor] = None,
250
274
  scale: Optional[float] = None,
@@ -267,6 +291,7 @@ def long_context_attention(
267
291
  "sdpa",
268
292
  "sage",
269
293
  "sparge",
294
+ "vsa",
270
295
  ]
271
296
  assert attn_mask is None, "long context attention does not support attention mask"
272
297
  flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
@@ -307,11 +332,25 @@ def long_context_attention(
307
332
  if attn_impl == "sparge":
308
333
  attn_processor = SparseAttentionMeansim()
309
334
  # default args from spas_sage2_attn_meansim_cuda
310
- attn_processor.smooth_k = torch.tensor(kwargs.get("sparge_smooth_k", True))
311
- attn_processor.simthreshd1 = torch.tensor(kwargs.get("sparge_simthreshd1", 0.6))
312
- attn_processor.cdfthreshd = torch.tensor(kwargs.get("sparge_cdfthreshd", 0.98))
313
- attn_processor.pvthreshd = torch.tensor(kwargs.get("sparge_pvthreshd", 50))
335
+ attn_processor.smooth_k = torch.tensor(kwargs.get("smooth_k", True))
336
+ attn_processor.simthreshd1 = torch.tensor(kwargs.get("simthreshd1", 0.6))
337
+ attn_processor.cdfthreshd = torch.tensor(kwargs.get("cdfthreshd", 0.98))
338
+ attn_processor.pvthreshd = torch.tensor(kwargs.get("pvthreshd", 50))
314
339
  return LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)(
315
340
  q, k, v, softmax_scale=scale
316
341
  )
342
+ if attn_impl == "vsa":
343
+ return distributed_video_sparse_attn(
344
+ q,
345
+ k,
346
+ v,
347
+ g,
348
+ sparsity=kwargs.get("sparsity"),
349
+ num_tiles=kwargs.get("num_tiles"),
350
+ total_seq_length=kwargs.get("total_seq_length"),
351
+ tile_partition_indices=kwargs.get("tile_partition_indices"),
352
+ reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"),
353
+ variable_block_sizes=kwargs.get("variable_block_sizes"),
354
+ non_pad_index=kwargs.get("non_pad_index"),
355
+ )
317
356
  raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
@@ -1,5 +1,6 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
3
4
  import math
4
5
 
5
6
 
@@ -91,8 +92,8 @@ class NewGELUActivation(nn.Module):
91
92
  the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
92
93
  """
93
94
 
94
- def forward(self, input: "torch.Tensor") -> "torch.Tensor":
95
- return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
95
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
96
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
96
97
 
97
98
 
98
99
  class ApproximateGELU(nn.Module):
@@ -115,3 +116,36 @@ class ApproximateGELU(nn.Module):
115
116
  def forward(self, x: torch.Tensor) -> torch.Tensor:
116
117
  x = self.proj(x)
117
118
  return x * torch.sigmoid(1.702 * x)
119
+
120
+
121
+ class GELU(nn.Module):
122
+ r"""
123
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
124
+
125
+ Parameters:
126
+ dim_in (`int`): The number of channels in the input.
127
+ dim_out (`int`): The number of channels in the output.
128
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
129
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ dim_in: int,
135
+ dim_out: int,
136
+ approximate: str = "none",
137
+ bias: bool = True,
138
+ device: str = "cuda:0",
139
+ dtype: torch.dtype = torch.float16,
140
+ ):
141
+ super().__init__()
142
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias, device=device, dtype=dtype)
143
+ self.approximate = approximate
144
+
145
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
146
+ return F.gelu(gate, approximate=self.approximate)
147
+
148
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
149
+ x = self.proj(x)
150
+ x = self.gelu(x)
151
+ return x
@@ -0,0 +1,235 @@
1
+ import torch
2
+ import math
3
+ import functools
4
+
5
+ from vsa import video_sparse_attn as vsa_core
6
+ from diffsynth_engine.utils.parallel import get_sp_ulysses_group, get_sp_ring_world_size
7
+
8
+ VSA_TILE_SIZE = (4, 4, 4)
9
+
10
+
11
+ @functools.lru_cache(maxsize=10)
12
+ def get_tile_partition_indices(
13
+ dit_seq_shape: tuple[int, int, int],
14
+ tile_size: tuple[int, int, int],
15
+ device: torch.device,
16
+ ) -> torch.LongTensor:
17
+ T, H, W = dit_seq_shape
18
+ ts, hs, ws = tile_size
19
+ indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W)
20
+ ls = []
21
+ for t in range(math.ceil(T / ts)):
22
+ for h in range(math.ceil(H / hs)):
23
+ for w in range(math.ceil(W / ws)):
24
+ ls.append(
25
+ indices[
26
+ t * ts : min(t * ts + ts, T), h * hs : min(h * hs + hs, H), w * ws : min(w * ws + ws, W)
27
+ ].flatten()
28
+ )
29
+ index = torch.cat(ls, dim=0)
30
+ return index
31
+
32
+
33
+ @functools.lru_cache(maxsize=10)
34
+ def get_reverse_tile_partition_indices(
35
+ dit_seq_shape: tuple[int, int, int],
36
+ tile_size: tuple[int, int, int],
37
+ device: torch.device,
38
+ ) -> torch.LongTensor:
39
+ return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device))
40
+
41
+
42
+ @functools.lru_cache(maxsize=10)
43
+ def construct_variable_block_sizes(
44
+ dit_seq_shape: tuple[int, int, int],
45
+ num_tiles: tuple[int, int, int],
46
+ device: torch.device,
47
+ ) -> torch.LongTensor:
48
+ """
49
+ Compute the number of valid (non-padded) tokens inside every
50
+ (ts_t x ts_h x ts_w) tile after padding -- flattened in the order
51
+ (t-tile, h-tile, w-tile) that `rearrange` uses.
52
+
53
+ Returns
54
+ -------
55
+ torch.LongTensor # shape: [∏ full_window_size]
56
+ """
57
+ # unpack
58
+ t, h, w = dit_seq_shape
59
+ ts_t, ts_h, ts_w = VSA_TILE_SIZE
60
+ n_t, n_h, n_w = num_tiles
61
+
62
+ def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor:
63
+ """Vector with the size of each tile along one dimension."""
64
+ sizes = torch.full((n_tiles,), tile, dtype=torch.int, device=device)
65
+ # size of last (possibly partial) tile
66
+ remainder = dim_len - (n_tiles - 1) * tile
67
+ sizes[-1] = remainder if remainder > 0 else tile
68
+ return sizes
69
+
70
+ t_sizes = _sizes(t, ts_t, n_t) # [n_t]
71
+ h_sizes = _sizes(h, ts_h, n_h) # [n_h]
72
+ w_sizes = _sizes(w, ts_w, n_w) # [n_w]
73
+
74
+ # broadcast‑multiply to get voxels per tile, then flatten
75
+ block_sizes = (
76
+ t_sizes[:, None, None] # [n_t, 1, 1]
77
+ * h_sizes[None, :, None] # [1, n_h, 1]
78
+ * w_sizes[None, None, :] # [1, 1, n_w]
79
+ ).reshape(-1) # [n_t * n_h * n_w]
80
+
81
+ return block_sizes
82
+
83
+
84
+ @functools.lru_cache(maxsize=10)
85
+ def get_non_pad_index(
86
+ variable_block_sizes: torch.LongTensor,
87
+ max_block_size: int,
88
+ ):
89
+ n_win = variable_block_sizes.shape[0]
90
+ device = variable_block_sizes.device
91
+ starts_pad = torch.arange(n_win, device=device) * max_block_size
92
+ index_pad = starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :]
93
+ index_mask = torch.arange(max_block_size, device=device)[None, :] < variable_block_sizes[:, None]
94
+ return index_pad[index_mask]
95
+
96
+
97
+ def get_vsa_kwargs(
98
+ latent_shape: tuple[int, int, int],
99
+ patch_size: tuple[int, int, int],
100
+ sparsity: float,
101
+ device: torch.device,
102
+ ):
103
+ dit_seq_shape = (
104
+ latent_shape[0] // patch_size[0],
105
+ latent_shape[1] // patch_size[1],
106
+ latent_shape[2] // patch_size[2],
107
+ )
108
+
109
+ num_tiles = (
110
+ math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]),
111
+ math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]),
112
+ math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]),
113
+ )
114
+ total_seq_length = math.prod(dit_seq_shape)
115
+
116
+ tile_partition_indices = get_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
117
+ reverse_tile_partition_indices = get_reverse_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
118
+ variable_block_sizes = construct_variable_block_sizes(dit_seq_shape, num_tiles, device)
119
+ non_pad_index = get_non_pad_index(variable_block_sizes, math.prod(VSA_TILE_SIZE))
120
+
121
+ return {
122
+ "sparsity": sparsity,
123
+ "num_tiles": num_tiles,
124
+ "total_seq_length": total_seq_length,
125
+ "tile_partition_indices": tile_partition_indices,
126
+ "reverse_tile_partition_indices": reverse_tile_partition_indices,
127
+ "variable_block_sizes": variable_block_sizes,
128
+ "non_pad_index": non_pad_index,
129
+ }
130
+
131
+
132
+ def tile(
133
+ x: torch.Tensor,
134
+ num_tiles: tuple[int, int, int],
135
+ tile_partition_indices: torch.LongTensor,
136
+ non_pad_index: torch.LongTensor,
137
+ ) -> torch.Tensor:
138
+ t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]
139
+ h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]
140
+ w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]
141
+
142
+ x_padded = torch.zeros(
143
+ (x.shape[0], t_padded_size * h_padded_size * w_padded_size, x.shape[-2], x.shape[-1]),
144
+ device=x.device,
145
+ dtype=x.dtype,
146
+ )
147
+ x_padded[:, non_pad_index] = x[:, tile_partition_indices]
148
+ return x_padded
149
+
150
+
151
+ def untile(
152
+ x: torch.Tensor, reverse_tile_partition_indices: torch.LongTensor, non_pad_index: torch.LongTensor
153
+ ) -> torch.Tensor:
154
+ x = x[:, non_pad_index][:, reverse_tile_partition_indices]
155
+ return x
156
+
157
+
158
+ def video_sparse_attn(
159
+ q: torch.Tensor,
160
+ k: torch.Tensor,
161
+ v: torch.Tensor,
162
+ g: torch.Tensor,
163
+ sparsity: float,
164
+ num_tiles: tuple[int, int, int],
165
+ total_seq_length: int,
166
+ tile_partition_indices: torch.LongTensor,
167
+ reverse_tile_partition_indices: torch.LongTensor,
168
+ variable_block_sizes: torch.LongTensor,
169
+ non_pad_index: torch.LongTensor,
170
+ ):
171
+ q = tile(q, num_tiles, tile_partition_indices, non_pad_index)
172
+ k = tile(k, num_tiles, tile_partition_indices, non_pad_index)
173
+ v = tile(v, num_tiles, tile_partition_indices, non_pad_index)
174
+ g = tile(g, num_tiles, tile_partition_indices, non_pad_index)
175
+
176
+ q = q.transpose(1, 2).contiguous()
177
+ k = k.transpose(1, 2).contiguous()
178
+ v = v.transpose(1, 2).contiguous()
179
+ g = g.transpose(1, 2).contiguous()
180
+
181
+ topk = math.ceil((1 - sparsity) * (total_seq_length / math.prod(VSA_TILE_SIZE)))
182
+ out = vsa_core(
183
+ q,
184
+ k,
185
+ v,
186
+ variable_block_sizes=variable_block_sizes,
187
+ topk=topk,
188
+ block_size=VSA_TILE_SIZE,
189
+ compress_attn_weight=g,
190
+ ).transpose(1, 2)
191
+ out = untile(out, reverse_tile_partition_indices, non_pad_index)
192
+ return out
193
+
194
+
195
+ def distributed_video_sparse_attn(
196
+ q: torch.Tensor,
197
+ k: torch.Tensor,
198
+ v: torch.Tensor,
199
+ g: torch.Tensor,
200
+ sparsity: float,
201
+ num_tiles: tuple[int, int, int],
202
+ total_seq_length: int,
203
+ tile_partition_indices: torch.LongTensor,
204
+ reverse_tile_partition_indices: torch.LongTensor,
205
+ variable_block_sizes: torch.LongTensor,
206
+ non_pad_index: torch.LongTensor,
207
+ scatter_idx: int = 2,
208
+ gather_idx: int = 1,
209
+ ):
210
+ from yunchang.comm.all_to_all import SeqAllToAll4D
211
+
212
+ assert get_sp_ring_world_size() == 1, "distributed video sparse attention requires ring degree to be 1"
213
+ sp_ulysses_group = get_sp_ulysses_group()
214
+
215
+ q = SeqAllToAll4D.apply(sp_ulysses_group, q, scatter_idx, gather_idx)
216
+ k = SeqAllToAll4D.apply(sp_ulysses_group, k, scatter_idx, gather_idx)
217
+ v = SeqAllToAll4D.apply(sp_ulysses_group, v, scatter_idx, gather_idx)
218
+ g = SeqAllToAll4D.apply(sp_ulysses_group, g, scatter_idx, gather_idx)
219
+
220
+ out = video_sparse_attn(
221
+ q,
222
+ k,
223
+ v,
224
+ g,
225
+ sparsity,
226
+ num_tiles,
227
+ total_seq_length,
228
+ tile_partition_indices,
229
+ reverse_tile_partition_indices,
230
+ variable_block_sizes,
231
+ non_pad_index,
232
+ )
233
+
234
+ out = SeqAllToAll4D.apply(sp_ulysses_group, out, gather_idx, scatter_idx)
235
+ return out
@@ -86,7 +86,6 @@ class FluxControlNet(PreTrainedModel):
86
86
  def __init__(
87
87
  self,
88
88
  condition_channels: int = 64,
89
- attn_kwargs: Optional[Dict[str, Any]] = None,
90
89
  device: str = "cuda:0",
91
90
  dtype: torch.dtype = torch.bfloat16,
92
91
  ):
@@ -103,10 +102,7 @@ class FluxControlNet(PreTrainedModel):
103
102
  self.x_embedder = nn.Linear(64, 3072, device=device, dtype=dtype)
104
103
  self.controlnet_x_embedder = nn.Linear(condition_channels, 3072)
105
104
  self.blocks = nn.ModuleList(
106
- [
107
- FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
108
- for _ in range(6)
109
- ]
105
+ [FluxDoubleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(6)]
110
106
  )
111
107
  # controlnet projection
112
108
  self.blocks_proj = nn.ModuleList(
@@ -128,6 +124,7 @@ class FluxControlNet(PreTrainedModel):
128
124
  image_ids: torch.Tensor,
129
125
  text_ids: torch.Tensor,
130
126
  guidance: torch.Tensor,
127
+ attn_kwargs: Optional[Dict[str, Any]] = None,
131
128
  ):
132
129
  hidden_states = self.x_embedder(hidden_states) + self.controlnet_x_embedder(control_condition)
133
130
  condition = (
@@ -141,7 +138,9 @@ class FluxControlNet(PreTrainedModel):
141
138
  # double block
142
139
  double_block_outputs = []
143
140
  for i, block in enumerate(self.blocks):
144
- hidden_states, prompt_emb = block(hidden_states, prompt_emb, condition, image_rotary_emb)
141
+ hidden_states, prompt_emb = block(
142
+ hidden_states, prompt_emb, condition, image_rotary_emb, attn_kwargs=attn_kwargs
143
+ )
145
144
  double_block_outputs.append(self.blocks_proj[i](hidden_states))
146
145
 
147
146
  # apply control scale
@@ -149,24 +148,13 @@ class FluxControlNet(PreTrainedModel):
149
148
  return double_block_outputs, None
150
149
 
151
150
  @classmethod
152
- def from_state_dict(
153
- cls,
154
- state_dict: Dict[str, torch.Tensor],
155
- device: str,
156
- dtype: torch.dtype,
157
- attn_kwargs: Optional[Dict[str, Any]] = None,
158
- ):
151
+ def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
159
152
  if "controlnet_x_embedder.weight" in state_dict:
160
153
  condition_channels = state_dict["controlnet_x_embedder.weight"].shape[1]
161
154
  else:
162
155
  condition_channels = 64
163
156
 
164
- model = cls(
165
- condition_channels=condition_channels,
166
- attn_kwargs=attn_kwargs,
167
- device="meta",
168
- dtype=dtype,
169
- )
157
+ model = cls(condition_channels=condition_channels, device="meta", dtype=dtype)
170
158
  model.requires_grad_(False)
171
159
  model.load_state_dict(state_dict, assign=True)
172
160
  model.to(device=device, dtype=dtype, non_blocking=True)