diffsynth-engine 0.6.1.dev23__tar.gz → 0.6.1.dev25__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (210) hide show
  1. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/PKG-INFO +1 -1
  2. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/__init__.py +6 -2
  3. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/configs/__init__.py +10 -6
  4. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/configs/pipeline.py +2 -25
  5. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/basic/transformer_helper.py +36 -2
  6. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/basic/video_sparse_attention.py +4 -1
  7. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/qwen_image/qwen_image_dit.py +13 -17
  8. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/pipelines/base.py +30 -2
  9. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/pipelines/flux_image.py +2 -2
  10. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/pipelines/qwen_image.py +17 -7
  11. diffsynth_engine-0.6.1.dev25/diffsynth_engine/pipelines/utils.py +71 -0
  12. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/pipelines/wan_s2v.py +1 -1
  13. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/pipelines/wan_video.py +8 -4
  14. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/tokenizers/base.py +6 -0
  15. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/tokenizers/qwen2.py +12 -4
  16. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/parallel.py +6 -7
  17. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine.egg-info/PKG-INFO +1 -1
  18. diffsynth_engine-0.6.1.dev23/diffsynth_engine/pipelines/utils.py +0 -19
  19. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/.gitattributes +0 -0
  20. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/.gitignore +0 -0
  21. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/.pre-commit-config.yaml +0 -0
  22. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/LICENSE +0 -0
  23. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/MANIFEST.in +0 -0
  24. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/README.md +0 -0
  25. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/assets/dingtalk.png +0 -0
  26. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/assets/showcase.jpeg +0 -0
  27. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/assets/tongyi.svg +0 -0
  28. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/__init__.py +0 -0
  29. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
  30. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
  31. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
  32. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
  33. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
  34. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
  35. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  36. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
  37. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
  38. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
  39. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
  40. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
  41. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
  42. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
  43. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  44. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
  45. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  46. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
  47. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
  48. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
  49. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
  50. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
  51. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
  52. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
  53. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
  54. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
  55. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/components/vae.json +0 -0
  56. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
  57. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
  58. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
  59. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_config.json +0 -0
  60. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_vision_config.json +0 -0
  61. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae.json +0 -0
  62. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae_keymap.json +0 -0
  63. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
  64. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
  65. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
  66. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
  67. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
  68. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
  69. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/wan/dit/wan2.1_flf2v_14b.json +0 -0
  70. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/wan/dit/wan2.1_i2v_14b.json +0 -0
  71. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_1.3b.json +0 -0
  72. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_14b.json +0 -0
  73. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/wan/dit/wan2.2_i2v_a14b.json +0 -0
  74. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/wan/dit/wan2.2_s2v_14b.json +0 -0
  75. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/wan/dit/wan2.2_t2v_a14b.json +0 -0
  76. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/wan/dit/wan2.2_ti2v_5b.json +0 -0
  77. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +0 -0
  78. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/wan/vae/wan2.1_vae.json +0 -0
  79. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/wan/vae/wan2.2_vae.json +0 -0
  80. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/models/wan/vae/wan_vae_keymap.json +0 -0
  81. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
  82. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
  83. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
  84. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
  85. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
  86. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  87. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
  88. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
  89. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/qwen_image/qwen2_vl_image_processor.json +0 -0
  90. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/added_tokens.json +0 -0
  91. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/merges.txt +0 -0
  92. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/special_tokens_map.json +0 -0
  93. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer.json +0 -0
  94. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer_config.json +0 -0
  95. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/vocab.json +0 -0
  96. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
  97. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
  98. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
  99. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
  100. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
  101. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
  102. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
  103. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
  104. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
  105. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  106. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
  107. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
  108. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/configs/controlnet.py +0 -0
  109. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/kernels/__init__.py +0 -0
  110. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/__init__.py +0 -0
  111. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/base.py +0 -0
  112. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/basic/__init__.py +0 -0
  113. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/basic/attention.py +0 -0
  114. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/basic/lora.py +0 -0
  115. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
  116. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/basic/timestep.py +0 -0
  117. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/basic/unet_helper.py +0 -0
  118. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/flux/__init__.py +0 -0
  119. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/flux/flux_controlnet.py +0 -0
  120. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/flux/flux_dit.py +0 -0
  121. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/flux/flux_dit_fbcache.py +0 -0
  122. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/flux/flux_ipadapter.py +0 -0
  123. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/flux/flux_redux.py +0 -0
  124. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
  125. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/flux/flux_vae.py +0 -0
  126. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/hunyuan3d/__init__.py +0 -0
  127. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +0 -0
  128. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/hunyuan3d/hunyuan3d_dit.py +0 -0
  129. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/hunyuan3d/hunyuan3d_vae.py +0 -0
  130. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/hunyuan3d/moe.py +0 -0
  131. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/hunyuan3d/surface_extractor.py +0 -0
  132. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/hunyuan3d/volume_decoder.py +0 -0
  133. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/qwen_image/__init__.py +0 -0
  134. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/qwen_image/qwen2_5_vl.py +0 -0
  135. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +0 -0
  136. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/qwen_image/qwen_image_vae.py +0 -0
  137. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/sd/__init__.py +0 -0
  138. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/sd/sd_controlnet.py +0 -0
  139. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
  140. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/sd/sd_unet.py +0 -0
  141. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/sd/sd_vae.py +0 -0
  142. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/sd3/__init__.py +0 -0
  143. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/sd3/sd3_dit.py +0 -0
  144. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
  145. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
  146. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/sdxl/__init__.py +0 -0
  147. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/sdxl/sdxl_controlnet.py +0 -0
  148. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
  149. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/sdxl/sdxl_unet.py +0 -0
  150. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
  151. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/text_encoder/__init__.py +0 -0
  152. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/text_encoder/clip.py +0 -0
  153. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/text_encoder/siglip.py +0 -0
  154. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/text_encoder/t5.py +0 -0
  155. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/vae/__init__.py +0 -0
  156. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/vae/vae.py +0 -0
  157. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/wan/__init__.py +0 -0
  158. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/wan/wan_audio_encoder.py +0 -0
  159. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/wan/wan_dit.py +0 -0
  160. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
  161. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/wan/wan_s2v_dit.py +0 -0
  162. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
  163. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/models/wan/wan_vae.py +0 -0
  164. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/pipelines/__init__.py +0 -0
  165. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/pipelines/hunyuan3d_shape.py +0 -0
  166. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/pipelines/sd_image.py +0 -0
  167. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/pipelines/sdxl_image.py +0 -0
  168. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/processor/__init__.py +0 -0
  169. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/processor/canny_processor.py +0 -0
  170. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/processor/depth_processor.py +0 -0
  171. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/tokenizers/__init__.py +0 -0
  172. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/tokenizers/clip.py +0 -0
  173. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/tokenizers/qwen2_vl_image_processor.py +0 -0
  174. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/tokenizers/qwen2_vl_processor.py +0 -0
  175. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/tokenizers/t5.py +0 -0
  176. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/tokenizers/wan.py +0 -0
  177. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/tools/__init__.py +0 -0
  178. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/tools/flux_inpainting_tool.py +0 -0
  179. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/tools/flux_outpainting_tool.py +0 -0
  180. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/tools/flux_reference_tool.py +0 -0
  181. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/tools/flux_replace_tool.py +0 -0
  182. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/__init__.py +0 -0
  183. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/cache.py +0 -0
  184. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/constants.py +0 -0
  185. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/download.py +0 -0
  186. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/env.py +0 -0
  187. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/flag.py +0 -0
  188. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/fp8_linear.py +0 -0
  189. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/gguf.py +0 -0
  190. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/image.py +0 -0
  191. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/loader.py +0 -0
  192. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/lock.py +0 -0
  193. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/logging.py +0 -0
  194. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/memory/__init__.py +0 -0
  195. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/memory/linear_regression.py +0 -0
  196. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/memory/memory_predcit_model.py +0 -0
  197. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/offload.py +0 -0
  198. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/onnx.py +0 -0
  199. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/platform.py +0 -0
  200. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/prompt.py +0 -0
  201. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine/utils/video.py +0 -0
  202. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine.egg-info/SOURCES.txt +0 -0
  203. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
  204. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine.egg-info/requires.txt +0 -0
  205. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/diffsynth_engine.egg-info/top_level.txt +0 -0
  206. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/docs/tutorial.md +0 -0
  207. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/docs/tutorial_zh.md +0 -0
  208. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/pyproject.toml +0 -0
  209. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/setup.cfg +0 -0
  210. {diffsynth_engine-0.6.1.dev23 → diffsynth_engine-0.6.1.dev25}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev23
