diffsynth-engine 0.3.6.dev8__tar.gz → 0.3.6.dev10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/PKG-INFO +1 -1
  2. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/__init__.py +10 -8
  3. diffsynth_engine-0.3.6.dev10/diffsynth_engine/configs/__init__.py +23 -0
  4. diffsynth_engine-0.3.6.dev10/diffsynth_engine/configs/controlnet.py +17 -0
  5. diffsynth_engine-0.3.6.dev10/diffsynth_engine/configs/pipeline.py +206 -0
  6. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/basic/attention.py +43 -4
  7. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/flux/flux_controlnet.py +8 -5
  8. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/flux/flux_dit.py +22 -16
  9. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/flux/flux_dit_fbcache.py +5 -5
  10. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  11. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd/sd_controlnet.py +2 -4
  12. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sdxl/sdxl_controlnet.py +1 -2
  13. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/wan/wan_dit.py +15 -15
  14. diffsynth_engine-0.3.6.dev10/diffsynth_engine/pipelines/__init__.py +17 -0
  15. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/pipelines/base.py +14 -65
  16. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/pipelines/flux_image.py +85 -158
  17. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/pipelines/sd_image.py +30 -64
  18. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/pipelines/sdxl_image.py +39 -71
  19. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/pipelines/wan_video.py +66 -105
  20. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tools/flux_inpainting_tool.py +7 -3
  21. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tools/flux_outpainting_tool.py +7 -3
  22. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tools/flux_reference_tool.py +21 -5
  23. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tools/flux_replace_tool.py +15 -3
  24. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/fp8_linear.py +14 -5
  25. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/parallel.py +1 -1
  26. diffsynth_engine-0.3.6.dev10/diffsynth_engine/utils/platform.py +20 -0
  27. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine.egg-info/PKG-INFO +1 -1
  28. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine.egg-info/SOURCES.txt +3 -0
  29. diffsynth_engine-0.3.6.dev8/diffsynth_engine/pipelines/__init__.py +0 -20
  30. diffsynth_engine-0.3.6.dev8/diffsynth_engine/utils/platform.py +0 -12
  31. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/.gitignore +0 -0
  32. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/.pre-commit-config.yaml +0 -0
  33. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/LICENSE +0 -0
  34. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/MANIFEST.in +0 -0
  35. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/README.md +0 -0
  36. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/assets/dingtalk.png +0 -0
  37. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/assets/showcase.jpeg +0 -0
  38. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/__init__.py +0 -0
  39. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
  40. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
  41. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
  42. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
  43. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
  44. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
  45. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  46. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
  47. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
  48. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
  49. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
  50. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
  51. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
  52. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
  53. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  54. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
  55. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  56. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
  57. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
  58. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
  59. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
  60. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
  61. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
  62. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
  63. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
  64. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
  65. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/components/vae.json +0 -0
  66. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
  67. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
  68. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
  69. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
  70. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
  71. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
  72. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
  73. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
  74. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
  75. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +0 -0
  76. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/wan/dit/14b-flf2v.json +0 -0
  77. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/wan/dit/14b-i2v.json +0 -0
  78. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/wan/dit/14b-t2v.json +0 -0
  79. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
  80. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
  81. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
  82. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
  83. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
  84. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  85. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
  86. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
  87. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
  88. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
  89. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
  90. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
  91. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
  92. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
  93. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
  94. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
  95. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
  96. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  97. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
  98. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
  99. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/kernels/__init__.py +0 -0
  100. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/__init__.py +0 -0
  101. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/base.py +0 -0
  102. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/basic/__init__.py +0 -0
  103. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/basic/lora.py +0 -0
  104. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
  105. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/basic/timestep.py +0 -0
  106. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/basic/transformer_helper.py +0 -0
  107. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/basic/unet_helper.py +0 -0
  108. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/flux/__init__.py +0 -0
  109. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/flux/flux_redux.py +0 -0
  110. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
  111. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/flux/flux_vae.py +0 -0
  112. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd/__init__.py +0 -0
  113. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
  114. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd/sd_unet.py +0 -0
  115. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd/sd_vae.py +0 -0
  116. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd3/__init__.py +0 -0
  117. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd3/sd3_dit.py +0 -0
  118. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
  119. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
  120. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sdxl/__init__.py +0 -0
  121. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
  122. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sdxl/sdxl_unet.py +0 -0
  123. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
  124. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/text_encoder/__init__.py +0 -0
  125. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/text_encoder/clip.py +0 -0
  126. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/text_encoder/siglip.py +0 -0
  127. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/text_encoder/t5.py +0 -0
  128. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/utils.py +0 -0
  129. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/vae/__init__.py +0 -0
  130. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/vae/vae.py +0 -0
  131. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/wan/__init__.py +0 -0
  132. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
  133. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
  134. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/wan/wan_vae.py +0 -0
  135. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/pipelines/controlnet_helper.py +0 -0
  136. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/processor/__init__.py +0 -0
  137. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/processor/canny_processor.py +0 -0
  138. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/processor/depth_processor.py +0 -0
  139. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tokenizers/__init__.py +0 -0
  140. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tokenizers/base.py +0 -0
  141. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tokenizers/clip.py +0 -0
  142. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tokenizers/t5.py +0 -0
  143. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tokenizers/wan.py +0 -0
  144. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tools/__init__.py +0 -0
  145. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/__init__.py +0 -0
  146. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/constants.py +0 -0
  147. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/download.py +0 -0
  148. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/env.py +0 -0
  149. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/flag.py +0 -0
  150. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/gguf.py +0 -0
  151. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/image.py +0 -0
  152. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/loader.py +0 -0
  153. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/lock.py +0 -0
  154. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/logging.py +0 -0
  155. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/offload.py +0 -0
  156. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/onnx.py +0 -0
  157. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/prompt.py +0 -0
  158. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/video.py +0 -0
  159. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
  160. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine.egg-info/requires.txt +0 -0
  161. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine.egg-info/top_level.txt +0 -0
  162. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/docs/tutorial.md +0 -0
  163. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/docs/tutorial_zh.md +0 -0
  164. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/pyproject.toml +0 -0
  165. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/setup.cfg +0 -0
  166. {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.3.6.dev8
3
+ Version: 0.3.6.dev10
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -1,12 +1,14 @@
1
+ from .configs import (
2
+ SDPipelineConfig,
3
+ SDXLPipelineConfig,
4
+ FluxPipelineConfig,
5
+ WanPipelineConfig,
6
+ )
1
7
  from .pipelines import (
2
8
  FluxImagePipeline,
3
9
  SDXLImagePipeline,
4
10
  SDImagePipeline,
5
11
  WanVideoPipeline,
6
- FluxModelConfig,
7
- SDXLModelConfig,
8
- SDModelConfig,
9
- WanModelConfig,
10
12
  ControlNetParams,
11
13
  )
12
14
  from .models.flux import FluxControlNet, FluxIPAdapter, FluxRedux
@@ -23,6 +25,10 @@ from .tools import (
23
25
  )
24
26
 
25
27
  __all__ = [
28
+ "SDPipelineConfig",
29
+ "SDXLPipelineConfig",
30
+ "FluxPipelineConfig",
31
+ "WanPipelineConfig",
26
32
  "FluxImagePipeline",
27
33
  "FluxControlNet",
28
34
  "FluxIPAdapter",
@@ -32,10 +38,6 @@ __all__ = [
32
38
  "SDXLImagePipeline",
33
39
  "SDImagePipeline",
34
40
  "WanVideoPipeline",
35
- "FluxModelConfig",
36
- "SDXLModelConfig",
37
- "SDModelConfig",
38
- "WanModelConfig",
39
41
  "FluxInpaintingTool",
40
42
  "FluxOutpaintingTool",
41
43
  "FluxIPAdapterRefTool",
@@ -0,0 +1,23 @@
1
+ from .pipeline import (
2
+ BaseConfig,
3
+ AttentionConfig,
4
+ OptimizationConfig,
5
+ ParallelConfig,
6
+ SDPipelineConfig,
7
+ SDXLPipelineConfig,
8
+ FluxPipelineConfig,
9
+ WanPipelineConfig,
10
+ )
11
+ from .controlnet import ControlType
12
+
13
+ __all__ = [
14
+ "BaseConfig",
15
+ "AttentionConfig",
16
+ "OptimizationConfig",
17
+ "ParallelConfig",
18
+ "SDPipelineConfig",
19
+ "SDXLPipelineConfig",
20
+ "FluxPipelineConfig",
21
+ "WanPipelineConfig",
22
+ "ControlType",
23
+ ]
@@ -0,0 +1,17 @@
1
+ from enum import Enum
2
+
3
+
4
+ # FLUX ControlType
5
+ class ControlType(Enum):
6
+ normal = "normal"
7
+ bfl_control = "bfl_control"
8
+ bfl_fill = "bfl_fill"
9
+ bfl_kontext = "bfl_kontext"
10
+
11
+ def get_in_channel(self):
12
+ if self in [ControlType.normal, ControlType.bfl_kontext]:
13
+ return 64
14
+ elif self == ControlType.bfl_control:
15
+ return 128
16
+ elif self == ControlType.bfl_fill:
17
+ return 384
@@ -0,0 +1,206 @@
1
+ import os
2
+ import torch
3
+ from dataclasses import dataclass, field
4
+ from typing import List, Tuple, Optional
5
+
6
+ from diffsynth_engine.configs.controlnet import ControlType
7
+
8
+
9
+ @dataclass
10
+ class BaseConfig:
11
+ model_path: str | os.PathLike | List[str | os.PathLike]
12
+ model_dtype: torch.dtype
13
+ batch_cfg: bool = False
14
+ vae_tiled: bool = False
15
+ vae_tile_size: int | Tuple[int, int] = 256
16
+ vae_tile_stride: int | Tuple[int, int] = 256
17
+ device: str = "cuda"
18
+ offload_mode: Optional[str] = None
19
+
20
+
21
+ @dataclass
22
+ class AttentionConfig:
23
+ dit_attn_impl: str = "auto"
24
+ # Sparge Attention
25
+ sparge_smooth_k: bool = True
26
+ sparge_cdfthreshd: float = 0.6
27
+ sparge_simthreshd1: float = 0.98
28
+ sparge_pvthreshd: float = 50.0
29
+
30
+
31
+ @dataclass
32
+ class OptimizationConfig:
33
+ use_fp8_linear: bool = False
34
+ use_fbcache: bool = False
35
+ fbcache_relative_l1_threshold: float = 0.05
36
+
37
+
38
+ @dataclass
39
+ class ParallelConfig:
40
+ parallelism: int = 1
41
+ use_cfg_parallel: bool = False
42
+ cfg_degree: Optional[int] = None
43
+ sp_ulysses_degree: Optional[int] = None
44
+ sp_ring_degree: Optional[int] = None
45
+ tp_degree: Optional[int] = None
46
+ use_fsdp: bool = False
47
+
48
+
49
+ @dataclass
50
+ class SDPipelineConfig(BaseConfig):
51
+ model_path: str | os.PathLike | List[str | os.PathLike]
52
+ clip_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
53
+ vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
54
+ model_dtype: torch.dtype = torch.float16
55
+ clip_dtype: torch.dtype = torch.float16
56
+ vae_dtype: torch.dtype = torch.float32
57
+
58
+ @classmethod
59
+ def basic_config(
60
+ cls,
61
+ model_path: str | os.PathLike | List[str | os.PathLike],
62
+ device: str = "cuda",
63
+ offload_mode: Optional[str] = None,
64
+ ) -> "SDPipelineConfig":
65
+ return cls(
66
+ model_path=model_path,
67
+ device=device,
68
+ offload_mode=offload_mode,
69
+ )
70
+
71
+
72
+ @dataclass
73
+ class SDXLPipelineConfig(BaseConfig):
74
+ model_path: str | os.PathLike | List[str | os.PathLike]
75
+ clip_l_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
76
+ clip_g_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
77
+ vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
78
+ model_dtype: torch.dtype = torch.float16
79
+ clip_l_dtype: torch.dtype = torch.float16
80
+ clip_g_dtype: torch.dtype = torch.float16
81
+ vae_dtype: torch.dtype = torch.float32
82
+
83
+ @classmethod
84
+ def basic_config(
85
+ cls,
86
+ model_path: str | os.PathLike | List[str | os.PathLike],
87
+ device: str = "cuda",
88
+ offload_mode: Optional[str] = None,
89
+ ) -> "SDXLPipelineConfig":
90
+ return cls(
91
+ model_path=model_path,
92
+ device=device,
93
+ offload_mode=offload_mode,
94
+ )
95
+
96
+
97
+ @dataclass
98
+ class FluxPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig):
99
+ model_path: str | os.PathLike | List[str | os.PathLike]
100
+ clip_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
101
+ t5_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
102
+ vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
103
+ model_dtype: torch.dtype = torch.bfloat16
104
+ clip_dtype: torch.dtype = torch.bfloat16
105
+ t5_dtype: torch.dtype = torch.bfloat16
106
+ vae_dtype: torch.dtype = torch.bfloat16
107
+
108
+ load_text_encoder: bool = True
109
+ control_type: ControlType = ControlType.normal
110
+
111
+ @classmethod
112
+ def basic_config(
113
+ cls,
114
+ model_path: str | os.PathLike | List[str | os.PathLike],
115
+ device: str = "cuda",
116
+ parallelism: int = 1,
117
+ offload_mode: Optional[str] = None,
118
+ ) -> "FluxPipelineConfig":
119
+ return cls(
120
+ model_path=model_path,
121
+ device=device,
122
+ parallelism=parallelism,
123
+ use_fsdp=True,
124
+ offload_mode=offload_mode,
125
+ )
126
+
127
+ def __post_init__(self):
128
+ init_parallel_config(self)
129
+
130
+
131
+ @dataclass
132
+ class WanPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig):
133
+ model_path: str | os.PathLike | List[str | os.PathLike]
134
+ t5_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
135
+ vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
136
+ image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
137
+ model_dtype: torch.dtype = torch.bfloat16
138
+ t5_dtype: torch.dtype = torch.bfloat16
139
+ vae_dtype: torch.dtype = torch.bfloat16
140
+ image_encoder_dtype: torch.dtype = torch.bfloat16
141
+
142
+ shift: Optional[float] = field(default=None, init=False) # RecifitedFlowScheduler shift factor, set by model type
143
+
144
+ # override BaseConfig
145
+ vae_tiled: bool = True
146
+ vae_tile_size: Tuple[int, int] = (34, 34)
147
+ vae_tile_stride: Tuple[int, int] = (18, 16)
148
+
149
+ @classmethod
150
+ def basic_config(
151
+ cls,
152
+ model_path: str | os.PathLike | List[str | os.PathLike],
153
+ image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
154
+ device: str = "cuda",
155
+ parallelism: int = 1,
156
+ offload_mode: Optional[str] = None,
157
+ ) -> "WanPipelineConfig":
158
+ return cls(
159
+ model_path=model_path,
160
+ image_encoder_path=image_encoder_path,
161
+ device=device,
162
+ parallelism=parallelism,
163
+ use_cfg_parallel=True,
164
+ use_fsdp=True,
165
+ offload_mode=offload_mode,
166
+ )
167
+
168
+ def __post_init__(self):
169
+ init_parallel_config(self)
170
+
171
+
172
+ def init_parallel_config(config: FluxPipelineConfig | WanPipelineConfig):
173
+ assert config.parallelism in (1, 2, 4, 8), "parallelism must be 1, 2, 4 or 8"
174
+ config.batch_cfg = True if config.parallelism > 1 and config.use_cfg_parallel else config.batch_cfg
175
+
176
+ if config.use_cfg_parallel is True and config.cfg_degree is not None:
177
+ raise ValueError("use_cfg_parallel and cfg_degree should not be specified together")
178
+ config.cfg_degree = (2 if config.use_cfg_parallel else 1) if config.cfg_degree is None else config.cfg_degree
179
+
180
+ if config.tp_degree is not None:
181
+ assert config.sp_ulysses_degree is None and config.sp_ring_degree is None, (
182
+ "not allowed to enable sequence parallel and tensor parallel together; "
183
+ "either set sp_ulysses_degree=None, sp_ring_degree=None or set tp_degree=None during pipeline initialization"
184
+ )
185
+ assert config.use_fsdp is False, (
186
+ "not allowed to enable fully sharded data parallel and tensor parallel together; "
187
+ "either set use_fsdp=False or set tp_degree=None during pipeline initialization"
188
+ )
189
+ assert config.parallelism == config.cfg_degree * config.tp_degree, (
190
+ f"parallelism ({config.parallelism}) must be equal to cfg_degree ({config.cfg_degree}) * tp_degree ({config.tp_degree})"
191
+ )
192
+ config.sp_ulysses_degree = 1
193
+ config.sp_ring_degree = 1
194
+ elif config.sp_ulysses_degree is None and config.sp_ring_degree is None:
195
+ # use ulysses if not specified
196
+ config.sp_ulysses_degree = config.parallelism // config.cfg_degree
197
+ config.sp_ring_degree = 1
198
+ config.tp_degree = 1
199
+ elif config.sp_ulysses_degree is not None and config.sp_ring_degree is not None:
200
+ assert config.parallelism == config.cfg_degree * config.sp_ulysses_degree * config.sp_ring_degree, (
201
+ f"parallelism ({config.parallelism}) must be equal to cfg_degree ({config.cfg_degree}) * "
202
+ f"sp_ulysses_degree ({config.sp_ulysses_degree}) * sp_ring_degree ({config.sp_ring_degree})"
203
+ )
204
+ config.tp_degree = 1
205
+ else:
206
+ raise ValueError("sp_ulysses_degree and sp_ring_degree must be specified together")
@@ -61,12 +61,33 @@ if SAGE_ATTN_AVAILABLE:
61
61
 
62
62
  if SPARGE_ATTN_AVAILABLE:
63
63
  from spas_sage_attn import spas_sage2_attn_meansim_cuda
64
+ from spas_sage_attn.autotune import SparseAttentionMeansim
64
65
 
65
- def sparge_attn(q, k, v, attn_mask=None, scale=None):
66
+ def sparge_attn(
67
+ q,
68
+ k,
69
+ v,
70
+ attn_mask=None,
71
+ scale=None,
72
+ smooth_k=True,
73
+ simthreshd1=0.6,
74
+ cdfthreshd=0.98,
75
+ pvthreshd=50,
76
+ ):
66
77
  q = q.transpose(1, 2)
67
78
  k = k.transpose(1, 2)
68
79
  v = v.transpose(1, 2)
69
- out = spas_sage2_attn_meansim_cuda(q, k, v, attn_mask=attn_mask, scale=scale)
80
+ out = spas_sage2_attn_meansim_cuda(
81
+ q,
82
+ k,
83
+ v,
84
+ attn_mask=attn_mask,
85
+ scale=scale,
86
+ smooth_k=smooth_k,
87
+ simthreshd1=simthreshd1,
88
+ cdfthreshd=cdfthreshd,
89
+ pvthreshd=pvthreshd,
90
+ )
70
91
  return out.transpose(1, 2)
71
92
 
72
93
 
@@ -91,6 +112,7 @@ def attention(
91
112
  attn_impl: Optional[str] = None,
92
113
  attn_mask: Optional[torch.Tensor] = None,
93
114
  scale: Optional[float] = None,
115
+ **kwargs,
94
116
  ):
95
117
  """
96
118
  q: [B, Lq, Nq, C1]
@@ -133,7 +155,17 @@ def attention(
133
155
  elif attn_impl == "sage_attn":
134
156
  return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
135
157
  elif attn_impl == "sparge_attn":
136
- return sparge_attn(q, k, v, attn_mask=attn_mask, scale=scale)
158
+ return sparge_attn(
159
+ q,
160
+ k,
161
+ v,
162
+ attn_mask=attn_mask,
163
+ scale=scale,
164
+ smooth_k=kwargs.get("sparge_smooth_k", True),
165
+ simthreshd1=kwargs.get("sparge_simthreshd1", 0.6),
166
+ cdfthreshd=kwargs.get("sparge_cdfthreshd", 0.98),
167
+ pvthreshd=kwargs.get("sparge_pvthreshd", 50),
168
+ )
137
169
  else:
138
170
  raise ValueError(f"Invalid attention implementation: {attn_impl}")
139
171
 
@@ -189,6 +221,7 @@ def long_context_attention(
189
221
  attn_impl: Optional[str] = None,
190
222
  attn_mask: Optional[torch.Tensor] = None,
191
223
  scale: Optional[float] = None,
224
+ **kwargs,
192
225
  ):
193
226
  """
194
227
  q: [B, Lq, Nq, C1]
@@ -226,7 +259,13 @@ def long_context_attention(
226
259
  elif attn_impl == "sage_attn":
227
260
  attn_func = LongContextAttention(attn_type=AttnType.SAGE_FP8)
228
261
  elif attn_impl == "sparge_attn":
229
- attn_func = LongContextAttention(attn_type=AttnType.SPARSE_SAGE)
262
+ attn_processor = SparseAttentionMeansim()
263
+ # default args from spas_sage2_attn_meansim_cuda
264
+ attn_processor.smooth_k = torch.tensor(kwargs.get("sparge_smooth_k", True))
265
+ attn_processor.simthreshd1 = torch.tensor(kwargs.get("sparge_simthreshd1", 0.6))
266
+ attn_processor.cdfthreshd = torch.tensor(kwargs.get("sparge_cdfthreshd", 0.98))
267
+ attn_processor.pvthreshd = torch.tensor(kwargs.get("sparge_pvthreshd", 50))
268
+ attn_func = LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)
230
269
  else:
231
270
  raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
232
271
  return attn_func(q, k, v, softmax_scale=scale)
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
- from typing import Optional, Dict
3
+ from typing import Any, Dict, Optional
4
4
  from einops import rearrange
5
5
  from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
6
6
  from diffsynth_engine.models.flux.flux_dit import (
@@ -87,7 +87,7 @@ class FluxControlNet(PreTrainedModel):
87
87
  def __init__(
88
88
  self,
89
89
  condition_channels: int = 64,
90
- attn_impl: Optional[str] = None,
90
+ attn_kwargs: Optional[Dict[str, Any]] = None,
91
91
  device: str = "cuda:0",
92
92
  dtype: torch.dtype = torch.bfloat16,
93
93
  ):
@@ -104,7 +104,10 @@ class FluxControlNet(PreTrainedModel):
104
104
  self.x_embedder = nn.Linear(64, 3072, device=device, dtype=dtype)
105
105
  self.controlnet_x_embedder = nn.Linear(condition_channels, 3072)
106
106
  self.blocks = nn.ModuleList(
107
- [FluxDoubleTransformerBlock(3072, 24, attn_impl=attn_impl, device=device, dtype=dtype) for _ in range(6)]
107
+ [
108
+ FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
109
+ for _ in range(6)
110
+ ]
108
111
  )
109
112
  # controlnet projection
110
113
  self.blocks_proj = nn.ModuleList(
@@ -154,7 +157,7 @@ class FluxControlNet(PreTrainedModel):
154
157
  state_dict: Dict[str, torch.Tensor],
155
158
  device: str,
156
159
  dtype: torch.dtype,
157
- attn_impl: Optional[str] = None,
160
+ attn_kwargs: Optional[Dict[str, Any]] = None,
158
161
  ):
159
162
  if "controlnet_x_embedder.weight" in state_dict:
160
163
  condition_channels = state_dict["controlnet_x_embedder.weight"].shape[1]
@@ -163,7 +166,7 @@ class FluxControlNet(PreTrainedModel):
163
166
 
164
167
  with no_init_weights():
165
168
  model = torch.nn.utils.skip_init(
166
- cls, condition_channels=condition_channels, attn_impl=attn_impl, device=device, dtype=dtype
169
+ cls, condition_channels=condition_channels, attn_kwargs=attn_kwargs, device=device, dtype=dtype
167
170
  )
168
171
  model.load_state_dict(state_dict)
169
172
  model.to(device=device, dtype=dtype, non_blocking=True)
@@ -2,7 +2,7 @@ import json
2
2
  import torch
3
3
  import torch.nn as nn
4
4
  import numpy as np
5
- from typing import Dict, Optional
5
+ from typing import Any, Dict, Optional
6
6
  from einops import rearrange
7
7
 
8
8
  from diffsynth_engine.models.basic.transformer_helper import (
@@ -177,7 +177,7 @@ class FluxDoubleAttention(nn.Module):
177
177
  dim_b,
178
178
  num_heads,
179
179
  head_dim,
180
- attn_impl: Optional[str] = None,
180
+ attn_kwargs: Optional[Dict[str, Any]] = None,
181
181
  device: str = "cuda:0",
182
182
  dtype: torch.dtype = torch.bfloat16,
183
183
  ):
@@ -195,7 +195,7 @@ class FluxDoubleAttention(nn.Module):
195
195
 
196
196
  self.a_to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype)
197
197
  self.b_to_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype)
