diffsynth-engine 0.2.0__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (232) hide show
  1. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/.gitignore +3 -1
  2. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/PKG-INFO +3 -3
  3. diffsynth_engine-0.2.1/assets/dingtalk.png +0 -0
  4. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/__init__.py +7 -0
  5. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/base.py +5 -5
  6. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/attention.py +22 -6
  7. diffsynth_engine-0.2.1/diffsynth_engine/models/components/siglip.py +169 -0
  8. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/flux/__init__.py +2 -0
  9. diffsynth_engine-0.2.1/diffsynth_engine/models/flux/flux_controlnet.py +160 -0
  10. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/flux/flux_dit.py +16 -17
  11. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd3/sd3_dit.py +1 -7
  12. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sdxl/sdxl_unet.py +1 -7
  13. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/wan_dit.py +1 -0
  14. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/__init__.py +2 -1
  15. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/base.py +26 -28
  16. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/flux_image.py +179 -32
  17. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/sd_image.py +32 -7
  18. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/sdxl_image.py +32 -7
  19. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/wan_video.py +51 -27
  20. diffsynth_engine-0.2.1/diffsynth_engine/tools/__init__.py +4 -0
  21. diffsynth_engine-0.2.1/diffsynth_engine/tools/flux_inpainting.py +50 -0
  22. diffsynth_engine-0.2.1/diffsynth_engine/tools/flux_outpainting.py +58 -0
  23. diffsynth_engine-0.2.1/diffsynth_engine/utils/env.py +10 -0
  24. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/flag.py +2 -2
  25. diffsynth_engine-0.2.1/diffsynth_engine/utils/image.py +25 -0
  26. diffsynth_engine-0.2.1/diffsynth_engine/utils/loader.py +32 -0
  27. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/parallel.py +15 -4
  28. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/PKG-INFO +3 -3
  29. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/SOURCES.txt +17 -1
  30. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/requires.txt +2 -2
  31. diffsynth_engine-0.2.1/examples/i2v_input.jpg +0 -0
  32. diffsynth_engine-0.2.1/examples/wan_image_to_video.py +35 -0
  33. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/examples/wan_text_to_video.py +1 -1
  34. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/pyproject.toml +2 -2
  35. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/common/test_case.py +1 -1
  36. diffsynth_engine-0.2.1/tests/data/expect/flux/flux_inpainting.png +0 -0
  37. diffsynth_engine-0.2.1/tests/data/expect/flux/flux_outpainting.png +0 -0
  38. diffsynth_engine-0.2.1/tests/data/expect/flux/flux_union_pro_canny.png +0 -0
  39. diffsynth_engine-0.2.1/tests/data/expect/test_siglip_image_encoder.safetensors +0 -0
  40. diffsynth_engine-0.2.1/tests/data/input/canny.png +0 -0
  41. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/flux/test_flux_dit.py +1 -1
  42. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/flux/test_flux_text_encoder.py +1 -2
  43. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/flux/test_flux_vae.py +1 -2
  44. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/sd/test_sd_text_encoder.py +1 -2
  45. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/sd/test_sd_vae.py +1 -2
  46. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/sdxl/test_sdxl_text_encoder.py +1 -1
  47. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/sdxl/test_sdxl_vae.py +1 -2
  48. diffsynth_engine-0.2.1/tests/test_models/test_siglip.py +17 -0
  49. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/wan/test_wan_vae.py +1 -2
  50. diffsynth_engine-0.2.1/tests/test_pipelines/test_flux_controlnet.py +32 -0
  51. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_pipelines/test_flux_image.py +0 -13
  52. diffsynth_engine-0.2.1/tests/test_tools/__init__.py +0 -0
  53. diffsynth_engine-0.2.1/tests/test_tools/test_flux_tools.py +31 -0
  54. diffsynth_engine-0.2.0/assets/dingtalk.png +0 -0
  55. diffsynth_engine-0.2.0/diffsynth_engine/utils/env.py +0 -7
  56. diffsynth_engine-0.2.0/diffsynth_engine/utils/loader.py +0 -17
  57. diffsynth_engine-0.2.0/tests/data/expect/flux/flux_inpainting.png +0 -0
  58. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/.github/workflows/python-publish.yml +0 -0
  59. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/.pre-commit-config.yaml +0 -0
  60. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/LICENSE +0 -0
  61. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/README.md +0 -0
  62. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/assets/showcase.jpeg +0 -0
  63. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/__init__.py +0 -0
  64. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
  65. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
  66. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
  67. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
  68. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
  69. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
  70. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  71. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
  72. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
  73. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
  74. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
  75. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
  76. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
  77. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
  78. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  79. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
  80. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  81. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
  82. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
  83. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
  84. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
  85. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
  86. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
  87. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
  88. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
  89. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
  90. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/components/vae.json +0 -0
  91. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
  92. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
  93. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
  94. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
  95. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
  96. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
  97. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
  98. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
  99. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
  100. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +0 -0
  101. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/wan/dit/14b-i2v.json +0 -0
  102. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/wan/dit/14b-t2v.json +0 -0
  103. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
  104. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
  105. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
  106. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
  107. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
  108. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  109. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
  110. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
  111. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
  112. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
  113. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
  114. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
  115. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
  116. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
  117. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
  118. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
  119. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
  120. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  121. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
  122. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
  123. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/kernels/__init__.py +0 -0
  124. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/__init__.py +0 -0
  125. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/__init__.py +0 -0
  126. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/lora.py +0 -0
  127. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
  128. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/timestep.py +0 -0
  129. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/transformer_helper.py +0 -0
  130. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/unet_helper.py +0 -0
  131. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/components/__init__.py +0 -0
  132. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/components/clip.py +0 -0
  133. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/components/t5.py +0 -0
  134. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/components/vae.py +0 -0
  135. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
  136. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/flux/flux_vae.py +0 -0
  137. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd/__init__.py +0 -0
  138. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
  139. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd/sd_unet.py +0 -0
  140. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd/sd_vae.py +0 -0
  141. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd3/__init__.py +0 -0
  142. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
  143. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
  144. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sdxl/__init__.py +0 -0
  145. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
  146. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
  147. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/utils.py +0 -0
  148. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/__init__.py +0 -0
  149. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
  150. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
  151. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/wan_vae.py +0 -0
  152. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/__init__.py +0 -0
  153. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/base.py +0 -0
  154. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/clip.py +0 -0
  155. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/t5.py +0 -0
  156. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/wan.py +0 -0
  157. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/__init__.py +0 -0
  158. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/constants.py +0 -0
  159. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/download.py +0 -0
  160. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/fp8_linear.py +0 -0
  161. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/gguf.py +0 -0
  162. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/lock.py +0 -0
  163. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/logging.py +0 -0
  164. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/offload.py +0 -0
  165. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/prompt.py +0 -0
  166. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/video.py +0 -0
  167. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
  168. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/top_level.txt +0 -0
  169. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/docs/tutorial.md +0 -0
  170. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/docs/tutorial_zh.md +0 -0
  171. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/examples/flux_lora.py +0 -0
  172. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/examples/flux_text_to_image.py +0 -0
  173. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/examples/sdxl_text_to_image.py +0 -0
  174. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/examples/wan_lora.py +0 -0
  175. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/setup.cfg +0 -0
  176. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/setup.py +0 -0
  177. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/__init__.py +0 -0
  178. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/common/__init__.py +0 -0
  179. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/common/utils.py +0 -0
  180. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/beta_20steps.safetensors +0 -0
  181. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/ddim_20steps.safetensors +0 -0
  182. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/euler_i10.safetensors +0 -0
  183. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/exponential_20steps.safetensors +0 -0
  184. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/flow_match_euler_i10.safetensors +0 -0
  185. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/karras_20steps.safetensors +0 -0
  186. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/output.safetensors +0 -0
  187. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/recifited_flow_20steps_flux.safetensors +0 -0
  188. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/scaled_linear_20steps.safetensors +0 -0
  189. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/sgm_uniform_20steps.safetensors +0 -0
  190. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/flux/flux_dit.safetensors +0 -0
  191. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/flux/flux_lora.png +0 -0
  192. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/flux/flux_text_encoder_1.safetensors +0 -0
  193. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/flux/flux_text_encoder_2.safetensors +0 -0
  194. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/flux/flux_txt2img.png +0 -0
  195. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/flux/flux_vae.safetensors +0 -0
  196. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sd/sd_inpainting.png +0 -0
  197. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sd/sd_lora.png +0 -0
  198. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sd/sd_text_encoder.safetensors +0 -0
  199. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sd/sd_txt2img.png +0 -0
  200. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sd/sd_unet.safetensors +0 -0
  201. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sd/sd_vae.safetensors +0 -0
  202. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sdxl/sdxl_inpainting.png +0 -0
  203. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sdxl/sdxl_lora.png +0 -0
  204. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sdxl/sdxl_text_encoder_1.safetensors +0 -0
  205. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sdxl/sdxl_text_encoder_2.safetensors +0 -0
  206. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sdxl/sdxl_txt2img.png +0 -0
  207. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sdxl/sdxl_unet.safetensors +0 -0
  208. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sdxl/sdxl_vae.safetensors +0 -0
  209. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/wan/wan_vae.safetensors +0 -0
  210. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/input/astronaut_320_320.mp4 +0 -0
  211. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/input/mask_image.png +0 -0
  212. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/input/test_image.png +0 -0
  213. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/input/wukong_1024_1024.png +0 -0
  214. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/input/wukong_480_480.png +0 -0
  215. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_algorithm/__init__.py +0 -0
  216. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_algorithm/test_sampler.py +0 -0
  217. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_algorithm/test_scheduler.py +0 -0
  218. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/__init__.py +0 -0
  219. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/flux/__init__.py +0 -0
  220. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/sd/__init__.py +0 -0
  221. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/sd/test_sd_unet.py +0 -0
  222. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/sdxl/__init__.py +0 -0
  223. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/sdxl/test_sdxl_unet.py +0 -0
  224. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_pipelines/__init__.py +0 -0
  225. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_pipelines/test_sd_image.py +0 -0
  226. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_pipelines/test_sdxl_image.py +0 -0
  227. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_pipelines/test_wan_video.py +0 -0
  228. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_pipelines/test_wan_video_gguf.py +0 -0
  229. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_pipelines/test_wan_video_tp.py +0 -0
  230. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_tokenizers/__init__.py +0 -0
  231. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_tokenizers/test_clip.py +0 -0
  232. {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_tokenizers/test_t5.py +0 -0
@@ -6,4 +6,6 @@ tmp/
6
6
  build/
7
7
  dist/
8
8
  *.egg-info/
9
- .DS_Store/
9
+ .DS_Store/
10
+ .pytest_cache/
11
+ .ruff_cache/
@@ -1,12 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
7
7
  Requires-Python: >=3.10
8
8
  License-File: LICENSE
9
- Requires-Dist: torch>=2.4.1
9
+ Requires-Dist: torch>=2.6
10
10
  Requires-Dist: torchvision
11
11
  Requires-Dist: xformers; sys_platform == "linux"
12
12
  Requires-Dist: safetensors
@@ -22,7 +22,7 @@ Requires-Dist: scipy
22
22
  Requires-Dist: torchsde
23
23
  Requires-Dist: pillow
24
24
  Requires-Dist: imageio[ffmpeg]
25
- Requires-Dist: yunchang
25
+ Requires-Dist: yunchang; sys_platform == "linux"
26
26
  Provides-Extra: dev
27
27
  Requires-Dist: diffusers==0.31.0; extra == "dev"
28
28
  Requires-Dist: transformers==4.45.2; extra == "dev"
@@ -7,12 +7,16 @@ from .pipelines import (
7
7
  SDXLModelConfig,
8
8
  SDModelConfig,
9
9
  WanModelConfig,
10
+ ControlNetParams,
10
11
  )
12
+ from .models.flux import FluxControlNet
11
13
  from .utils.download import fetch_model, fetch_modelscope_model, fetch_civitai_model
12
14
  from .utils.video import load_video, save_video
15
+ from .tools import FluxInpaintingTool, FluxOutpaintingTool
13
16
 
14
17
  __all__ = [
15
18
  "FluxImagePipeline",
19
+ "FluxControlNet",
16
20
  "SDXLImagePipeline",
17
21
  "SDImagePipeline",
18
22
  "WanVideoPipeline",
@@ -20,6 +24,9 @@ __all__ = [
20
24
  "SDXLModelConfig",
21
25
  "SDModelConfig",
22
26
  "WanModelConfig",
27
+ "FluxInpaintingTool",
28
+ "FluxOutpaintingTool",
29
+ "ControlNetParams",
23
30
  "fetch_model",
24
31
  "fetch_modelscope_model",
25
32
  "fetch_civitai_model",
@@ -1,9 +1,8 @@
1
1
  import os
2
2
  import torch
3
3
  import torch.nn as nn
4
- from typing import Dict, List, Union
5
- from safetensors.torch import load_file
6
-
4
+ from typing import Dict, Union, List, Any
5
+ from diffsynth_engine.utils.loader import load_file
7
6
  from diffsynth_engine.models.basic.lora import LoRALinear, LoRAConv2d
8
7
  from diffsynth_engine.models.utils import no_init_weights
9
8
 
@@ -22,18 +21,19 @@ class PreTrainedModel(nn.Module):
22
21
 
23
22
  @classmethod
24
23
  def from_pretrained(cls, pretrained_model_path: Union[str, os.PathLike], device: str, dtype: torch.dtype, **kwargs):
25
- state_dict = load_file(pretrained_model_path, device=device)
24
+ state_dict = load_file(pretrained_model_path)
26
25
  return cls.from_state_dict(state_dict, device=device, dtype=dtype, **kwargs)
27
26
 
28
27
  @classmethod
29
28
  def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype, **kwargs):
30
29
  with no_init_weights():
31
30
  model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, **kwargs)
31
+ model.to_empty(device=device)
32
32
  model.load_state_dict(state_dict)
33
33
  model.to(device=device, dtype=dtype, non_blocking=True)
34
34
  return model
35
35
 
36
- def load_loras(self, lora_args: List[Dict[str, any]], fused: bool = True):
36
+ def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = True):
37
37
  for args in lora_args:
38
38
  key = args["name"]
39
39
  module = self.get_submodule(key)
@@ -1,10 +1,9 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
- from einops import rearrange
3
+ from einops import rearrange, repeat
4
4
  from typing import Optional
5
- from yunchang import LongContextAttention
6
- from yunchang.kernels import AttnType
7
5
 
6
+ import torch.nn.functional as F
8
7
  from diffsynth_engine.utils import logging
9
8
  from diffsynth_engine.utils.flag import (
10
9
  FLASH_ATTN_3_AVAILABLE,
@@ -18,12 +17,26 @@ from diffsynth_engine.utils.flag import (
18
17
  logger = logging.get_logger(__name__)
19
18
 
20
19
 
20
+ def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
21
+ padding_size = (alignment - x.shape[dim] % alignment) % alignment
22
+ padded_x = F.pad(x, (0, padding_size), "constant", 0)
23
+ return padded_x[..., : x.shape[dim]]
24
+
25
+
21
26
  if FLASH_ATTN_3_AVAILABLE:
22
27
  from flash_attn_interface import flash_attn_func as flash_attn3
23
28
  if FLASH_ATTN_2_AVAILABLE:
24
29
  from flash_attn import flash_attn_func as flash_attn2
25
30
  if XFORMERS_AVAILABLE:
26
- from xformers.ops import memory_efficient_attention as xformers_attn
31
+ from xformers.ops import memory_efficient_attention
32
+
33
+ def xformers_attn(q, k, v, attn_mask=None, scale=None):
34
+ if attn_mask is not None:
35
+ attn_mask = repeat(attn_mask, "S L -> B H S L", B=q.shape[0], H=q.shape[2])
36
+ attn_mask = memory_align(attn_mask)
37
+ return memory_efficient_attention(q, k, v, attn_bias=attn_mask, scale=scale)
38
+
39
+
27
40
  if SDPA_AVAILABLE:
28
41
 
29
42
  def sdpa_attn(q, k, v, attn_mask=None, scale=None):
@@ -100,7 +113,7 @@ def attention(
100
113
  elif FLASH_ATTN_2_AVAILABLE:
101
114
  return flash_attn2(q, k, v, softmax_scale=scale)
102
115
  elif XFORMERS_AVAILABLE:
103
- return xformers_attn(q, k, v, attn_bias=attn_mask, scale=scale)
116
+ return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
104
117
  elif SDPA_AVAILABLE:
105
118
  return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
106
119
  else:
@@ -113,7 +126,7 @@ def attention(
113
126
  elif attn_impl == "flash_attn_2":
114
127
  return flash_attn2(q, k, v, softmax_scale=scale)
115
128
  elif attn_impl == "xformers":
116
- return xformers_attn(q, k, v, attn_bias=attn_mask, scale=scale)
129
+ return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
117
130
  elif attn_impl == "sdpa":
118
131
  return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
119
132
  elif attn_impl == "sage_attn":
@@ -181,6 +194,9 @@ def long_context_attention(
181
194
  k: [B, Lk, Nk, C1]
182
195
  v: [B, Lk, Nk, C2]
183
196
  """
197
+ from yunchang import LongContextAttention
198
+ from yunchang.kernels import AttnType
199
+
184
200
  assert attn_impl in [
185
201
  None,
186
202
  "auto",
@@ -0,0 +1,169 @@
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from einops import rearrange
7
+ from typing import Union, List
8
+ from PIL import Image
9
+ from diffsynth_engine.models.basic.attention import Attention
10
+ from diffsynth_engine.utils.loader import load_file
11
+ from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
12
+
13
+
14
+ class SiglipVisionEmbeddings(nn.Module):
15
+ def __init__(
16
+ self, num_channels: int, num_positions: int, hidden_size: int, patch_size: int, device: str, dtype: torch.dtype
17
+ ):
18
+ super().__init__()
19
+ self.patch_embedding = nn.Conv2d(
20
+ in_channels=num_channels,
21
+ out_channels=hidden_size,
22
+ kernel_size=patch_size,
23
+ stride=patch_size,
24
+ padding="valid",
25
+ device=device,
26
+ dtype=dtype,
27
+ )
28
+ self.position_embedding = nn.Embedding(num_positions, hidden_size, device=device, dtype=dtype)
29
+ self.position_ids = torch.arange(num_positions).expand((1, -1))
30
+
31
+ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
32
+ target_dtype = self.patch_embedding.weight.dtype
33
+ target_device = self.patch_embedding.weight.device
34
+ self.position_ids = self.position_ids.to(target_device)
35
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
36
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
37
+ embeddings = embeddings + self.position_embedding(self.position_ids)
38
+ return embeddings
39
+
40
+
41
+ class SiglipMLP(nn.Module):
42
+ def __init__(self, hidden_size, inner_dim, device, dtype):
43
+ super().__init__()
44
+ self.fc1 = nn.Linear(hidden_size, inner_dim, device=device, dtype=dtype)
45
+ self.fc2 = nn.Linear(inner_dim, hidden_size, device=device, dtype=dtype)
46
+
47
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
48
+ hidden_states = self.fc1(hidden_states)
49
+ hidden_states = F.gelu(hidden_states, approximate="tanh")
50
+ hidden_states = self.fc2(hidden_states)
51
+ return hidden_states
52
+
53
+
54
+ class SiglipEncoderLayer(nn.Module):
55
+ def __init__(self, hidden_size: int, inner_dim: int, num_heads: int, eps: float, device: str, dtype: torch.dtype):
56
+ super().__init__()
57
+ self.layer_norm1 = nn.LayerNorm(hidden_size, eps=eps)
58
+ self.self_attn = Attention(
59
+ q_dim=hidden_size,
60
+ num_heads=num_heads,
61
+ head_dim=hidden_size // num_heads,
62
+ bias_q=True,
63
+ bias_kv=True,
64
+ bias_out=True,
65
+ )
66
+ self.layer_norm2 = nn.LayerNorm(hidden_size, eps=eps)
67
+ self.mlp = SiglipMLP(hidden_size=hidden_size, inner_dim=inner_dim, device=device, dtype=dtype)
68
+
69
+ def forward(self, x):
70
+ x = self.self_attn(self.layer_norm1(x)) + x
71
+ x = self.mlp(self.layer_norm2(x)) + x
72
+ return x
73
+
74
+
75
+ class SiglipMultiheadAttentionPoolingHead(nn.Module):
76
+ """Multihead Attention Pooling."""
77
+
78
+ def __init__(self, hidden_size, inner_dim, num_heads, eps, device, dtype) -> None:
79
+ super().__init__()
80
+
81
+ self.probe = nn.Parameter(data=torch.randn(1, 1, hidden_size))
82
+ self.attention = nn.MultiheadAttention(
83
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, device=device, dtype=dtype
84
+ )
85
+ self.layernorm = nn.LayerNorm(normalized_shape=hidden_size, eps=eps, device=device, dtype=dtype)
86
+ self.mlp = SiglipMLP(hidden_size=hidden_size, inner_dim=inner_dim, device=device, dtype=dtype)
87
+
88
+ def forward(self, hidden_state) -> torch.Tensor:
89
+ batch_size = hidden_state.shape[0]
90
+ probe = self.probe.repeat(batch_size, 1, 1)
91
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
92
+ residual = hidden_state
93
+ hidden_state = self.layernorm(hidden_state)
94
+ hidden_state = residual + self.mlp(hidden_state)
95
+ return hidden_state[:, 0]
96
+
97
+
98
+ class SiglipVisionTransformer(nn.Module):
99
+ def __init__(
100
+ self,
101
+ hidden_size: int = 1152,
102
+ num_channels: int = 3,
103
+ image_size: int = 384,
104
+ patch_size: int = 14,
105
+ layer_num: int = 27,
106
+ inner_dim: int = 4304,
107
+ num_heads: int = 16,
108
+ eps: float = 1e-06,
109
+ device: str = "cpu",
110
+ dtype: torch.dtype = torch.bfloat16,
111
+ ):
112
+ super().__init__()
113
+ self.embeddings = SiglipVisionEmbeddings(
114
+ num_channels=num_channels,
115
+ num_positions=(image_size // patch_size) ** 2,
116
+ hidden_size=hidden_size,
117
+ patch_size=patch_size,
118
+ device=device,
119
+ dtype=dtype,
120
+ )
121
+ self.layers = nn.ModuleList(
122
+ [SiglipEncoderLayer(hidden_size, inner_dim, num_heads, eps, device, dtype) for _ in range(layer_num)]
123
+ )
124
+ self.post_layernorm = nn.LayerNorm(hidden_size, eps=eps, device=device, dtype=dtype)
125
+ self.head = SiglipMultiheadAttentionPoolingHead(
126
+ hidden_size, inner_dim=inner_dim, num_heads=num_heads, eps=eps, device=device, dtype=dtype
127
+ )
128
+
129
+ def forward(self, x):
130
+ x = self.embeddings(x)
131
+ for layer in self.layers:
132
+ x = layer(x)
133
+ x = self.post_layernorm(x)
134
+ x = self.head(x)
135
+ return x
136
+
137
+
138
+ class SiglipImageEncoderConverter(StateDictConverter):
139
+ def convert(self, state_dict: dict) -> dict:
140
+ return state_dict
141
+
142
+
143
+ class SiglipImageEncoder(PreTrainedModel):
144
+ converter = SiglipImageEncoderConverter()
145
+
146
+ def __init__(self, device: str, dtype: torch.dtype) -> None:
147
+ super().__init__()
148
+ self.image_encoder = SiglipVisionTransformer(device=device, dtype=dtype)
149
+
150
+ def image_preprocess(self, images: List[Image.Image]):
151
+ images = [image.resize(size=(384, 384), resample=3) for image in images]
152
+ rescaled_images = [np.array(image) / 255 for image in images]
153
+ normalized_images = [(image - 0.5) / 0.5 for image in rescaled_images]
154
+ image_tensor = torch.stack([torch.tensor(image) for image in normalized_images])
155
+ param = next(self.parameters())
156
+ image_tensor = image_tensor.to(param.device, param.dtype)
157
+ return rearrange(image_tensor, "b h w c -> b c h w")
158
+
159
+ @torch.no_grad()
160
+ def forward(self, images: List[Image.Image] | Image.Image):
161
+ if isinstance(images, Image.Image):
162
+ images = [images]
163
+ image_input = self.image_preprocess(images)
164
+ return self.image_encoder(image_input)
165
+
166
+ @classmethod
167
+ def from_pretrained(cls, pretrained_model_path: Union[str, os.PathLike], device: str, dtype: torch.dtype, **kwargs):
168
+ state_dict = load_file(str(pretrained_model_path))
169
+ return cls.from_state_dict(state_dict, device=device, dtype=dtype, **kwargs)
@@ -1,9 +1,11 @@
1
1
  from .flux_dit import FluxDiT, config as flux_dit_config
2
2
  from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2, config as flux_text_encoder_config
3
3
  from .flux_vae import FluxVAEDecoder, FluxVAEEncoder, config as flux_vae_config
4
+ from .flux_controlnet import FluxControlNet
4
5
 
5
6
  __all__ = [
6
7
  "FluxDiT",
8
+ "FluxControlNet",
7
9
  "FluxTextEncoder1",
8
10
  "FluxTextEncoder2",
9
11
  "FluxVAEDecoder",
@@ -0,0 +1,160 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, Dict
4
+ from einops import rearrange
5
+ from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
6
+ from diffsynth_engine.models.flux.flux_dit import (
7
+ FluxJointTransformerBlock,
8
+ RoPEEmbedding,
9
+ TimestepEmbeddings,
10
+ )
11
+ from diffsynth_engine.models.utils import no_init_weights
12
+
13
+
14
+ class FluxControlNetStateDictConverter(StateDictConverter):
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
19
+ new_state_dict = {}
20
+ for key, value in state_dict.items():
21
+ new_key = key
22
+ if "attn.to_q" in new_key:
23
+ q = state_dict[new_key]
24
+ k = state_dict[new_key.replace("attn.to_q", "attn.to_k")]
25
+ v = state_dict[new_key.replace("attn.to_q", "attn.to_v")]
26
+ new_key = new_key.replace("transformer_blocks", "blocks")
27
+ new_key = new_key.replace("attn.to_q", "attn.a_to_qkv")
28
+ new_state_dict[new_key] = torch.cat((q, k, v), dim=0)
29
+ elif "attn.add_q_proj" in new_key:
30
+ q = state_dict[new_key]
31
+ k = state_dict[new_key.replace("attn.add_q_proj", "attn.add_k_proj")]
32
+ v = state_dict[new_key.replace("attn.add_q_proj", "attn.add_v_proj")]
33
+ new_key = new_key.replace("transformer_blocks", "blocks")
34
+ new_key = new_key.replace("attn.add_q_proj", "attn.b_to_qkv")
35
+ new_state_dict[new_key.replace("attn.add_q_proj", "attn.b_to_qkv")] = torch.cat((q, k, v), dim=0)
36
+ elif (
37
+ "attn.to_k" in new_key
38
+ or "attn.to_v" in new_key
39
+ or "attn.add_k_proj" in new_key
40
+ or "attn.add_v_proj" in new_key
41
+ ):
42
+ continue
43
+ else:
44
+ new_key = new_key.replace("transformer_blocks", "blocks")
45
+ new_key = new_key.replace("controlnet_blocks", "blocks_proj")
46
+ new_key = new_key.replace("time_text_embed.guidance_embedder", "guidance_embedder")
47
+ new_key = new_key.replace("time_text_embed.timestep_embedder", "time_embedder")
48
+ new_key = new_key.replace("time_text_embed.text_embedder.linear_1", "pooled_text_embedder.0")
49
+ new_key = new_key.replace("time_text_embed.text_embedder.linear_2", "pooled_text_embedder.2")
50
+ new_key = new_key.replace("transformer_blocks", "blocks")
51
+ new_key = new_key.replace("time_embedder.linear_1", "time_embedder.timestep_embedder.0")
52
+ new_key = new_key.replace("time_embedder.linear_2", "time_embedder.timestep_embedder.2")
53
+ new_key = new_key.replace("guidance_embedder.linear_1", "guidance_embedder.timestep_embedder.0")
54
+ new_key = new_key.replace("guidance_embedder.linear_2", "guidance_embedder.timestep_embedder.2")
55
+ # joint block
56
+ new_key = new_key.replace("norm1.linear", "norm1_a.linear")
57
+ new_key = new_key.replace("norm1_context.linear", "norm1_b.linear")
58
+ new_key = new_key.replace("attn.to_out.0", "attn.a_to_out")
59
+ new_key = new_key.replace("attn.to_add_out", "attn.b_to_out")
60
+ new_key = new_key.replace("attn.norm_q", "attn.norm_q_a")
61
+ new_key = new_key.replace("attn.norm_k", "attn.norm_k_a")
62
+ new_key = new_key.replace("attn.norm_added_q", "attn.norm_q_b")
63
+ new_key = new_key.replace("attn.norm_added_k", "attn.norm_k_b")
64
+ new_key = new_key.replace("ff.net", "ff_a")
65
+ new_key = new_key.replace("ff_context.net", "ff_b")
66
+ new_key = new_key.replace("0.proj", "0")
67
+ new_state_dict[new_key] = value
68
+ return new_state_dict
69
+
70
+ def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
71
+ return self._from_diffusers(state_dict)
72
+
73
+
74
+ class FluxControlNet(PreTrainedModel):
75
+ converter = FluxControlNetStateDictConverter()
76
+
77
+ def __init__(
78
+ self,
79
+ condition_channels: int = 64,
80
+ attn_impl: Optional[str] = None,
81
+ device: str = "cuda:0",
82
+ dtype: torch.dtype = torch.bfloat16,
83
+ ):
84
+ super().__init__()
85
+ self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
86
+ self.time_embedder = TimestepEmbeddings(256, 3072, device=device, dtype=dtype)
87
+ self.guidance_embedder = TimestepEmbeddings(256, 3072, device=device, dtype=dtype)
88
+ self.pooled_text_embedder = nn.Sequential(
89
+ nn.Linear(768, 3072, device=device, dtype=dtype),
90
+ nn.SiLU(),
91
+ nn.Linear(3072, 3072, device=device, dtype=dtype),
92
+ )
93
+ self.context_embedder = nn.Linear(4096, 3072, device=device, dtype=dtype)
94
+ self.x_embedder = nn.Linear(64, 3072, device=device, dtype=dtype)
95
+ self.controlnet_x_embedder = nn.Linear(condition_channels, 3072)
96
+ self.blocks = nn.ModuleList(
97
+ [FluxJointTransformerBlock(3072, 24, attn_impl=attn_impl, device=device, dtype=dtype) for _ in range(6)]
98
+ )
99
+ # controlnet projection
100
+ self.blocks_proj = nn.ModuleList(
101
+ [nn.Linear(3072, 3072, device=device, dtype=dtype) for _ in range(len(self.blocks))]
102
+ )
103
+
104
+ def patchify(self, hidden_states):
105
+ hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
106
+ return hidden_states
107
+
108
+ def forward(
109
+ self,
110
+ hidden_states,
111
+ control_condition,
112
+ control_scale,
113
+ timestep,
114
+ prompt_emb,
115
+ pooled_prompt_emb,
116
+ guidance,
117
+ image_ids,
118
+ text_ids,
119
+ ):
120
+ hidden_states = self.patchify(hidden_states)
121
+ control_condition = self.patchify(control_condition)
122
+ hidden_states = self.x_embedder(hidden_states) + self.controlnet_x_embedder(control_condition)
123
+ condition = (
124
+ self.time_embedder(timestep, hidden_states.dtype)
125
+ + self.guidance_embedder(guidance * 1000, hidden_states.dtype)
126
+ + self.pooled_text_embedder(pooled_prompt_emb)
127
+ )
128
+ prompt_emb = self.context_embedder(prompt_emb)
129
+ image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
130
+
131
+ # double block
132
+ double_block_outputs = []
133
+ for i, block in enumerate(self.blocks):
134
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, condition, image_rotary_emb)
135
+ double_block_outputs.append(self.blocks_proj[i](hidden_states))
136
+
137
+ # apply control scale
138
+ double_block_outputs = [control_scale * output for output in double_block_outputs]
139
+ return double_block_outputs, None
140
+
141
+ @classmethod
142
+ def from_state_dict(
143
+ cls,
144
+ state_dict: Dict[str, torch.Tensor],
145
+ device: str,
146
+ dtype: torch.dtype,
147
+ attn_impl: Optional[str] = None,
148
+ ):
149
+ if "controlnet_x_embedder.weight" in state_dict:
150
+ condition_channels = state_dict["controlnet_x_embedder.weight"].shape[1]
151
+ else:
152
+ condition_channels = 64
153
+
154
+ with no_init_weights():
155
+ model = torch.nn.utils.skip_init(
156
+ cls, condition_channels=condition_channels, attn_impl=attn_impl, device=device, dtype=dtype
157
+ )
158
+ model.load_state_dict(state_dict)
159
+ model.to(device=device, dtype=dtype, non_blocking=True)
160
+ return model
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import torch
3
3
  import torch.nn as nn
4
+ import numpy as np
4
5
  from typing import Dict, Optional
5
6
  from einops import rearrange
6
7
 
@@ -327,7 +328,6 @@ class FluxDiT(PreTrainedModel):
327
328
 
328
329
  def __init__(
329
330
  self,
330
- disable_guidance_embedder=False,
331
331
  attn_impl: Optional[str] = None,
332
332
  device: str = "cuda:0",
333
333
  dtype: torch.dtype = torch.bfloat16,
@@ -335,9 +335,7 @@ class FluxDiT(PreTrainedModel):
335
335
  super().__init__()
336
336
  self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
337
337
  self.time_embedder = TimestepEmbeddings(256, 3072, device=device, dtype=dtype)
338
- self.guidance_embedder = (
339
- None if disable_guidance_embedder else TimestepEmbeddings(256, 3072, device=device, dtype=dtype)
340
- )
338
+ self.guidance_embedder = TimestepEmbeddings(256, 3072, device=device, dtype=dtype)
341
339
  self.pooled_text_embedder = nn.Sequential(
342
340
  nn.Linear(768, 3072, device=device, dtype=dtype),
343
341
  nn.SiLU(),
@@ -392,6 +390,8 @@ class FluxDiT(PreTrainedModel):
392
390
  text_ids,
393
391
  image_ids=None,
394
392
  use_gradient_checkpointing=False,
393
+ controlnet_double_block_output=None,
394
+ controlnet_single_block_output=None,
395
395
  **kwargs,
396
396
  ):
397
397
  fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
@@ -413,16 +413,10 @@ class FluxDiT(PreTrainedModel):
413
413
  hidden_states = self.patchify(hidden_states)
414
414
  hidden_states = self.x_embedder(hidden_states)
415
415
 
416
- def create_custom_forward(module):
417
- def custom_forward(*inputs):
418
- return module(*inputs)
419
-
420
- return custom_forward
421
-
422
- for block in self.blocks:
416
+ for i, block in enumerate(self.blocks):
423
417
  if self.training and use_gradient_checkpointing:
424
418
  hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
425
- create_custom_forward(block),
419
+ block,
426
420
  hidden_states,
427
421
  prompt_emb,
428
422
  conditioning,
@@ -431,12 +425,16 @@ class FluxDiT(PreTrainedModel):
431
425
  )
432
426
  else:
433
427
  hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
428
+ if controlnet_double_block_output is not None:
429
+ interval_control = len(self.blocks) / len(controlnet_double_block_output)
430
+ interval_control = int(np.ceil(interval_control))
431
+ hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
434
432
 
435
433
  hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
436
434
  for block in self.single_blocks:
437
435
  if self.training and use_gradient_checkpointing:
438
436
  hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
439
- create_custom_forward(block),
437
+ block,
440
438
  hidden_states,
441
439
  prompt_emb,
442
440
  conditioning,
@@ -445,12 +443,15 @@ class FluxDiT(PreTrainedModel):
445
443
  )
446
444
  else:
447
445
  hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
448
- hidden_states = hidden_states[:, prompt_emb.shape[1] :]
446
+ if controlnet_single_block_output is not None:
447
+ interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
448
+ interval_control = int(np.ceil(interval_control))
449
+ hidden_states = hidden_states + controlnet_single_block_output[i // interval_control]
449
450
 
451
+ hidden_states = hidden_states[:, prompt_emb.shape[1] :]
450
452
  hidden_states = self.final_norm_out(hidden_states, conditioning)
451
453
  hidden_states = self.final_proj_out(hidden_states)
452
454
  hidden_states = self.unpatchify(hidden_states, height, width)
453
-
454
455
  return hidden_states
455
456
 
456
457
  @classmethod
@@ -459,7 +460,6 @@ class FluxDiT(PreTrainedModel):
459
460
  state_dict: Dict[str, torch.Tensor],
460
461
  device: str,
461
462
  dtype: torch.dtype,
462
- disable_guidance_embedder: bool = False,
463
463
  attn_impl: Optional[str] = None,
464
464
  ):
465
465
  with no_init_weights():
@@ -467,7 +467,6 @@ class FluxDiT(PreTrainedModel):
467
467
  cls,
468
468
  device=device,
469
469
  dtype=dtype,
470
- disable_guidance_embedder=disable_guidance_embedder,
471
470
  attn_impl=attn_impl,
472
471
  )
473
472
  model = model.requires_grad_(False) # for loading gguf
@@ -268,16 +268,10 @@ class SD3DiT(PreTrainedModel):
268
268
  height, width = hidden_states.shape[-2:]
269
269
  hidden_states = self.pos_embedder(hidden_states)
270
270
 
271
- def create_custom_forward(module):
272
- def custom_forward(*inputs):
273
- return module(*inputs)
274
-
275
- return custom_forward
276
-
277
271
  for block in self.blocks:
278
272
  if self.training and use_gradient_checkpointing:
279
273
  hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
280
- create_custom_forward(block),
274
+ block,
281
275
  hidden_states,
282
276
  prompt_emb,
283
277
  conditioning,
@@ -260,12 +260,6 @@ class SDXLUNet(PreTrainedModel):
260
260
  res_stack = [hidden_states]
261
261
 
262
262
  # 3. blocks
263
- def create_custom_forward(module):
264
- def custom_forward(*inputs):
265
- return module(*inputs)
266
-
267
- return custom_forward
268
-
269
263
  for i, block in enumerate(self.blocks):
270
264
  if (
271
265
  self.training
@@ -273,7 +267,7 @@ class SDXLUNet(PreTrainedModel):
273
267
  and not (isinstance(block, PushBlock) or isinstance(block, PopBlock))
274
268
  ):
275
269
  hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
276
- create_custom_forward(block),
270
+ block,
277
271
  hidden_states,
278
272
  time_emb,
279
273
  text_emb,
@@ -166,6 +166,7 @@ class CrossAttention(nn.Module):
166
166
  if self.has_image_input:
167
167
  k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
168
168
  k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
169
+ v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
169
170
  y = attention(q, k_img, v_img, attn_impl=self.attn_impl).flatten(2)
170
171
  x = x + y
171
172
  return self.o(x)
@@ -1,5 +1,5 @@
1
1
  from .base import BasePipeline, LoRAStateDictConverter
2
- from .flux_image import FluxImagePipeline, FluxModelConfig
2
+ from .flux_image import FluxImagePipeline, FluxModelConfig, ControlNetParams
3
3
  from .sdxl_image import SDXLImagePipeline, SDXLModelConfig
4
4
  from .sd_image import SDImagePipeline, SDModelConfig
5
5
  from .wan_video import WanVideoPipeline, WanModelConfig
@@ -15,4 +15,5 @@ __all__ = [
15
15
  "SDModelConfig",
16
16
  "WanVideoPipeline",
17
17
  "WanModelConfig",
18
+ "ControlNetParams",
18
19
  ]