3
+ Version: 0.6.1.dev25
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -12,11 +12,13 @@ from .configs import (
12
12
  WanStateDicts,
13
13
  QwenImageStateDicts,
14
14
  AttnImpl,
15
+ SpargeAttentionParams,
16
+ VideoSparseAttentionParams,
17
+ LoraConfig,
15
18
  ControlNetParams,
16
19
  ControlType,
17
20
  QwenImageControlNetParams,
18
21
  QwenImageControlType,
19
- LoraConfig,
20
22
  )
21
23
  from .pipelines import (
22
24
  SDImagePipeline,
@@ -59,6 +61,9 @@ __all__ = [
59
61
  "WanStateDicts",
60
62
  "QwenImageStateDicts",
61
63
  "AttnImpl",
64
+ "SpargeAttentionParams",
65
+ "VideoSparseAttentionParams",
66
+ "LoraConfig",
62
67
  "ControlNetParams",
63
68
  "ControlType",
64
69
  "QwenImageControlNetParams",
@@ -79,7 +84,6 @@ __all__ = [
79
84
  "FluxIPAdapterRefTool",
80
85
  "FluxReplaceByControlTool",
81
86
  "FluxReduxRefTool",
82
- "LoraConfig",
83
87
  "fetch_model",
84
88
  "fetch_modelscope_model",
85
89
  "register_fetch_modelscope_model",
@@ -17,14 +17,16 @@ from .pipeline import (
17
17
  WanStateDicts,
18
18
  WanS2VStateDicts,
19
19
  QwenImageStateDicts,
20
- LoraConfig,
21
20
  AttnImpl,
21
+ SpargeAttentionParams,
22
+ VideoSparseAttentionParams,
23
+ LoraConfig,
22
24
  )
23
25
  from .controlnet import (
24
26
  ControlType,
25
27
  ControlNetParams,
26
- QwenImageControlNetParams,
27
28
  QwenImageControlType,
29
+ QwenImageControlNetParams,
28
30
  )
29
31
 
30
32
  __all__ = [
@@ -46,10 +48,12 @@ __all__ = [
46
48
  "WanStateDicts",
47
49
  "WanS2VStateDicts",
48
50
  "QwenImageStateDicts",
49
- "QwenImageControlType",
50
- "QwenImageControlNetParams",
51
+ "AttnImpl",
52
+ "SpargeAttentionParams",
53
+ "VideoSparseAttentionParams",
54
+ "LoraConfig",
51
55
  "ControlType",
52
56
  "ControlNetParams",
53
- "LoraConfig",
54
- "AttnImpl",
57
+ "QwenImageControlType",
58
+ "QwenImageControlNetParams",
55
59
  ]
@@ -5,7 +5,6 @@ from dataclasses import dataclass, field
5
5
  from typing import List, Dict, Tuple, Optional
6
6
 
7
7
  from diffsynth_engine.configs.controlnet import ControlType
8
- from diffsynth_engine.models.basic.video_sparse_attention import get_vsa_kwargs
9
8
 
10
9
 
11
10
  @dataclass
@@ -52,23 +51,6 @@ class AttentionConfig:
52
51
  dit_attn_impl: AttnImpl = AttnImpl.AUTO
53
52
  attn_params: Optional[SpargeAttentionParams | VideoSparseAttentionParams] = None
54
53
 
55
- def get_attn_kwargs(self, latents: torch.Tensor, device: str) -> Dict:
56
- attn_kwargs = {"attn_impl": self.dit_attn_impl.value}
57
- if isinstance(self.attn_params, SpargeAttentionParams):
58
- assert self.dit_attn_impl == AttnImpl.SPARGE
59
- attn_kwargs.update(
60
- {
61
- "smooth_k": self.attn_params.smooth_k,
62
- "simthreshd1": self.attn_params.simthreshd1,
63
- "cdfthreshd": self.attn_params.cdfthreshd,
64
- "pvthreshd": self.attn_params.pvthreshd,
65
- }
66
- )
67
- elif isinstance(self.attn_params, VideoSparseAttentionParams):
68
- assert self.dit_attn_impl == AttnImpl.VSA
69
- attn_kwargs.update(get_vsa_kwargs(latents.shape[2:], (1, 2, 2), self.attn_params.sparsity, device=device))
70
- return attn_kwargs
71
-
72
54
 
73
55
  @dataclass
74
56
  class OptimizationConfig:
@@ -262,16 +244,11 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
262
244
  encoder_dtype: torch.dtype = torch.bfloat16
263
245
  vae_dtype: torch.dtype = torch.float32
264
246
 
247
+ load_encoder: bool = True
248
+
265
249
  # override OptimizationConfig
266
250
  fbcache_relative_l1_threshold = 0.009
267
251
 
268
- # override BaseConfig
269
- vae_tiled: bool = True
270
- vae_tile_size: Tuple[int, int] = (34, 34)
271
- vae_tile_stride: Tuple[int, int] = (18, 16)
272
-
273
- load_encoder: bool = True
274
-
275
252
  @classmethod
276
253
  def basic_config(
277
254
  cls,
@@ -1,5 +1,6 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
3
4
  import math
4
5
 
5
6
 
@@ -91,8 +92,8 @@ class NewGELUActivation(nn.Module):
91
92
  the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
92
93
  """
93
94
 
94
- def forward(self, input: "torch.Tensor") -> "torch.Tensor":
95
- return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
95
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
96
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
96
97
 
97
98
 
98
99
  class ApproximateGELU(nn.Module):
@@ -115,3 +116,36 @@ class ApproximateGELU(nn.Module):
115
116
  def forward(self, x: torch.Tensor) -> torch.Tensor:
116
117
  x = self.proj(x)
117
118
  return x * torch.sigmoid(1.702 * x)
119
+
120
+
121
+ class GELU(nn.Module):
122
+ r"""
123
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
124
+
125
+ Parameters:
126
+ dim_in (`int`): The number of channels in the input.
127
+ dim_out (`int`): The number of channels in the output.
128
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
129
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ dim_in: int,
135
+ dim_out: int,
136
+ approximate: str = "none",
137
+ bias: bool = True,
138
+ device: str = "cuda:0",
139
+ dtype: torch.dtype = torch.float16,
140
+ ):
141
+ super().__init__()
142
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias, device=device, dtype=dtype)
143
+ self.approximate = approximate
144
+
145
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
146
+ return F.gelu(gate, approximate=self.approximate)
147
+
148
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
149
+ x = self.proj(x)
150
+ x = self.gelu(x)
151
+ return x
@@ -2,9 +2,12 @@ import torch
2
2
  import math
3
3
  import functools
4
4
 
5
- from vsa import video_sparse_attn as vsa_core
5
+ from diffsynth_engine.utils.flag import VIDEO_SPARSE_ATTN_AVAILABLE
6
6
  from diffsynth_engine.utils.parallel import get_sp_ulysses_group, get_sp_ring_world_size
7
7
 
8
+ if VIDEO_SPARSE_ATTN_AVAILABLE:
9
+ from vsa import video_sparse_attn as vsa_core
10
+
8
11
  VSA_TILE_SIZE = (4, 4, 4)
9
12
 
10
13
 
@@ -6,7 +6,7 @@ from einops import rearrange
6
6
  from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
7
7
  from diffsynth_engine.models.basic import attention as attention_ops
8
8
  from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
9
- from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, ApproximateGELU, RMSNorm
9
+ from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, GELU, RMSNorm
10
10
  from diffsynth_engine.utils.gguf import gguf_inference
11
11
  from diffsynth_engine.utils.fp8_linear import fp8_inference
12
12
  from diffsynth_engine.utils.parallel import (
@@ -144,7 +144,7 @@ class QwenFeedForward(nn.Module):
144
144
  super().__init__()
145
145
  inner_dim = int(dim * 4)
146
146
  self.net = nn.ModuleList([])
147
- self.net.append(ApproximateGELU(dim, inner_dim, device=device, dtype=dtype))
147
+ self.net.append(GELU(dim, inner_dim, approximate="tanh", device=device, dtype=dtype))
148
148
  self.net.append(nn.Dropout(dropout))
149
149
  self.net.append(nn.Linear(inner_dim, dim_out, device=device, dtype=dtype))
150
150
 
@@ -155,8 +155,8 @@ class QwenFeedForward(nn.Module):
155
155
 
156
156
 
157
157
  def apply_rotary_emb_qwen(x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]):
158
- x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
159
- x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
158
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) # (b, s, h, d) -> (b, s, h, d/2, 2)
159
+ x_out = torch.view_as_real(x_rotated * freqs_cis.unsqueeze(1)).flatten(3) # (b, s, h, d/2, 2) -> (b, s, h, d)
160
160
  return x_out.type_as(x)
161
161
 
162
162
 
@@ -200,13 +200,13 @@ class QwenDoubleStreamAttention(nn.Module):
200
200
  img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
201
201
  txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
202
202
 
203
- img_q = rearrange(img_q, "b s (h d) -> b h s d", h=self.num_heads)
204
- img_k = rearrange(img_k, "b s (h d) -> b h s d", h=self.num_heads)
205
- img_v = rearrange(img_v, "b s (h d) -> b h s d", h=self.num_heads)
203
+ img_q = rearrange(img_q, "b s (h d) -> b s h d", h=self.num_heads)
204
+ img_k = rearrange(img_k, "b s (h d) -> b s h d", h=self.num_heads)
205
+ img_v = rearrange(img_v, "b s (h d) -> b s h d", h=self.num_heads)
206
206
 
207
- txt_q = rearrange(txt_q, "b s (h d) -> b h s d", h=self.num_heads)
208
- txt_k = rearrange(txt_k, "b s (h d) -> b h s d", h=self.num_heads)
209
- txt_v = rearrange(txt_v, "b s (h d) -> b h s d", h=self.num_heads)
207
+ txt_q = rearrange(txt_q, "b s (h d) -> b s h d", h=self.num_heads)
208
+ txt_k = rearrange(txt_k, "b s (h d) -> b s h d", h=self.num_heads)
209
+ txt_v = rearrange(txt_v, "b s (h d) -> b s h d", h=self.num_heads)
210
210
 
211
211
  img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
212
212
  txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
@@ -218,13 +218,9 @@ class QwenDoubleStreamAttention(nn.Module):
218
218
  txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
219
219
  txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
220
220
 
221
- joint_q = torch.cat([txt_q, img_q], dim=2)
222
- joint_k = torch.cat([txt_k, img_k], dim=2)
223
- joint_v = torch.cat([txt_v, img_v], dim=2)
224
-
225
- joint_q = joint_q.transpose(1, 2)
226
- joint_k = joint_k.transpose(1, 2)
227
- joint_v = joint_v.transpose(1, 2)
221
+ joint_q = torch.cat([txt_q, img_q], dim=1)
222
+ joint_k = torch.cat([txt_k, img_k], dim=1)
223
+ joint_v = torch.cat([txt_v, img_v], dim=1)
228
224
 
229
225
  attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
230
226
  joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **attn_kwargs)
@@ -5,7 +5,15 @@ from einops import rearrange
5
5
  from typing import Dict, List, Tuple, Union, Optional
6
6
  from PIL import Image
7
7
 
8
- from diffsynth_engine.configs import BaseConfig, BaseStateDicts, LoraConfig
8
+ from diffsynth_engine.configs import (
9
+ BaseConfig,
10
+ BaseStateDicts,
11
+ LoraConfig,
12
+ AttnImpl,
13
+ SpargeAttentionParams,
14
+ VideoSparseAttentionParams,
15
+ )
16
+ from diffsynth_engine.models.basic.video_sparse_attention import get_vsa_kwargs
9
17
  from diffsynth_engine.utils.offload import enable_sequential_cpu_offload, offload_model_to_dict, restore_model_from_dict
10
18
  from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
11
19
  from diffsynth_engine.utils.gguf import load_gguf_checkpoint
@@ -33,6 +41,7 @@ class BasePipeline:
33
41
  dtype=torch.float16,
34
42
  ):