198
- self.attn_impl = attn_impl
198
+ self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
199
199
 
200
200
  def attention_callback(self, attn_out_a, attn_out_b, x_a, x_b, q_a, q_b, k_a, k_b, v_a, v_b, rope_emb, image_emb):
201
201
  return attn_out_a, attn_out_b
@@ -207,7 +207,7 @@ class FluxDoubleAttention(nn.Module):
207
207
  k = torch.cat([self.norm_k_b(k_b), self.norm_k_a(k_a)], dim=1)
208
208
  v = torch.cat([v_b, v_a], dim=1)
209
209
  q, k = apply_rope(q, k, rope_emb)
210
- attn_out = attention_ops.attention(q, k, v, attn_impl=self.attn_impl)
210
+ attn_out = attention_ops.attention(q, k, v, **self.attn_kwargs)
211
211
  attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
212
212
  text_out, image_out = attn_out[:, : text.shape[1]], attn_out[:, text.shape[1] :]
213
213
  image_out, text_out = self.attention_callback(
@@ -232,13 +232,13 @@ class FluxDoubleTransformerBlock(nn.Module):
232
232
  self,
233
233
  dim,
234
234
  num_heads,
235
- attn_impl: Optional[str] = None,
235
+ attn_kwargs: Optional[Dict[str, Any]] = None,
236
236
  device: str = "cuda:0",
237
237
  dtype: torch.dtype = torch.bfloat16,
238
238
  ):
