diffsynth-engine 0.1.1__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (236) hide show
  1. diffsynth_engine-0.2.1/.github/workflows/python-publish.yml +41 -0
  2. diffsynth_engine-0.2.1/.gitignore +11 -0
  3. diffsynth_engine-0.2.1/.pre-commit-config.yaml +11 -0
  4. diffsynth_engine-0.2.1/PKG-INFO +34 -0
  5. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/README.md +14 -14
  6. diffsynth_engine-0.2.1/assets/dingtalk.png +0 -0
  7. diffsynth_engine-0.2.1/assets/showcase.jpeg +0 -0
  8. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/__init__.py +10 -0
  9. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +16 -14
  10. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -3
  11. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -3
  12. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +1 -1
  13. diffsynth_engine-0.2.1/diffsynth_engine/models/__init__.py +7 -0
  14. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/base.py +22 -13
  15. diffsynth_engine-0.2.1/diffsynth_engine/models/basic/attention.py +233 -0
  16. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/unet_helper.py +2 -2
  17. diffsynth_engine-0.2.1/diffsynth_engine/models/components/siglip.py +169 -0
  18. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/components/vae.py +0 -1
  19. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/flux/__init__.py +2 -0
  20. diffsynth_engine-0.2.1/diffsynth_engine/models/flux/flux_controlnet.py +160 -0
  21. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/flux/flux_dit.py +67 -96
  22. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/flux/flux_text_encoder.py +1 -3
  23. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/flux/flux_vae.py +1 -1
  24. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd3/sd3_dit.py +1 -7
  25. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sdxl/sdxl_unet.py +1 -7
  26. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/wan_dit.py +146 -79
  27. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/wan_image_encoder.py +2 -3
  28. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/wan_text_encoder.py +46 -13
  29. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/__init__.py +4 -2
  30. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/base.py +66 -31
  31. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/flux_image.py +190 -79
  32. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/sd_image.py +38 -47
  33. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/sdxl_image.py +40 -50
  34. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/wan_video.py +156 -89
  35. diffsynth_engine-0.2.1/diffsynth_engine/tokenizers/__init__.py +6 -0
  36. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/wan.py +17 -22
  37. diffsynth_engine-0.2.1/diffsynth_engine/tools/__init__.py +4 -0
  38. diffsynth_engine-0.2.1/diffsynth_engine/tools/flux_inpainting.py +50 -0
  39. diffsynth_engine-0.2.1/diffsynth_engine/tools/flux_outpainting.py +58 -0
  40. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/download.py +1 -5
  41. diffsynth_engine-0.2.1/diffsynth_engine/utils/env.py +10 -0
  42. diffsynth_engine-0.2.1/diffsynth_engine/utils/flag.py +46 -0
  43. diffsynth_engine-0.2.1/diffsynth_engine/utils/image.py +25 -0
  44. diffsynth_engine-0.2.1/diffsynth_engine/utils/loader.py +32 -0
  45. diffsynth_engine-0.2.1/diffsynth_engine/utils/parallel.py +401 -0
  46. diffsynth_engine-0.2.1/diffsynth_engine.egg-info/PKG-INFO +34 -0
  47. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/SOURCES.txt +98 -2
  48. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/requires.txt +3 -4
  49. diffsynth_engine-0.2.1/docs/tutorial.md +1 -0
  50. diffsynth_engine-0.2.1/docs/tutorial_zh.md +207 -0
  51. diffsynth_engine-0.2.1/examples/flux_lora.py +11 -0
  52. diffsynth_engine-0.2.1/examples/flux_text_to_image.py +8 -0
  53. diffsynth_engine-0.2.1/examples/i2v_input.jpg +0 -0
  54. diffsynth_engine-0.2.1/examples/sdxl_text_to_image.py +14 -0
  55. diffsynth_engine-0.2.1/examples/wan_image_to_video.py +35 -0
  56. diffsynth_engine-0.2.1/examples/wan_lora.py +33 -0
  57. diffsynth_engine-0.2.1/examples/wan_text_to_video.py +28 -0
  58. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/pyproject.toml +8 -8
  59. diffsynth_engine-0.2.1/tests/__init__.py +0 -0
  60. diffsynth_engine-0.2.1/tests/common/__init__.py +0 -0
  61. diffsynth_engine-0.2.1/tests/common/test_case.py +123 -0
  62. diffsynth_engine-0.2.1/tests/common/utils.py +29 -0
  63. diffsynth_engine-0.2.1/tests/data/expect/algorithm/beta_20steps.safetensors +0 -0
  64. diffsynth_engine-0.2.1/tests/data/expect/algorithm/ddim_20steps.safetensors +0 -0
  65. diffsynth_engine-0.2.1/tests/data/expect/algorithm/euler_i10.safetensors +0 -0
  66. diffsynth_engine-0.2.1/tests/data/expect/algorithm/exponential_20steps.safetensors +0 -0
  67. diffsynth_engine-0.2.1/tests/data/expect/algorithm/flow_match_euler_i10.safetensors +0 -0
  68. diffsynth_engine-0.2.1/tests/data/expect/algorithm/karras_20steps.safetensors +0 -0
  69. diffsynth_engine-0.2.1/tests/data/expect/algorithm/output.safetensors +0 -0
  70. diffsynth_engine-0.2.1/tests/data/expect/algorithm/recifited_flow_20steps_flux.safetensors +0 -0
  71. diffsynth_engine-0.2.1/tests/data/expect/algorithm/scaled_linear_20steps.safetensors +0 -0
  72. diffsynth_engine-0.2.1/tests/data/expect/algorithm/sgm_uniform_20steps.safetensors +0 -0
  73. diffsynth_engine-0.2.1/tests/data/expect/flux/flux_dit.safetensors +0 -0
  74. diffsynth_engine-0.2.1/tests/data/expect/flux/flux_inpainting.png +0 -0
  75. diffsynth_engine-0.2.1/tests/data/expect/flux/flux_lora.png +0 -0
  76. diffsynth_engine-0.2.1/tests/data/expect/flux/flux_outpainting.png +0 -0
  77. diffsynth_engine-0.2.1/tests/data/expect/flux/flux_text_encoder_1.safetensors +0 -0
  78. diffsynth_engine-0.2.1/tests/data/expect/flux/flux_text_encoder_2.safetensors +0 -0
  79. diffsynth_engine-0.2.1/tests/data/expect/flux/flux_txt2img.png +0 -0
  80. diffsynth_engine-0.2.1/tests/data/expect/flux/flux_union_pro_canny.png +0 -0
  81. diffsynth_engine-0.2.1/tests/data/expect/flux/flux_vae.safetensors +0 -0
  82. diffsynth_engine-0.2.1/tests/data/expect/sd/sd_inpainting.png +0 -0
  83. diffsynth_engine-0.2.1/tests/data/expect/sd/sd_lora.png +0 -0
  84. diffsynth_engine-0.2.1/tests/data/expect/sd/sd_text_encoder.safetensors +0 -0
  85. diffsynth_engine-0.2.1/tests/data/expect/sd/sd_txt2img.png +0 -0
  86. diffsynth_engine-0.2.1/tests/data/expect/sd/sd_unet.safetensors +0 -0
  87. diffsynth_engine-0.2.1/tests/data/expect/sd/sd_vae.safetensors +0 -0
  88. diffsynth_engine-0.2.1/tests/data/expect/sdxl/sdxl_inpainting.png +0 -0
  89. diffsynth_engine-0.2.1/tests/data/expect/sdxl/sdxl_lora.png +0 -0
  90. diffsynth_engine-0.2.1/tests/data/expect/sdxl/sdxl_text_encoder_1.safetensors +0 -0
  91. diffsynth_engine-0.2.1/tests/data/expect/sdxl/sdxl_text_encoder_2.safetensors +0 -0
  92. diffsynth_engine-0.2.1/tests/data/expect/sdxl/sdxl_txt2img.png +0 -0
  93. diffsynth_engine-0.2.1/tests/data/expect/sdxl/sdxl_unet.safetensors +0 -0
  94. diffsynth_engine-0.2.1/tests/data/expect/sdxl/sdxl_vae.safetensors +0 -0
  95. diffsynth_engine-0.2.1/tests/data/expect/test_siglip_image_encoder.safetensors +0 -0
  96. diffsynth_engine-0.2.1/tests/data/expect/wan/wan_vae.safetensors +0 -0
  97. diffsynth_engine-0.2.1/tests/data/input/astronaut_320_320.mp4 +0 -0
  98. diffsynth_engine-0.2.1/tests/data/input/canny.png +0 -0
  99. diffsynth_engine-0.2.1/tests/data/input/mask_image.png +0 -0
  100. diffsynth_engine-0.2.1/tests/data/input/test_image.png +0 -0
  101. diffsynth_engine-0.2.1/tests/data/input/wukong_1024_1024.png +0 -0
  102. diffsynth_engine-0.2.1/tests/data/input/wukong_480_480.png +0 -0
  103. diffsynth_engine-0.2.1/tests/test_algorithm/__init__.py +0 -0
  104. diffsynth_engine-0.2.1/tests/test_algorithm/test_sampler.py +42 -0
  105. diffsynth_engine-0.2.1/tests/test_algorithm/test_scheduler.py +77 -0
  106. diffsynth_engine-0.2.1/tests/test_models/__init__.py +0 -0
  107. diffsynth_engine-0.2.1/tests/test_models/flux/__init__.py +0 -0
  108. diffsynth_engine-0.2.1/tests/test_models/flux/test_flux_dit.py +208 -0
  109. diffsynth_engine-0.2.1/tests/test_models/flux/test_flux_text_encoder.py +114 -0
  110. diffsynth_engine-0.2.1/tests/test_models/flux/test_flux_vae.py +344 -0
  111. diffsynth_engine-0.2.1/tests/test_models/sd/__init__.py +0 -0
  112. diffsynth_engine-0.2.1/tests/test_models/sd/test_sd_text_encoder.py +72 -0
  113. diffsynth_engine-0.2.1/tests/test_models/sd/test_sd_unet.py +22 -0
  114. diffsynth_engine-0.2.1/tests/test_models/sd/test_sd_vae.py +353 -0
  115. diffsynth_engine-0.2.1/tests/test_models/sdxl/__init__.py +0 -0
  116. diffsynth_engine-0.2.1/tests/test_models/sdxl/test_sdxl_text_encoder.py +163 -0
  117. diffsynth_engine-0.2.1/tests/test_models/sdxl/test_sdxl_unet.py +21 -0
  118. diffsynth_engine-0.2.1/tests/test_models/sdxl/test_sdxl_vae.py +351 -0
  119. diffsynth_engine-0.2.1/tests/test_models/test_siglip.py +17 -0
  120. diffsynth_engine-0.2.1/tests/test_models/wan/test_wan_vae.py +34 -0
  121. diffsynth_engine-0.2.1/tests/test_pipelines/__init__.py +0 -0
  122. diffsynth_engine-0.2.1/tests/test_pipelines/test_flux_controlnet.py +32 -0
  123. diffsynth_engine-0.2.1/tests/test_pipelines/test_flux_image.py +68 -0
  124. diffsynth_engine-0.2.1/tests/test_pipelines/test_sd_image.py +55 -0
  125. diffsynth_engine-0.2.1/tests/test_pipelines/test_sdxl_image.py +59 -0
  126. diffsynth_engine-0.2.1/tests/test_pipelines/test_wan_video.py +24 -0
  127. diffsynth_engine-0.2.1/tests/test_pipelines/test_wan_video_gguf.py +24 -0
  128. diffsynth_engine-0.2.1/tests/test_pipelines/test_wan_video_tp.py +25 -0
  129. diffsynth_engine-0.2.1/tests/test_tokenizers/__init__.py +0 -0
  130. diffsynth_engine-0.2.1/tests/test_tokenizers/test_clip.py +135 -0
  131. diffsynth_engine-0.2.1/tests/test_tokenizers/test_t5.py +138 -0
  132. diffsynth_engine-0.2.1/tests/test_tools/__init__.py +0 -0
  133. diffsynth_engine-0.2.1/tests/test_tools/test_flux_tools.py +31 -0
  134. diffsynth_engine-0.1.1/PKG-INFO +0 -213
  135. diffsynth_engine-0.1.1/diffsynth_engine/models/basic/attention.py +0 -137
  136. diffsynth_engine-0.1.1/diffsynth_engine/models/wan/attention.py +0 -200
  137. diffsynth_engine-0.1.1/diffsynth_engine/tokenizers/__init__.py +0 -4
  138. diffsynth_engine-0.1.1/diffsynth_engine/utils/env.py +0 -7
  139. diffsynth_engine-0.1.1/diffsynth_engine/utils/loader.py +0 -14
  140. diffsynth_engine-0.1.1/diffsynth_engine/utils/parallel.py +0 -191
  141. diffsynth_engine-0.1.1/diffsynth_engine.egg-info/PKG-INFO +0 -213
  142. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/LICENSE +0 -0
  143. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/__init__.py +0 -0
  144. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
  145. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
  146. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
  147. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
  148. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
  149. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  150. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
  151. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
  152. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
  153. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
  154. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
  155. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  156. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  157. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
  158. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
  159. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
  160. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
  161. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
  162. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
  163. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
  164. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
  165. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
  166. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/components/vae.json +0 -0
  167. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
  168. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
  169. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
  170. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
  171. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
  172. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
  173. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
  174. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
  175. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
  176. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +0 -0
  177. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/wan/dit/14b-i2v.json +0 -0
  178. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/wan/dit/14b-t2v.json +0 -0
  179. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
  180. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
  181. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
  182. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
  183. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
  184. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  185. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
  186. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
  187. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
  188. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
  189. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
  190. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
  191. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
  192. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
  193. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
  194. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
  195. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
  196. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  197. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
  198. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
  199. {diffsynth_engine-0.1.1/diffsynth_engine/models → diffsynth_engine-0.2.1/diffsynth_engine/kernels}/__init__.py +0 -0
  200. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/__init__.py +0 -0
  201. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/lora.py +0 -0
  202. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
  203. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/timestep.py +0 -0
  204. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/transformer_helper.py +0 -0
  205. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/components/__init__.py +0 -0
  206. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/components/clip.py +0 -0
  207. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/components/t5.py +0 -0
  208. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd/__init__.py +0 -0
  209. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
  210. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd/sd_unet.py +0 -0
  211. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd/sd_vae.py +0 -0
  212. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd3/__init__.py +0 -0
  213. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
  214. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
  215. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sdxl/__init__.py +0 -0
  216. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
  217. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
  218. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/utils.py +0 -0
  219. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/__init__.py +0 -0
  220. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/wan_vae.py +0 -0
  221. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/base.py +0 -0
  222. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/clip.py +0 -0
  223. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/t5.py +0 -0
  224. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/__init__.py +0 -0
  225. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/constants.py +0 -0
  226. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/fp8_linear.py +0 -0
  227. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/gguf.py +0 -0
  228. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/lock.py +0 -0
  229. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/logging.py +0 -0
  230. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/offload.py +0 -0
  231. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/prompt.py +0 -0
  232. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/video.py +0 -0
  233. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
  234. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/top_level.txt +0 -0
  235. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/setup.cfg +0 -0
  236. {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/setup.py +0 -0
@@ -0,0 +1,41 @@
1
+ name: release
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - 'v**'
7
+
8
+ workflow_dispatch:
9
+ inputs:
10
+ branch:
11
+ required: true
12
+ default: 'main'
13
+
14
+ permissions:
15
+ contents: read
16
+
17
+ concurrency:
18
+ group: ${{ github.workflow }}-${{ github.ref }}
19
+ cancel-in-progress: true
20
+
21
+ jobs:
22
+ build-and-publish:
23
+ runs-on: ubuntu-latest
24
+
25
+ steps:
26
+ - uses: actions/checkout@v4
27
+
28
+ - uses: actions/setup-python@v5
29
+ with:
30
+ python-version: "3.10"
31
+
32
+ - name: Install build
33
+ run: pip install build
34
+
35
+ - name: Build dist
36
+ run: python -m build
37
+
38
+ - name: Publish to PyPI
39
+ run: |
40
+ pip install twine
41
+ twine upload dist/* --skip-existing -p ${{ secrets.PYPI_API_TOKEN }}
@@ -0,0 +1,11 @@
1
+ *.pyc
2
+ .idea/
3
+ .vscode/
4
+ __pycache__/
5
+ tmp/
6
+ build/
7
+ dist/
8
+ *.egg-info/
9
+ .DS_Store/
10
+ .pytest_cache/
11
+ .ruff_cache/
@@ -0,0 +1,11 @@
1
+ repos:
2
+ - repo: https://github.com/astral-sh/ruff-pre-commit
3
+ # Ruff version.
4
+ rev: v0.11.5
5
+ hooks:
6
+ # Run the linter.
7
+ - id: ruff
8
+ types_or: [ python, pyi ]
9
+ # Run the formatter.
10
+ - id: ruff-format
11
+ types_or: [ python, pyi ]
@@ -0,0 +1,34 @@
1
+ Metadata-Version: 2.4
2
+ Name: diffsynth_engine
3
+ Version: 0.2.1
4
+ Author: MuseAI x ModelScope
5
+ Classifier: Programming Language :: Python :: 3
6
+ Classifier: Operating System :: OS Independent
7
+ Requires-Python: >=3.10
8
+ License-File: LICENSE
9
+ Requires-Dist: torch>=2.6
10
+ Requires-Dist: torchvision
11
+ Requires-Dist: xformers; sys_platform == "linux"
12
+ Requires-Dist: safetensors
13
+ Requires-Dist: gguf
14
+ Requires-Dist: einops
15
+ Requires-Dist: ftfy
16
+ Requires-Dist: regex
17
+ Requires-Dist: sentencepiece
18
+ Requires-Dist: tokenizers
19
+ Requires-Dist: modelscope
20
+ Requires-Dist: flufl.lock
21
+ Requires-Dist: scipy
22
+ Requires-Dist: torchsde
23
+ Requires-Dist: pillow
24
+ Requires-Dist: imageio[ffmpeg]
25
+ Requires-Dist: yunchang; sys_platform == "linux"
26
+ Provides-Extra: dev
27
+ Requires-Dist: diffusers==0.31.0; extra == "dev"
28
+ Requires-Dist: transformers==4.45.2; extra == "dev"
29
+ Requires-Dist: build; extra == "dev"
30
+ Requires-Dist: ruff; extra == "dev"
31
+ Requires-Dist: scikit-image; extra == "dev"
32
+ Requires-Dist: pytest; extra == "dev"
33
+ Requires-Dist: pre-commit; extra == "dev"
34
+ Dynamic: license-file
@@ -6,20 +6,20 @@
6
6
  [![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Engine.svg)](https://GitHub.com/modelscope/DiffSynth-Engine/pull/)
7
7
  [![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Engine)](https://GitHub.com/modelscope/DiffSynth-Engine/commit/)
8
8
 
9
- Diffsynth Engine is a high-performance diffusion inference engine designed for developers.
9
+ DiffSynth-Engine is a high-performance engine geared towards buidling efficient inference pipelines for diffusion models.
10
10
 
11
11
  **Key Features:**
12
12
 
13
- - **Clean and Readable Code:** Fully re-implements the Diffusion sampler and scheduler without relying on third-party libraries like k-diffusion, ldm, or sgm.
13
+ - **Thoughtfully-Designed Implementation:** We carefully re-implemented key components in Diffusion pipelines, such as sampler and scheduler, without introducing external dependencies on libraries like k-diffusion, ldm, or sgm.
14
14
 
15
- - **Extensive Model Support:** Compatible with multiple formats (e.g., CivitAI format) of base models and LoRA models , catering to diverse use cases.
15
+ - **Extensive Model Support:** Compatible with popular formats (e.g., CivitAI) of base models and LoRA models , catering to diverse use cases.
16
16
 
17
- - **Flexible Memory Management:** Supports various levels of model quantization (e.g., FP8, INT8)
18
- and offload strategies, enabling users to run large models (e.g., Flux.1 Dev) on limited GPU memory.
17
+ - **Versatile Resource Management:** Comprehensive support for varous model quantization (e.g., FP8, INT8)
18
+ and offloading strategies, enabling loading of larger diffusion models (e.g., Flux.1 Dev) on limited hardware budget of GPU memory.
19
19
 
20
- - **High-Performance Inference:** Optimizes the inference pipeline to achieve fast generation across various hardware environments.
20
+ - **Optimized Performance:** Carefully-crafted inference pipeline to achieve fast generation across various hardware environments.
21
21
 
22
- - **Platform Compatibility:** Supports Windows, macOS (Apple Silicon), and Linux, ensuring a smooth experience across different operating systems.
22
+ - **Cross-Platform Support:** Runnable on Windows, macOS (Apple Silicon), and Linux, ensuring a smooth experience across different operating systems.
23
23
 
24
24
  ## Quick Start
25
25
  ### Requirements
@@ -29,13 +29,13 @@ and offload strategies, enabling users to run large models (e.g., Flux.1 Dev) on
29
29
 
30
30
  ### Installation
31
31
 
32
- Install for PyPI (stable version)
33
- ```python
32
+ Install released version (from PyPI):
33
+ ```shell
34
34
  pip3 install diffsynth-engine
35
35
  ```
36
36
 
37
- Install for source (preview version)
38
- ```python
37
+ Install from source:
38
+ ```shell
39
39
  git clone https://github.com/modelscope/diffsynth-engine.git && cd diffsynth-engine
40
40
  pip3 install -e .
41
41
  ```
@@ -71,10 +71,10 @@ For more details, please refer to our tutorials ([English](./docs/tutorial.md),
71
71
 
72
72
  ## Contact
73
73
 
74
- If you have any questions or feedback, please scan the QR code or send email to muse@alibaba-inc.com.
74
+ If you have any questions or feedback, please scan the QR code below, or send email to muse@alibaba-inc.com.
75
75
 
76
76
  <div style="display: flex; justify-content: space-between;">
77
- <img src="assets/dingtalk.png" alt="dingtalk" style="zoom: 60%;" />
77
+ <img src="assets/dingtalk.png" alt="dingtalk" width="400" />
78
78
  </div>
79
79
 
80
80
  ## License
@@ -82,7 +82,7 @@ This project is licensed under the Apache License 2.0. See the LICENSE file for
82
82
 
83
83
  ## Citation
84
84
 
85
- If you use this codebase, or otherwise found our work valuable, please cite:
85
+ If you use this codebase, or otherwise found our work helpful, please cite:
86
86
 
87
87
  ```bibtex
88
88
  @misc{diffsynth-engine2025,
@@ -7,11 +7,16 @@ from .pipelines import (
7
7
  SDXLModelConfig,
8
8
  SDModelConfig,
9
9
  WanModelConfig,
10
+ ControlNetParams,
10
11
  )
12
+ from .models.flux import FluxControlNet
11
13
  from .utils.download import fetch_model, fetch_modelscope_model, fetch_civitai_model
12
14
  from .utils.video import load_video, save_video
15
+ from .tools import FluxInpaintingTool, FluxOutpaintingTool
16
+
13
17
  __all__ = [
14
18
  "FluxImagePipeline",
19
+ "FluxControlNet",
15
20
  "SDXLImagePipeline",
16
21
  "SDImagePipeline",
17
22
  "WanVideoPipeline",
@@ -19,7 +24,12 @@ __all__ = [
19
24
  "SDXLModelConfig",
20
25
  "SDModelConfig",
21
26
  "WanModelConfig",
27
+ "FluxInpaintingTool",
28
+ "FluxOutpaintingTool",
29
+ "ControlNetParams",
22
30
  "fetch_model",
23
31
  "fetch_modelscope_model",
24
32
  "fetch_civitai_model",
33
+ "load_video",
34
+ "save_video",
25
35
  ]
@@ -5,18 +5,19 @@ from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zer
5
5
 
6
6
 
7
7
  class RecifitedFlowScheduler(BaseScheduler):
8
- def __init__(self,
9
- shift=1.0,
10
- sigma_min=0.001,
8
+ def __init__(
9
+ self,
10
+ shift=1.0,
11
+ sigma_min=0.001,
11
12
  sigma_max=1.0,
12
- num_train_timesteps=1000,
13
+ num_train_timesteps=1000,
13
14
  use_dynamic_shifting=False,
14
15
  ):
15
16
  self.shift = shift
16
17
  self.sigma_min = sigma_min
17
18
  self.sigma_max = sigma_max
18
- self.num_train_timesteps = num_train_timesteps
19
- self.use_dynamic_shifting = use_dynamic_shifting
19
+ self.num_train_timesteps = num_train_timesteps
20
+ self.use_dynamic_shifting = use_dynamic_shifting
20
21
 
21
22
  def _sigma_to_t(self, sigma):
22
23
  return sigma * self.num_train_timesteps
@@ -30,19 +31,20 @@ class RecifitedFlowScheduler(BaseScheduler):
30
31
  def _shift_sigma(self, sigma: torch.Tensor, shift: float):
31
32
  return shift * sigma / (1 + (shift - 1) * sigma)
32
33
 
33
- def schedule(self,
34
- num_inference_steps: int,
35
- mu: float | None = None,
36
- sigma_min: float | None = None,
37
- sigma_max: float | None = None
34
+ def schedule(
35
+ self,
36
+ num_inference_steps: int,
37
+ mu: float | None = None,
38
+ sigma_min: float | None = None,
39
+ sigma_max: float | None = None,
38
40
  ):
39
41
  sigma_min = self.sigma_min if sigma_min is None else sigma_min
40
- sigma_max = self.sigma_max if sigma_max is None else sigma_max
42
+ sigma_max = self.sigma_max if sigma_max is None else sigma_max
41
43
  sigmas = torch.linspace(sigma_max, sigma_min, num_inference_steps)
42
44
  if self.use_dynamic_shifting:
43
- sigmas = self._time_shift(mu, 1.0, sigmas) # FLUX
45
+ sigmas = self._time_shift(mu, 1.0, sigmas) # FLUX
44
46
  else:
45
47
  sigmas = self._shift_sigma(sigmas, self.shift)
46
48
  timesteps = sigmas * self.num_train_timesteps
47
49
  sigmas = append_zero(sigmas)
48
- return sigmas, timesteps
50
+ return sigmas, timesteps
@@ -1,7 +1,4 @@
1
1
  import torch
2
- from .linear import ScaledLinearScheduler
3
- from ..base_scheduler import append_zero
4
- import numpy as np
5
2
 
6
3
  from diffsynth_engine.algorithm.noise_scheduler.stable_diffusion.linear import ScaledLinearScheduler
7
4
  from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
@@ -1,7 +1,4 @@
1
1
  import torch
2
- from .linear import ScaledLinearScheduler
3
- from ..base_scheduler import append_zero
4
- import numpy as np
5
2
 
6
3
  from diffsynth_engine.algorithm.noise_scheduler.stable_diffusion.linear import ScaledLinearScheduler
7
4
  from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
@@ -2,7 +2,7 @@ import torch
2
2
 
3
3
 
4
4
  class FlowMatchEulerSampler:
5
- def initialize(self, init_latents, timesteps, sigmas, mask=None):
5
+ def initialize(self, init_latents, timesteps, sigmas, mask=None):
6
6
  self.init_latents = init_latents
7
7
  self.timesteps = timesteps
8
8
  self.sigmas = sigmas
@@ -0,0 +1,7 @@
1
+ from .base import PreTrainedModel, StateDictConverter
2
+
3
+
4
+ __all__ = [
5
+ "PreTrainedModel",
6
+ "StateDictConverter",
7
+ ]
@@ -1,22 +1,14 @@
1
1
  import os
2
2
  import torch
3
3
  import torch.nn as nn
4
- from typing import Dict, Union
5
- from safetensors.torch import load_file
6
-
4
+ from typing import Dict, Union, List, Any
5
+ from diffsynth_engine.utils.loader import load_file
6
+ from diffsynth_engine.models.basic.lora import LoRALinear, LoRAConv2d
7
7
  from diffsynth_engine.models.utils import no_init_weights
8
8
 
9
9
 
10
- class LoRAStateDictConverter:
11
- def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
12
- return {"lora": lora_state_dict}
13
-
14
-
15
- StateDictType = Dict[str, torch.Tensor]
16
-
17
-
18
10
  class StateDictConverter:
19
- def convert(self, state_dict: StateDictType) -> StateDictType:
11
+ def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
20
12
  return state_dict
21
13
 
22
14
 
@@ -29,17 +21,34 @@ class PreTrainedModel(nn.Module):
29
21
 
30
22
  @classmethod
31
23
  def from_pretrained(cls, pretrained_model_path: Union[str, os.PathLike], device: str, dtype: torch.dtype, **kwargs):
32
- state_dict = load_file(pretrained_model_path, device=device)
24
+ state_dict = load_file(pretrained_model_path)
33
25
  return cls.from_state_dict(state_dict, device=device, dtype=dtype, **kwargs)
34
26
 
35
27
  @classmethod
36
28
  def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype, **kwargs):
37
29
  with no_init_weights():
38
30
  model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, **kwargs)
31
+ model.to_empty(device=device)
39
32
  model.load_state_dict(state_dict)
40
33
  model.to(device=device, dtype=dtype, non_blocking=True)
41
34
  return model
42
35
 
36
+ def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = True):
37
+ for args in lora_args:
38
+ key = args["name"]
39
+ module = self.get_submodule(key)
40
+ if not isinstance(module, (LoRALinear, LoRAConv2d)):
41
+ raise ValueError(f"Unsupported lora key: {key}")
42
+ if fused:
43
+ module.add_frozen_lora(**args)
44
+ else:
45
+ module.add_lora(**args)
46
+
47
+ def unload_loras(self):
48
+ for module in self.modules():
49
+ if isinstance(module, (LoRALinear, LoRAConv2d)):
50
+ module.clear()
51
+
43
52
 
44
53
  def split_suffix(name: str):
45
54
  suffix_list = [
@@ -0,0 +1,233 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange, repeat
4
+ from typing import Optional
5
+
6
+ import torch.nn.functional as F
7
+ from diffsynth_engine.utils import logging
8
+ from diffsynth_engine.utils.flag import (
9
+ FLASH_ATTN_3_AVAILABLE,
10
+ FLASH_ATTN_2_AVAILABLE,
11
+ XFORMERS_AVAILABLE,
12
+ SDPA_AVAILABLE,
13
+ SAGE_ATTN_AVAILABLE,
14
+ SPARGE_ATTN_AVAILABLE,
15
+ )
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
21
+ padding_size = (alignment - x.shape[dim] % alignment) % alignment
22
+ padded_x = F.pad(x, (0, padding_size), "constant", 0)
23
+ return padded_x[..., : x.shape[dim]]
24
+
25
+
26
+ if FLASH_ATTN_3_AVAILABLE:
27
+ from flash_attn_interface import flash_attn_func as flash_attn3
28
+ if FLASH_ATTN_2_AVAILABLE:
29
+ from flash_attn import flash_attn_func as flash_attn2
30
+ if XFORMERS_AVAILABLE:
31
+ from xformers.ops import memory_efficient_attention
32
+
33
+ def xformers_attn(q, k, v, attn_mask=None, scale=None):
34
+ if attn_mask is not None:
35
+ attn_mask = repeat(attn_mask, "S L -> B H S L", B=q.shape[0], H=q.shape[2])
36
+ attn_mask = memory_align(attn_mask)
37
+ return memory_efficient_attention(q, k, v, attn_bias=attn_mask, scale=scale)
38
+
39
+
40
+ if SDPA_AVAILABLE:
41
+
42
+ def sdpa_attn(q, k, v, attn_mask=None, scale=None):
43
+ q = q.transpose(1, 2)
44
+ k = k.transpose(1, 2)
45
+ v = v.transpose(1, 2)
46
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=scale)
47
+ return out.transpose(1, 2)
48
+
49
+
50
+ if SAGE_ATTN_AVAILABLE:
51
+ from sageattention import sageattn
52
+
53
+ def sage_attn(q, k, v, attn_mask=None, scale=None):
54
+ q = q.transpose(1, 2)
55
+ k = k.transpose(1, 2)
56
+ v = v.transpose(1, 2)
57
+ out = sageattn(q, k, v, attn_mask=attn_mask, sm_scale=scale)
58
+ return out.transpose(1, 2)
59
+
60
+
61
+ if SPARGE_ATTN_AVAILABLE:
62
+ from spas_sage_attn import spas_sage2_attn_meansim_cuda
63
+
64
+ def sparge_attn(self, q, k, v, attn_mask=None, scale=None):
65
+ q = q.transpose(1, 2)
66
+ k = k.transpose(1, 2)
67
+ v = v.transpose(1, 2)
68
+ out = spas_sage2_attn_meansim_cuda(q, k, v, attn_mask=attn_mask, scale=scale)
69
+ return out.transpose(1, 2)
70
+
71
+
72
+ def eager_attn(q, k, v, attn_mask=None, scale=None):
73
+ q = q.transpose(1, 2)
74
+ k = k.transpose(1, 2)
75
+ v = v.transpose(1, 2)
76
+ scale = 1 / q.shape[-1] ** 0.5 if scale is None else scale
77
+ q = q * scale
78
+ attn = torch.matmul(q, k.transpose(-2, -1))
79
+ if attn_mask is not None:
80
+ attn = attn + attn_mask
81
+ attn = attn.softmax(-1)
82
+ out = attn @ v
83
+ return out.transpose(1, 2)
84
+
85
+
86
+ def attention(
87
+ q,
88
+ k,
89
+ v,
90
+ attn_impl: Optional[str] = None,
91
+ attn_mask: Optional[torch.Tensor] = None,
92
+ scale: Optional[float] = None,
93
+ ):
94
+ """
95
+ q: [B, Lq, Nq, C1]
96
+ k: [B, Lk, Nk, C1]
97
+ v: [B, Lk, Nk, C2]
98
+ """
99
+ assert attn_impl in [
100
+ None,
101
+ "auto",
102
+ "eager",
103
+ "flash_attn_2",
104
+ "flash_attn_3",
105
+ "xformers",
106
+ "sdpa",
107
+ "sage_attn",
108
+ "sparge_attn",
109
+ ]
110
+ if attn_impl is None or attn_impl == "auto":
111
+ if FLASH_ATTN_3_AVAILABLE:
112
+ return flash_attn3(q, k, v, softmax_scale=scale)
113
+ elif FLASH_ATTN_2_AVAILABLE:
114
+ return flash_attn2(q, k, v, softmax_scale=scale)
115
+ elif XFORMERS_AVAILABLE:
116
+ return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
117
+ elif SDPA_AVAILABLE:
118
+ return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
119
+ else:
120
+ return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
121
+ else:
122
+ if attn_impl == "eager":
123
+ return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
124
+ elif attn_impl == "flash_attn_3":
125
+ return flash_attn3(q, k, v, softmax_scale=scale)
126
+ elif attn_impl == "flash_attn_2":
127
+ return flash_attn2(q, k, v, softmax_scale=scale)
128
+ elif attn_impl == "xformers":
129
+ return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
130
+ elif attn_impl == "sdpa":
131
+ return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
132
+ elif attn_impl == "sage_attn":
133
+ return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
134
+ elif attn_impl == "sparge_attn":
135
+ return sparge_attn(q, k, v, attn_mask=attn_mask, scale=scale)
136
+ else:
137
+ raise ValueError(f"Invalid attention implementation: {attn_impl}")
138
+
139
+
140
+ class Attention(nn.Module):
141
+ def __init__(
142
+ self,
143
+ q_dim,
144
+ num_heads,
145
+ head_dim,
146
+ kv_dim=None,
147
+ bias_q=False,
148
+ bias_kv=False,
149
+ bias_out=False,
150
+ scale=None,
151
+ attn_impl: Optional[str] = None,
152
+ device: str = "cuda:0",
153
+ dtype: torch.dtype = torch.float16,
154
+ ):
155
+ super().__init__()
156
+ dim_inner = head_dim * num_heads
157
+ kv_dim = kv_dim if kv_dim is not None else q_dim
158
+ self.num_heads = num_heads
159
+ self.head_dim = head_dim
160
+
161
+ self.to_q = nn.Linear(q_dim, dim_inner, bias=bias_q, device=device, dtype=dtype)
162
+ self.to_k = nn.Linear(kv_dim, dim_inner, bias=bias_kv, device=device, dtype=dtype)
163
+ self.to_v = nn.Linear(kv_dim, dim_inner, bias=bias_kv, device=device, dtype=dtype)
164
+ self.to_out = nn.Linear(dim_inner, q_dim, bias=bias_out, device=device, dtype=dtype)
165
+ self.attn_impl = attn_impl
166
+ self.scale = scale
167
+
168
+ def forward(
169
+ self,
170
+ x: torch.Tensor,
171
+ y: Optional[torch.Tensor] = None,
172
+ attn_mask: Optional[torch.Tensor] = None,
173
+ ):
174
+ if y is None:
175
+ y = x
176
+ q = rearrange(self.to_q(x), "b s (n d) -> b s n d", n=self.num_heads)
177
+ k = rearrange(self.to_k(y), "b s (n d) -> b s n d", n=self.num_heads)
178
+ v = rearrange(self.to_v(y), "b s (n d) -> b s n d", n=self.num_heads)
179
+ out = attention(q, k, v, attn_mask=attn_mask, attn_impl=self.attn_impl, scale=self.scale)
180
+ out = rearrange(out, "b s n d -> b s (n d)", n=self.num_heads)
181
+ return self.to_out(out)
182
+
183
+
184
+ def long_context_attention(
185
+ q,
186
+ k,
187
+ v,
188
+ attn_impl: Optional[str] = None,
189
+ attn_mask: Optional[torch.Tensor] = None,
190
+ scale: Optional[float] = None,
191
+ ):
192
+ """
193
+ q: [B, Lq, Nq, C1]
194
+ k: [B, Lk, Nk, C1]
195
+ v: [B, Lk, Nk, C2]
196
+ """
197
+ from yunchang import LongContextAttention
198
+ from yunchang.kernels import AttnType
199
+
200
+ assert attn_impl in [
201
+ None,
202
+ "auto",
203
+ "eager",
204
+ "flash_attn_2",
205
+ "flash_attn_3",
206
+ "xformers",
207
+ "sdpa",
208
+ "sage_attn",
209
+ "sparge_attn",
210
+ ]
211
+ if attn_impl is None or attn_impl == "auto":
212
+ if FLASH_ATTN_3_AVAILABLE:
213
+ attn_func = LongContextAttention(attn_type=AttnType.FA3)
214
+ elif FLASH_ATTN_2_AVAILABLE:
215
+ attn_func = LongContextAttention(attn_type=AttnType.FA)
216
+ elif SDPA_AVAILABLE:
217
+ attn_func = LongContextAttention(attn_type=AttnType.TORCH)
218
+ else:
219
+ raise ValueError("No available long context attention implementation")
220
+ else:
221
+ if attn_impl == "flash_attn_3":
222
+ attn_func = LongContextAttention(attn_type=AttnType.FA3)
223
+ elif attn_impl == "flash_attn_2":
224
+ attn_func = LongContextAttention(attn_type=AttnType.FA)
225
+ elif attn_impl == "sdpa":
226
+ attn_func = LongContextAttention(attn_type=AttnType.TORCH)
227
+ elif attn_impl == "sage_attn":
228
+ attn_func = LongContextAttention(attn_type=AttnType.SAGE_FP8)
229
+ elif attn_impl == "sparge_attn":
230
+ attn_func = LongContextAttention(attn_type=AttnType.SPARSE_SAGE)
231
+ else:
232
+ raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
233
+ return attn_func(q, k, v, softmax_scale=scale)
@@ -51,12 +51,12 @@ class BasicTransformerBlock(nn.Module):
51
51
  def forward(self, hidden_states, encoder_hidden_states):
52
52
  # 1. Self-Attention
53
53
  norm_hidden_states = self.norm1(hidden_states)
54
- attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
54
+ attn_output = self.attn1(norm_hidden_states)
55
55
  hidden_states = attn_output + hidden_states
56
56
 
57
57
  # 2. Cross-Attention
58
58
  norm_hidden_states = self.norm2(hidden_states)
59
- attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
59
+ attn_output = self.attn2(norm_hidden_states, y=encoder_hidden_states)
60
60
  hidden_states = attn_output + hidden_states
61
61
 
62
62
  # 3. Feed-forward