35
43
  super().__init__()
44
+ self.config = None
36
45
  self.vae_tiled = vae_tiled
37
46
  self.vae_tile_size = vae_tile_size
38
47
  self.vae_tile_stride = vae_tile_stride
@@ -48,7 +57,7 @@ class BasePipeline:
48
57
  raise NotImplementedError()
49
58
 
50
59
  @classmethod
51
- def from_state_dict(cls, state_dicts: BaseStateDicts, pipeline_config: BaseConfig) -> "BasePipeline":
60
+ def from_state_dict(cls, state_dicts: BaseStateDicts, config: BaseConfig) -> "BasePipeline":
52
61
  raise NotImplementedError()
53
62
 
54
63
  def update_weights(self, state_dicts: BaseStateDicts) -> None:
@@ -260,6 +269,25 @@ class BasePipeline:
260
269
  )
261
270
  return init_latents, latents, sigmas, timesteps
262
271
 
272
+ def get_attn_kwargs(self, latents: torch.Tensor) -> Dict:
273
+ attn_kwargs = {"attn_impl": self.config.dit_attn_impl.value}
274
+ if isinstance(self.config.attn_params, SpargeAttentionParams):
275
+ assert self.config.dit_attn_impl == AttnImpl.SPARGE
276
+ attn_kwargs.update(
277
+ {
278
+ "smooth_k": self.config.attn_params.smooth_k,
279
+ "simthreshd1": self.config.attn_params.simthreshd1,
280
+ "cdfthreshd": self.config.attn_params.cdfthreshd,
281
+ "pvthreshd": self.config.attn_params.pvthreshd,
282
+ }
283
+ )
284
+ elif isinstance(self.config.attn_params, VideoSparseAttentionParams):
285
+ assert self.config.dit_attn_impl == AttnImpl.VSA
286
+ attn_kwargs.update(
287
+ get_vsa_kwargs(latents.shape[2:], (1, 2, 2), self.config.attn_params.sparsity, device=self.device)
288
+ )
289
+ return attn_kwargs
290
+
263
291
  def eval(self):