239
239
  super().__init__()
240
240
  self.attn = FluxDoubleAttention(
241
- dim, dim, num_heads, dim // num_heads, attn_impl=attn_impl, device=device, dtype=dtype
241
+ dim, dim, num_heads, dim // num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype
242
242
  )
243
243
  # Image
244
244
  self.norm_msa_a = AdaLayerNormZero(dim, device=device, dtype=dtype)
@@ -278,7 +278,7 @@ class FluxSingleAttention(nn.Module):
278
278
  self,
279
279
  dim,
280
280
  num_heads,
281
- attn_impl: Optional[str] = None,
281
+ attn_kwargs: Optional[Dict[str, Any]] = None,
282
282
  device: str = "cuda:0",
283
283
  dtype: torch.dtype = torch.bfloat16,
284
284
  ):
@@ -287,7 +287,7 @@ class FluxSingleAttention(nn.Module):
287
287
  self.to_qkv = nn.Linear(dim, dim * 3, device=device, dtype=dtype)
288
288
  self.norm_q_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype)
289
289
  self.norm_k_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype)
290
- self.attn_impl = attn_impl
290
+ self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
291
291
 
292
292
  def attention_callback(self, attn_out, x, q, k, v, rope_emb, image_emb):
293
293
  return attn_out
@@ -295,7 +295,7 @@ class FluxSingleAttention(nn.Module):
295
295
  def forward(self, x, rope_emb, image_emb):
296
296
  q, k, v = rearrange(self.to_qkv(x), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
297
297
  q, k = apply_rope(self.norm_q_a(q), self.norm_k_a(k), rope_emb)
298
- attn_out = attention_ops.attention(q, k, v, attn_impl=self.attn_impl)
298
+ attn_out = attention_ops.attention(q, k, v, **self.attn_kwargs)
299
299
  attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
300
300
  return self.attention_callback(attn_out=attn_out, x=x, q=q, k=k, v=v, rope_emb=rope_emb, image_emb=image_emb)
301
301
 
@@ -305,14 +305,14 @@ class FluxSingleTransformerBlock(nn.Module):
305
305
  self,
306
306
  dim,
307
307
  num_heads,
308
- attn_impl: Optional[str] = None,
308
+ attn_kwargs: Optional[Dict[str, Any]] = None,
309
309
  device: str = "cuda:0",
310
310
  dtype: torch.dtype = torch.bfloat16,
311
311
  ):
312
312
  super().__init__()
313
313
  self.dim = dim
314
314
  self.norm = AdaLayerNormZero(dim, device=device, dtype=dtype)
315
- self.attn = FluxSingleAttention(dim, num_heads, attn_impl=attn_impl, device=device, dtype=dtype)
315
+ self.attn = FluxSingleAttention(dim, num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
316
316
  self.mlp = nn.Sequential(
317
317
  nn.Linear(dim, dim * 4),
318
318
  nn.GELU(approximate="tanh"),
@@ -333,7 +333,7 @@ class FluxDiT(PreTrainedModel):
333
333
  def __init__(
334
334
  self,
335
335
  in_channel: int = 64,
336
- attn_impl: Optional[str] = None,
336
+ attn_kwargs: Optional[Dict[str, Any]] = None,
337
337
  device: str = "cuda:0",
338
338
  dtype: torch.dtype = torch.bfloat16,
339
339
  ):
@@ -351,10 +351,16 @@ class FluxDiT(PreTrainedModel):
351
351
  self.x_embedder = nn.Linear(in_channel, 3072, device=device, dtype=dtype)
352
352
 
353
353
  self.blocks = nn.ModuleList(
354
- [FluxDoubleTransformerBlock(3072, 24, attn_impl=attn_impl, device=device, dtype=dtype) for _ in range(19)]
354
+ [
355
+ FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
356
+ for _ in range(19)
357
+ ]
355
358
  )
356
359
  self.single_blocks = nn.ModuleList(
357
- [FluxSingleTransformerBlock(3072, 24, attn_impl=attn_impl, device=device, dtype=dtype) for _ in range(38)]
360
+ [
361
+ FluxSingleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
362
+ for _ in range(38)
363
+ ]
358
364
  )
359
365
  self.final_norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
360
366
  self.final_proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
@@ -495,7 +501,7 @@ class FluxDiT(PreTrainedModel):
495
501
  device: str,
496
502
  dtype: torch.dtype,
497
503
  in_channel: int = 64,
498
- attn_impl: Optional[str] = None,
504
+ attn_kwargs: Optional[Dict[str, Any]] = None,
499
505
  ):
500
506
  with no_init_weights():
501
507
  model = torch.nn.utils.skip_init(
@@ -503,7 +509,7 @@ class FluxDiT(PreTrainedModel):
503
509
  device=device,
504
510
  dtype=dtype,
505
511
  in_channel=in_channel,
506
- attn_impl=attn_impl,
512
+ attn_kwargs=attn_kwargs,
507
513
  )
508
514
  model = model.requires_grad_(False) # for loading gguf
509
515
  model.load_state_dict(state_dict, assign=True)
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
  import numpy as np
3
- from typing import Dict, Optional
3
+ from typing import Any, Dict, Optional
4
4
 
5
5
  from diffsynth_engine.models.utils import no_init_weights
6
6
  from diffsynth_engine.utils.gguf import gguf_inference
@@ -21,12 +21,12 @@ class FluxDiTFBCache(FluxDiT):
21
21
  def __init__(
22
22
  self,
23
23
  in_channel: int = 64,
24
- attn_impl: Optional[str] = None,
24
+ attn_kwargs: Optional[Dict[str, Any]] = None,
25
25
  device: str = "cuda:0",
26
26
  dtype: torch.dtype = torch.bfloat16,
27
27
  relative_l1_threshold: float = 0.05,
28
28
  ):
29
- super().__init__(in_channel=in_channel, attn_impl=attn_impl, device=device, dtype=dtype)
29
+ super().__init__(in_channel=in_channel, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
30
30
  self.relative_l1_threshold = relative_l1_threshold
31
31
  self.step_count = 0
32
32
  self.num_inference_steps = 0
@@ -187,7 +187,7 @@ class FluxDiTFBCache(FluxDiT):
187
187
  device: str,
188
188
  dtype: torch.dtype,
189
189
  in_channel: int = 64,
190
- attn_impl: Optional[str] = None,
190
+ attn_kwargs: Optional[Dict[str, Any]] = None,
191
191
  fb_cache_relative_l1_threshold: float = 0.05,
192
192
  ):
193
193
  with no_init_weights():
@@ -196,7 +196,7 @@ class FluxDiTFBCache(FluxDiT):
196
196
  device=device,
197
197
  dtype=dtype,
198
198
  in_channel=in_channel,
199
- attn_impl=attn_impl,
199
+ attn_kwargs=attn_kwargs,
200
200
  fb_cache_relative_l1_threshold=fb_cache_relative_l1_threshold,
201
201
  )