264
292
  for model_name in self.model_names:
265
293
  model = getattr(self, model_name)
@@ -751,7 +751,7 @@ class FluxImagePipeline(BasePipeline):
751
751
  latents = latents.to(self.dtype)
752
752
  self.load_models_to_device(["dit"])
753
753
 
754
- attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
754
+ attn_kwargs = self.get_attn_kwargs(latents)
755
755
  noise_pred = self.dit(
756
756
  hidden_states=latents,
757
757
  timestep=timestep,
@@ -886,7 +886,7 @@ class FluxImagePipeline(BasePipeline):
886
886
  empty_cache()
887
887
  param.model.to(self.device)
888
888
 
889
- attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
889
+ attn_kwargs = self.get_attn_kwargs(latents)
890
890
  double_block_output, single_block_output = param.model(
891
891
  hidden_states=latents,
892
892
  control_condition=control_condition,
@@ -24,7 +24,7 @@ from diffsynth_engine.models.qwen_image import (
24
24
  from diffsynth_engine.models.qwen_image import QwenImageVAE
25
25
  from diffsynth_engine.tokenizers import Qwen2TokenizerFast, Qwen2VLProcessor
26
26
  from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
27
- from diffsynth_engine.pipelines.utils import calculate_shift
27
+ from diffsynth_engine.pipelines.utils import calculate_shift, pad_and_concat
28
28
  from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
29
29
  from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
30
30
  from diffsynth_engine.utils.constants import (
@@ -148,9 +148,17 @@ class QwenImagePipeline(BasePipeline):
148
148
  self.prompt_template_encode_start_idx = 34
149
149
  # qwen image edit
150
150
  self.edit_system_prompt = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
151
- self.edit_prompt_template_encode = "<|im_start|>system\n" + self.edit_system_prompt + "<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
151
+ self.edit_prompt_template_encode = (
152
+ "<|im_start|>system\n"
153
+ + self.edit_system_prompt
154
+ + "<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
155
+ )
152
156
  # qwen image edit plus
153
- self.edit_plus_prompt_template_encode = "<|im_start|>system\n" + self.edit_system_prompt + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
157
+ self.edit_plus_prompt_template_encode = (
158
+ "<|im_start|>system\n"
159
+ + self.edit_system_prompt
160
+ + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
161
+ )
154
162
 
155
163
  self.edit_prompt_template_encode_start_idx = 64
156
164
 
@@ -200,7 +208,9 @@ class QwenImagePipeline(BasePipeline):
200
208
  )
201
209
  if config.load_encoder:
202
210
  logger.info(f"loading state dict from {config.encoder_path} ...")
203
- encoder_state_dict = cls.load_model_checkpoint(config.encoder_path, device="cpu", dtype=config.encoder_dtype)
211
+ encoder_state_dict = cls.load_model_checkpoint(
212
+ config.encoder_path, device="cpu", dtype=config.encoder_dtype
213
+ )
204
214
 
205
215
  state_dicts = QwenImageStateDicts(
206
216
  model=model_state_dict,
@@ -490,8 +500,8 @@ class QwenImagePipeline(BasePipeline):
490
500
  else:
491
501
  # cfg by predict noise in one batch
492
502
  bs, _, h, w = latents.shape
493
- prompt_emb = torch.cat([prompt_emb, negative_prompt_emb], dim=0)
494
- prompt_emb_mask = torch.cat([prompt_emb_mask, negative_prompt_emb_mask], dim=0)
503
+ prompt_emb = pad_and_concat(prompt_emb, negative_prompt_emb)
504
+ prompt_emb_mask = pad_and_concat(prompt_emb_mask, negative_prompt_emb_mask)
495
505
  if entity_prompt_embs is not None:
496
506
  entity_prompt_embs = [
497
507
  torch.cat([x, y], dim=0) for x, y in zip(entity_prompt_embs, negative_entity_prompt_embs)
@@ -539,7 +549,7 @@ class QwenImagePipeline(BasePipeline):
539
549
  entity_masks: Optional[List[torch.Tensor]] = None,
540
550
  ):
541
551
  self.load_models_to_device(["dit"])
542
- attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
552
+ attn_kwargs = self.get_attn_kwargs(latents)
543
553
  noise_pred = self.dit(
544
554
  image=latents,
545
555
  edit=image_latents,
@@ -0,0 +1,71 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def accumulate(result, new_item):
6
+ if result is None:
7
+ return new_item
8
+ for i, item in enumerate(new_item):
9
+ result[i] += item
10
+ return result
11
+
12
+
13
+ def calculate_shift(
14
+ image_seq_len,
15
+ base_seq_len: int = 256,
16
+ max_seq_len: int = 4096,
17
+ base_shift: float = 0.5,
18
+ max_shift: float = 1.15,
19
+ ):
20
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
21
+ b = base_shift - m * base_seq_len
22
+ mu = image_seq_len * m + b
23
+ return mu
24
+
25
+
26
+ def pad_and_concat(
27
+ tensor1: torch.Tensor,
28
+ tensor2: torch.Tensor,
29
+ concat_dim: int = 0,
30
+ pad_dim: int = 1,
31
+ ) -> torch.Tensor:
32
+ """
33
+ Concatenate two tensors along a specified dimension after padding along another dimension.
34
+
35
+ Assumes input tensors have shape (b, s, d), where:
36
+ - b: batch dimension
37
+ - s: sequence dimension (may differ)
38
+ - d: feature dimension
39
+
40
+ Args:
41
+ tensor1: First tensor with shape (b1, s1, d)
42
+ tensor2: Second tensor with shape (b2, s2, d)
43
+ concat_dim: Dimension to concatenate along, default is 0 (batch dimension)
44
+ pad_dim: Dimension to pad along, default is 1 (sequence dimension)
45
+
46
+ Returns:
47
+ Concatenated tensor, shape depends on concat_dim and pad_dim choices
48
+ """
49
+ assert tensor1.dim() == tensor2.dim(), "Both tensors must have the same number of dimensions"
50
+ assert concat_dim != pad_dim, "concat_dim and pad_dim cannot be the same"
51
+
52
+ len1, len2 = tensor1.shape[pad_dim], tensor2.shape[pad_dim]
53
+ max_len = max(len1, len2)
54
+
55
+ # Calculate the position of pad_dim in the padding list
56
+ # Padding format: from the last dimension, each pair represents (dim_n_left, dim_n_right, ..., dim_0_left, dim_0_right)
57
+ ndim = tensor1.dim()
58
+ padding = [0] * (2 * ndim)
59
+ pad_right_idx = -2 * pad_dim - 1
60
+
61
+ if len1 < max_len:
62
+ pad_len = max_len - len1
63
+ padding[pad_right_idx] = pad_len
64
+ tensor1 = F.pad(tensor1, padding, mode="constant", value=0)
65
+ elif len2 < max_len:
66
+ pad_len = max_len - len2
67
+ padding[pad_right_idx] = pad_len
68
+ tensor2 = F.pad(tensor2, padding, mode="constant", value=0)
69
+
70
+ # Concatenate along the specified dimension
71
+ return torch.cat([tensor1, tensor2], dim=concat_dim)
@@ -394,7 +394,7 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
394
394
  void_audio_input: torch.Tensor | None = None,
395
395
  ):
396
396
  latents = latents.to(dtype=self.config.model_dtype, device=self.device)
397
- attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
397
+ attn_kwargs = self.get_attn_kwargs(latents)
398
398
 
399
399
  noise_pred = model(
400
400
  x=latents,
@@ -144,7 +144,7 @@ class WanVideoPipeline(BasePipeline):
144
144
  lora_list: List[Tuple[str, float]],
145
145
  fused: bool = True,
146
146
  save_original_weight: bool = False,
147
- lora_converter: Optional[WanLoRAConverter] = None
147
+ lora_converter: Optional[WanLoRAConverter] = None,
148
148
  ):
149
149
  assert self.config.tp_degree is None or self.config.tp_degree == 1, (
150
150
  "load LoRA is not allowed when tensor parallel is enabled; "
@@ -156,11 +156,15 @@ class WanVideoPipeline(BasePipeline):
156
156
  )
157
157
  super().load_loras(lora_list, fused, save_original_weight, lora_converter)
158
158
 
159
- def load_loras_low_noise(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
159
+ def load_loras_low_noise(
160
+ self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False
161
+ ):
160
162
  assert self.dit2 is not None, "low noise LoRA can only be applied to Wan2.2"
161
163
  self.load_loras(lora_list, fused, save_original_weight, self.low_noise_lora_converter)
162
164
 
163
- def load_loras_high_noise(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
165
+ def load_loras_high_noise(
166
+ self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False
167
+ ):
164
168
  assert self.dit2 is not None, "high noise LoRA can only be applied to Wan2.2"
165
169
  self.load_loras(lora_list, fused, save_original_weight)
166
170
 
@@ -323,7 +327,7 @@ class WanVideoPipeline(BasePipeline):
323
327
 
324
328
  def predict_noise(self, model, latents, image_clip_feature, image_y, timestep, context):
325
329
  latents = latents.to(dtype=self.config.model_dtype, device=self.device)
326
- attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
330
+ attn_kwargs = self.get_attn_kwargs(latents)
327
331
 
328
332
  noise_pred = model(
329
333
  x=latents,
@@ -1,10 +1,16 @@
1
1
  # Modified from transformers.tokenization_utils_base
2
2
  from typing import Dict, List, Union, overload
3
+ from enum import Enum
3
4
 
4
5
 
5
6
  TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
6
7
 
7
8
 
9
+ class PaddingStrategy(str, Enum):
10
+ LONGEST = "longest"
11
+ MAX_LENGTH = "max_length"
12
+
13
+
8
14
  class BaseTokenizer:
9
15
  SPECIAL_TOKENS_ATTRIBUTES = [
10
16
  "bos_token",
@@ -4,7 +4,7 @@ import torch
4
4
  from typing import Dict, List, Union, Optional
5
5
  from tokenizers import Tokenizer as TokenizerFast, AddedToken
6
6
 
7
- from diffsynth_engine.tokenizers.base import BaseTokenizer, TOKENIZER_CONFIG_FILE
7
+ from diffsynth_engine.tokenizers.base import BaseTokenizer, PaddingStrategy, TOKENIZER_CONFIG_FILE
8
8
 
9
9
 
10
10
  VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
@@ -165,22 +165,28 @@ class Qwen2TokenizerFast(BaseTokenizer):
165
165
  texts: Union[str, List[str]],
166
166
  max_length: Optional[int] = None,
167
167
  padding_side: Optional[str] = None,
168
+ padding_strategy: Union[PaddingStrategy, str] = "longest",
168
169
  **kwargs,
169
170
  ) -> Dict[str, "torch.Tensor"]:
170
171
  """
171
172
  Tokenize text and prepare for model inputs.
172
173
 
173
174
  Args:
174
- text (`str`, `List[str]`, *optional*):
175
+ texts (`str`, `List[str]`):
175
176
  The sequence or batch of sequences to be encoded.
176
177
 
177
178
  max_length (`int`, *optional*):
178
- Each encoded sequence will be truncated or padded to max_length.
179
+ Maximum length of the encoded sequences.
179
180
 
180
181
  padding_side (`str`, *optional*):
181
182
  The side on which the padding should be applied. Should be selected between `"right"` and `"left"`.
182
183
  Defaults to `"right"`.
183
184
 
185
+ padding_strategy (`PaddingStrategy`, `str`, *optional*):
186
+ If `"longest"`, will pad the sequences to the longest sequence in the batch.
187
+ If `"max_length"`, will pad the sequences to the `max_length` argument.
188
+ Defaults to `"longest"`.
189
+
184
190
  Returns:
185
191
  `Dict[str, "torch.Tensor"]`: tensor dict compatible with model_input_names.
186
192
  """
@@ -190,7 +196,9 @@ class Qwen2TokenizerFast(BaseTokenizer):
190
196
 
191
197
  batch_ids = self.batch_encode(texts)
192
198
  ids_lens = [len(ids_) for ids_ in batch_ids]
193
- max_length = max_length if max_length is not None else min(max(ids_lens), self.model_max_length)
199
+ max_length = max_length if max_length is not None else self.model_max_length
200
+ if padding_strategy == PaddingStrategy.LONGEST:
201
+ max_length = min(max(ids_lens), max_length)
194
202
  padding_side = padding_side if padding_side is not None else self.padding_side
195
203
 
196
204
  encoded = torch.zeros(len(texts), max_length, dtype=torch.long)
@@ -19,8 +19,6 @@ from typing import Dict, List, Set, Type, Union, Optional
19
19
  from queue import Empty
20
20
 
21
21
  import diffsynth_engine.models.basic.attention as attention_ops
22
- from diffsynth_engine.models import PreTrainedModel
23
- from diffsynth_engine.pipelines import BasePipeline
24
22
  from diffsynth_engine.utils.platform import empty_cache
25
23
  from diffsynth_engine.utils import logging
26
24
 
@@ -300,14 +298,15 @@ def _worker_loop(
300
298
  world_size=world_size,
301
299
  )
302
300
 
303
- def wrap_for_parallel(module: Union[PreTrainedModel, BasePipeline]):
304
- if isinstance(module, BasePipeline):
305
- for model_name in module.model_names:
306
- if isinstance(submodule := getattr(module, model_name), PreTrainedModel):
301
+ def wrap_for_parallel(module):
302
+ if hasattr(module, "model_names"):
303
+ for model_name in getattr(module, "model_names"):
304
+ submodule = getattr(module, model_name)
305
+ if getattr(submodule, "_supports_parallelization", False):
307
306
  setattr(module, model_name, wrap_for_parallel(submodule))
308
307
  return module
309
308
 
310
- if not module._supports_parallelization:
309
+ if not getattr(module, "_supports_parallelization", False):
311
310
  return module
312
311
 
313
312
  if tp_degree > 1:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev23
3
+ Version: 0.6.1.dev25
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -1,19 +0,0 @@
1
- def accumulate(result, new_item):
2
- if result is None:
3
- return new_item
4
- for i, item in enumerate(new_item):
5
- result[i] += item
6
- return result
7
-
8
-
9
- def calculate_shift(
10
- image_seq_len,
11
- base_seq_len: int = 256,
12
- max_seq_len: int = 4096,
13
- base_shift: float = 0.5,
14
- max_shift: float = 1.15,
15
- ):
16
- m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
17
- b = base_shift - m * base_seq_len
18
- mu = image_seq_len * m + b
19
- return mu