202
202
  model = model.requires_grad_(False) # for loading gguf
@@ -2,7 +2,7 @@ import torch
2
2
  from einops import rearrange
3
3
  from torch import nn
4
4
  from PIL import Image
5
- from typing import Dict, List
5
+ from typing import Any, Dict, List, Optional
6
6
  from functools import partial
7
7
  from diffsynth_engine.models.utils import no_init_weights
8
8
  from diffsynth_engine.models.text_encoder.siglip import SiglipImageEncoder
@@ -19,7 +19,7 @@ class FluxIPAdapterAttention(nn.Module):
19
19
  dim: int = 3072,
20
20
  head_num: int = 24,
21
21
  scale: float = 1.0,
22
- attn_impl="auto",
22
+ attn_kwargs: Optional[Dict[str, Any]] = None,
23
23
  device: str = "cuda:0",
24
24
  dtype: torch.dtype = torch.bfloat16,
25
25
  ):
@@ -29,12 +29,12 @@ class FluxIPAdapterAttention(nn.Module):
29
29
  self.to_v_ip = nn.Linear(image_emb_dim, dim, device=device, dtype=dtype, bias=False)
30
30
  self.head_num = head_num
31
31
  self.scale = scale
32
- self.attn_impl = attn_impl
32
+ self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
33
33
 
34
34
  def forward(self, query: torch.Tensor, image_emb: torch.Tensor):
35
35
  key = rearrange(self.norm_k(self.to_k_ip(image_emb)), "b s (h d) -> b s h d", h=self.head_num)
36
36
  value = rearrange(self.to_v_ip(image_emb), "b s (h d) -> b s h d", h=self.head_num)
37
- attn_out = attention(query, key, value)
37
+ attn_out = attention(query, key, value, **self.attn_kwargs)
38
38
  return self.scale * rearrange(attn_out, "b s h d -> b s (h d)")
39
39
 
40
40
  @classmethod
@@ -142,7 +142,7 @@ class FluxIPAdapter(PreTrainedModel):
142
142
  single_attention_callback, self=dit.single_blocks[i].attn
143
143
  )
144
144
 
145
- def image_encode(self, image: Image.Image) -> torch.Tensor:
145
+ def encode_image(self, image: Image.Image) -> torch.Tensor:
146
146
  image_emb = self.image_encoder(image)
147
147
  return self.image_proj(image_emb)
148
148
 
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
- from typing import Dict, Optional
3
+ from typing import Dict
4
4
 
5
5
  from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
6
6
  from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
@@ -570,7 +570,6 @@ class SDControlNet(PreTrainedModel):
570
570
 
571
571
  def __init__(
572
572
  self,
573
- attn_impl: Optional[str] = None,
574
573
  device: str = "cuda:0",
575
574
  dtype: torch.dtype = torch.bfloat16,
576
575
  ):
@@ -666,10 +665,9 @@ class SDControlNet(PreTrainedModel):
666
665
  state_dict: Dict[str, torch.Tensor],
667
666
  device: str,
668
667
  dtype: torch.dtype,
669
- attn_impl: Optional[str] = None,
670
668
  ):
671
669
  with no_init_weights():
672
- model = torch.nn.utils.skip_init(cls, attn_impl=attn_impl, device=device, dtype=dtype)
670
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
673
671
  model.load_state_dict(state_dict)
674
672
  model.to(device=device, dtype=dtype, non_blocking=True)
675
673
  return model