diffsynth-engine 0.2.1__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (228) hide show
  1. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/PKG-INFO +1 -1
  2. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/basic/transformer_helper.py +21 -9
  3. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/components/vae.py +15 -2
  4. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/flux/flux_dit.py +11 -4
  5. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/sd/sd_vae.py +18 -7
  6. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/sdxl/sdxl_vae.py +18 -7
  7. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/wan/wan_dit.py +1 -20
  8. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/pipelines/flux_image.py +25 -2
  9. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/pipelines/sd_image.py +6 -2
  10. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/pipelines/sdxl_image.py +11 -2
  11. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/utils/download.py +4 -2
  12. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine.egg-info/PKG-INFO +1 -1
  13. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/.github/workflows/python-publish.yml +0 -0
  14. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/.gitignore +0 -0
  15. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/.pre-commit-config.yaml +0 -0
  16. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/LICENSE +0 -0
  17. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/README.md +0 -0
  18. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/assets/dingtalk.png +0 -0
  19. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/assets/showcase.jpeg +0 -0
  20. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/__init__.py +0 -0
  21. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/__init__.py +0 -0
  22. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
  23. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
  24. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
  25. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
  26. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
  27. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
  28. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  29. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
  30. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
  31. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
  32. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
  33. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
  34. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
  35. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
  36. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  37. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
  38. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  39. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
  40. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
  41. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
  42. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
  43. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
  44. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
  45. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
  46. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
  47. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
  48. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/models/components/vae.json +0 -0
  49. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
  50. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
  51. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
  52. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
  53. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
  54. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
  55. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
  56. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
  57. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
  58. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +0 -0
  59. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/models/wan/dit/14b-i2v.json +0 -0
  60. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/models/wan/dit/14b-t2v.json +0 -0
  61. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
  62. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
  63. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
  64. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
  65. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
  66. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  67. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
  68. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
  69. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
  70. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
  71. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
  72. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
  73. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
  74. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
  75. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
  76. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
  77. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
  78. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  79. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
  80. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
  81. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/kernels/__init__.py +0 -0
  82. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/__init__.py +0 -0
  83. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/base.py +0 -0
  84. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/basic/__init__.py +0 -0
  85. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/basic/attention.py +0 -0
  86. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/basic/lora.py +0 -0
  87. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
  88. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/basic/timestep.py +0 -0
  89. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/basic/unet_helper.py +0 -0
  90. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/components/__init__.py +0 -0
  91. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/components/clip.py +0 -0
  92. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/components/siglip.py +0 -0
  93. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/components/t5.py +0 -0
  94. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/flux/__init__.py +0 -0
  95. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/flux/flux_controlnet.py +0 -0
  96. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
  97. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/flux/flux_vae.py +0 -0
  98. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/sd/__init__.py +0 -0
  99. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
  100. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/sd/sd_unet.py +0 -0
  101. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/sd3/__init__.py +0 -0
  102. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/sd3/sd3_dit.py +0 -0
  103. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
  104. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
  105. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/sdxl/__init__.py +0 -0
  106. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
  107. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/sdxl/sdxl_unet.py +0 -0
  108. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/utils.py +0 -0
  109. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/wan/__init__.py +0 -0
  110. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
  111. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
  112. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/models/wan/wan_vae.py +0 -0
  113. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/pipelines/__init__.py +0 -0
  114. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/pipelines/base.py +0 -0
  115. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/pipelines/wan_video.py +0 -0
  116. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/tokenizers/__init__.py +0 -0
  117. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/tokenizers/base.py +0 -0
  118. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/tokenizers/clip.py +0 -0
  119. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/tokenizers/t5.py +0 -0
  120. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/tokenizers/wan.py +0 -0
  121. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/tools/__init__.py +0 -0
  122. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/tools/flux_inpainting.py +0 -0
  123. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/tools/flux_outpainting.py +0 -0
  124. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/utils/__init__.py +0 -0
  125. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/utils/constants.py +0 -0
  126. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/utils/env.py +0 -0
  127. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/utils/flag.py +0 -0
  128. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/utils/fp8_linear.py +0 -0
  129. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/utils/gguf.py +0 -0
  130. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/utils/image.py +0 -0
  131. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/utils/loader.py +0 -0
  132. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/utils/lock.py +0 -0
  133. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/utils/logging.py +0 -0
  134. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/utils/offload.py +0 -0
  135. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/utils/parallel.py +0 -0
  136. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/utils/prompt.py +0 -0
  137. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine/utils/video.py +0 -0
  138. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine.egg-info/SOURCES.txt +0 -0
  139. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
  140. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine.egg-info/requires.txt +0 -0
  141. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/diffsynth_engine.egg-info/top_level.txt +0 -0
  142. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/docs/tutorial.md +0 -0
  143. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/docs/tutorial_zh.md +0 -0
  144. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/examples/flux_lora.py +0 -0
  145. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/examples/flux_text_to_image.py +0 -0
  146. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/examples/i2v_input.jpg +0 -0
  147. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/examples/sdxl_text_to_image.py +0 -0
  148. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/examples/wan_image_to_video.py +0 -0
  149. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/examples/wan_lora.py +0 -0
  150. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/examples/wan_text_to_video.py +0 -0
  151. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/pyproject.toml +0 -0
  152. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/setup.cfg +0 -0
  153. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/setup.py +0 -0
  154. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/__init__.py +0 -0
  155. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/common/__init__.py +0 -0
  156. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/common/test_case.py +0 -0
  157. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/common/utils.py +0 -0
  158. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/algorithm/beta_20steps.safetensors +0 -0
  159. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/algorithm/ddim_20steps.safetensors +0 -0
  160. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/algorithm/euler_i10.safetensors +0 -0
  161. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/algorithm/exponential_20steps.safetensors +0 -0
  162. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/algorithm/flow_match_euler_i10.safetensors +0 -0
  163. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/algorithm/karras_20steps.safetensors +0 -0
  164. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/algorithm/output.safetensors +0 -0
  165. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/algorithm/recifited_flow_20steps_flux.safetensors +0 -0
  166. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/algorithm/scaled_linear_20steps.safetensors +0 -0
  167. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/algorithm/sgm_uniform_20steps.safetensors +0 -0
  168. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/flux/flux_dit.safetensors +0 -0
  169. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/flux/flux_inpainting.png +0 -0
  170. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/flux/flux_lora.png +0 -0
  171. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/flux/flux_outpainting.png +0 -0
  172. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/flux/flux_text_encoder_1.safetensors +0 -0
  173. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/flux/flux_text_encoder_2.safetensors +0 -0
  174. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/flux/flux_txt2img.png +0 -0
  175. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/flux/flux_union_pro_canny.png +0 -0
  176. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/flux/flux_vae.safetensors +0 -0
  177. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/sd/sd_inpainting.png +0 -0
  178. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/sd/sd_lora.png +0 -0
  179. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/sd/sd_text_encoder.safetensors +0 -0
  180. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/sd/sd_txt2img.png +0 -0
  181. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/sd/sd_unet.safetensors +0 -0
  182. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/sd/sd_vae.safetensors +0 -0
  183. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/sdxl/sdxl_inpainting.png +0 -0
  184. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/sdxl/sdxl_lora.png +0 -0
  185. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/sdxl/sdxl_text_encoder_1.safetensors +0 -0
  186. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/sdxl/sdxl_text_encoder_2.safetensors +0 -0
  187. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/sdxl/sdxl_txt2img.png +0 -0
  188. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/sdxl/sdxl_unet.safetensors +0 -0
  189. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/sdxl/sdxl_vae.safetensors +0 -0
  190. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/test_siglip_image_encoder.safetensors +0 -0
  191. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/expect/wan/wan_vae.safetensors +0 -0
  192. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/input/astronaut_320_320.mp4 +0 -0
  193. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/input/canny.png +0 -0
  194. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/input/mask_image.png +0 -0
  195. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/input/test_image.png +0 -0
  196. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/input/wukong_1024_1024.png +0 -0
  197. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/data/input/wukong_480_480.png +0 -0
  198. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_algorithm/__init__.py +0 -0
  199. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_algorithm/test_sampler.py +0 -0
  200. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_algorithm/test_scheduler.py +0 -0
  201. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_models/__init__.py +0 -0
  202. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_models/flux/__init__.py +0 -0
  203. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_models/flux/test_flux_dit.py +0 -0
  204. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_models/flux/test_flux_text_encoder.py +0 -0
  205. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_models/flux/test_flux_vae.py +0 -0
  206. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_models/sd/__init__.py +0 -0
  207. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_models/sd/test_sd_text_encoder.py +0 -0
  208. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_models/sd/test_sd_unet.py +0 -0
  209. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_models/sd/test_sd_vae.py +0 -0
  210. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_models/sdxl/__init__.py +0 -0
  211. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_models/sdxl/test_sdxl_text_encoder.py +0 -0
  212. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_models/sdxl/test_sdxl_unet.py +0 -0
  213. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_models/sdxl/test_sdxl_vae.py +0 -0
  214. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_models/test_siglip.py +0 -0
  215. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_models/wan/test_wan_vae.py +0 -0
  216. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_pipelines/__init__.py +0 -0
  217. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_pipelines/test_flux_controlnet.py +0 -0
  218. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_pipelines/test_flux_image.py +0 -0
  219. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_pipelines/test_sd_image.py +0 -0
  220. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_pipelines/test_sdxl_image.py +0 -0
  221. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_pipelines/test_wan_video.py +0 -0
  222. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_pipelines/test_wan_video_gguf.py +0 -0
  223. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_pipelines/test_wan_video_tp.py +0 -0
  224. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_tokenizers/__init__.py +0 -0
  225. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_tokenizers/test_clip.py +0 -0
  226. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_tokenizers/test_t5.py +0 -0
  227. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_tools/__init__.py +0 -0
  228. {diffsynth_engine-0.2.1 → diffsynth_engine-0.2.2}/tests/test_tools/test_flux_tools.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -65,17 +65,29 @@ class RoPEEmbedding(nn.Module):
65
65
 
66
66
 
67
67
  class RMSNorm(nn.Module):
68
- def __init__(self, dim, eps, device: str, dtype: torch.dtype):
68
+ def __init__(
69
+ self,
70
+ dim,
71
+ eps=1e-5,
72
+ elementwise_affine=True,
73
+ device: str = "cuda:0",
74
+ dtype: torch.dtype = torch.bfloat16,
75
+ ):
69
76
  super().__init__()
70
- self.weight = nn.Parameter(torch.ones((dim,), device=device, dtype=dtype))
71
77
  self.eps = eps
72
-
73
- def forward(self, hidden_states):
74
- input_dtype = hidden_states.dtype
75
- variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
76
- hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
77
- hidden_states = hidden_states.to(input_dtype) * self.weight
78
- return hidden_states
78
+ self.dim = dim
79
+ self.elementwise_affine = elementwise_affine
80
+ if elementwise_affine:
81
+ self.weight = nn.Parameter(torch.ones(dim, device=device, dtype=dtype))
82
+
83
+ def norm(self, x):
84
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
85
+
86
+ def forward(self, x):
87
+ norm_result = self.norm(x.float()).to(x.dtype)
88
+ if self.elementwise_affine:
89
+ return norm_result * self.weight
90
+ return norm_result
79
91
 
80
92
 
81
93
  class NewGELUActivation(nn.Module):
@@ -67,6 +67,7 @@ class VAEAttentionBlock(nn.Module):
67
67
  num_layers=1,
68
68
  norm_num_groups=32,
69
69
  eps=1e-5,
70
+ attn_impl: str = "auto",
70
71
  device: str = "cuda:0",
71
72
  dtype: torch.dtype = torch.float32,
72
73
  ):
@@ -86,6 +87,7 @@ class VAEAttentionBlock(nn.Module):
86
87
  bias_q=True,
87
88
  bias_kv=True,
88
89
  bias_out=True,
90
+ attn_impl=attn_impl,
89
91
  device=device,
90
92
  dtype=dtype,
91
93
  )
@@ -119,6 +121,7 @@ class VAEDecoder(PreTrainedModel):
119
121
  scaling_factor: float = 0.18215,
120
122
  shift_factor: float = 0,
121
123
  use_post_quant_conv: bool = True,
124
+ attn_impl: str = "auto",
122
125
  device: str = "cuda:0",
123
126
  dtype: torch.dtype = torch.float32,
124
127
  ):
@@ -137,7 +140,7 @@ class VAEDecoder(PreTrainedModel):
137
140
  [
138
141
  # UNetMidBlock2D
139
142
  ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
140
- VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, device=device, dtype=dtype),
143
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, device=device, dtype=dtype, attn_impl=attn_impl),
141
144
  ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
142
145
  # UpDecoderBlock2D
143
146
  ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
@@ -202,6 +205,7 @@ class VAEDecoder(PreTrainedModel):
202
205
  scaling_factor: float = 0.18215,
203
206
  shift_factor: float = 0,
204
207
  use_post_quant_conv: bool = True,
208
+ attn_impl: str = "auto",
205
209
  ):
206
210
  with no_init_weights():
207
211
  model = torch.nn.utils.skip_init(
@@ -210,6 +214,7 @@ class VAEDecoder(PreTrainedModel):
210
214
  scaling_factor=scaling_factor,
211
215
  shift_factor=shift_factor,
212
216
  use_post_quant_conv=use_post_quant_conv,
217
+ attn_impl=attn_impl,
213
218
  device=device,
214
219
  dtype=dtype,
215
220
  )
@@ -230,6 +235,7 @@ class VAEEncoder(PreTrainedModel):
230
235
  scaling_factor: float = 0.18215,
231
236
  shift_factor: float = 0,
232
237
  use_quant_conv: bool = True,
238
+ attn_impl: str = "auto",
233
239
  device: str = "cuda:0",
234
240
  dtype: torch.dtype = torch.float32,
235
241
  ):
@@ -263,7 +269,7 @@ class VAEEncoder(PreTrainedModel):
263
269
  ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
264
270
  # UNetMidBlock2D
265
271
  ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
266
- VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, device=device, dtype=dtype),
272
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, device=device, dtype=dtype, attn_impl=attn_impl),
267
273
  ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
268
274
  ]
269
275
  )
@@ -309,6 +315,7 @@ class VAEEncoder(PreTrainedModel):
309
315
  scaling_factor: float = 0.18215,
310
316
  shift_factor: float = 0,
311
317
  use_quant_conv: bool = True,
318
+ attn_impl: str = "auto",
312
319
  ):
313
320
  with no_init_weights():
314
321
  model = torch.nn.utils.skip_init(
@@ -317,6 +324,7 @@ class VAEEncoder(PreTrainedModel):
317
324
  scaling_factor=scaling_factor,
318
325
  shift_factor=shift_factor,
319
326
  use_quant_conv=use_quant_conv,
327
+ attn_impl=attn_impl,
320
328
  device=device,
321
329
  dtype=dtype,
322
330
  )
@@ -338,6 +346,7 @@ class VAE(PreTrainedModel):
338
346
  shift_factor: float = 0,
339
347
  use_quant_conv: bool = True,
340
348
  use_post_quant_conv: bool = True,
349
+ attn_impl: str = "auto",
341
350
  device: str = "cuda:0",
342
351
  dtype: torch.dtype = torch.float32,
343
352
  ):
@@ -347,6 +356,7 @@ class VAE(PreTrainedModel):
347
356
  scaling_factor=scaling_factor,
348
357
  shift_factor=shift_factor,
349
358
  use_quant_conv=use_quant_conv,
359
+ attn_impl=attn_impl,
350
360
  device=device,
351
361
  dtype=dtype,
352
362
  )
@@ -355,6 +365,7 @@ class VAE(PreTrainedModel):
355
365
  scaling_factor=scaling_factor,
356
366
  shift_factor=shift_factor,
357
367
  use_post_quant_conv=use_post_quant_conv,
368
+ attn_impl=attn_impl,
358
369
  device=device,
359
370
  dtype=dtype,
360
371
  )
@@ -376,6 +387,7 @@ class VAE(PreTrainedModel):
376
387
  shift_factor: float = 0,
377
388
  use_quant_conv: bool = True,
378
389
  use_post_quant_conv: bool = True,
390
+ attn_impl: str = "auto",
379
391
  ):
380
392
  with no_init_weights():
381
393
  model = torch.nn.utils.skip_init(
@@ -385,6 +397,7 @@ class VAE(PreTrainedModel):
385
397
  shift_factor=shift_factor,
386
398
  use_quant_conv=use_quant_conv,
387
399
  use_post_quant_conv=use_post_quant_conv,
400
+ attn_impl=attn_impl,
388
401
  device=device,
389
402
  dtype=dtype,
390
403
  )
@@ -227,7 +227,7 @@ class FluxJointTransformerBlock(nn.Module):
227
227
  nn.Linear(dim * 4, dim, device=device, dtype=dtype),
228
228
  )
229
229
 
230
- def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
230
+ def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, image_emb):
231
231
  norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
232
232
  norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
233
233
 
@@ -293,7 +293,7 @@ class FluxSingleTransformerBlock(nn.Module):
293
293
  hidden_states = hidden_states.to(q.dtype)
294
294
  return hidden_states
295
295
 
296
- def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
296
+ def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, image_emb):
297
297
  residual = hidden_states_a
298
298
  norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
299
299
  hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
@@ -386,6 +386,7 @@ class FluxDiT(PreTrainedModel):
386
386
  timestep,
387
387
  prompt_emb,
388
388
  pooled_prompt_emb,
389
+ image_emb,
389
390
  guidance,
390
391
  text_ids,
391
392
  image_ids=None,
@@ -421,10 +422,13 @@ class FluxDiT(PreTrainedModel):
421
422
  prompt_emb,
422
423
  conditioning,
423
424
  image_rotary_emb,
425
+ image_emb,
424
426
  use_reentrant=False,
425
427
  )
426
428
  else:
427
- hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
429
+ hidden_states, prompt_emb = block(
430
+ hidden_states, prompt_emb, conditioning, image_rotary_emb, image_emb
431
+ )
428
432
  if controlnet_double_block_output is not None:
429
433
  interval_control = len(self.blocks) / len(controlnet_double_block_output)
430
434
  interval_control = int(np.ceil(interval_control))
@@ -439,10 +443,13 @@ class FluxDiT(PreTrainedModel):
439
443
  prompt_emb,
440
444
  conditioning,
441
445
  image_rotary_emb,
446
+ image_emb,
442
447
  use_reentrant=False,
443
448
  )
444
449
  else:
445
- hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
450
+ hidden_states, prompt_emb = block(
451
+ hidden_states, prompt_emb, conditioning, image_rotary_emb, image_emb
452
+ )
446
453
  if controlnet_single_block_output is not None:
447
454
  interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
448
455
  interval_control = int(np.ceil(interval_control))
@@ -6,33 +6,44 @@ from diffsynth_engine.models.utils import no_init_weights
6
6
 
7
7
 
8
8
  class SDVAEEncoder(VAEEncoder):
9
- def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
9
+ def __init__(self, attn_impl: str = "auto", device: str = "cuda:0", dtype: torch.dtype = torch.float32):
10
10
  super().__init__(
11
- latent_channels=4, scaling_factor=0.18215, shift_factor=0, use_quant_conv=True, device=device, dtype=dtype
11
+ latent_channels=4,
12
+ scaling_factor=0.18215,
13
+ shift_factor=0,
14
+ use_quant_conv=True,
15
+ attn_impl=attn_impl,
16
+ device=device,
17
+ dtype=dtype,
12
18
  )
13
19
 
14
20
  @classmethod
15
- def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
21
+ def from_state_dict(
22
+ cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype, attn_impl: str = "auto"
23
+ ):
16
24
  with no_init_weights():
17
- model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
25
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, attn_impl=attn_impl)
18
26
  model.load_state_dict(state_dict)
19
27
  return model
20
28
 
21
29
 
22
30
  class SDVAEDecoder(VAEDecoder):
23
- def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
31
+ def __init__(self, attn_impl: str = "auto", device: str = "cuda:0", dtype: torch.dtype = torch.float32):
24
32
  super().__init__(
25
33
  latent_channels=4,
26
34
  scaling_factor=0.18215,
27
35
  shift_factor=0,
28
36
  use_post_quant_conv=True,
37
+ attn_impl=attn_impl,
29
38
  device=device,
30
39
  dtype=dtype,
31
40
  )
32
41
 
33
42
  @classmethod
34
- def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
43
+ def from_state_dict(
44
+ cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype, attn_impl: str = "auto"
45
+ ):
35
46
  with no_init_weights():
36
- model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
47
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, attn_impl=attn_impl)
37
48
  model.load_state_dict(state_dict)
38
49
  return model
@@ -6,33 +6,44 @@ from diffsynth_engine.models.utils import no_init_weights
6
6
 
7
7
 
8
8
  class SDXLVAEEncoder(VAEEncoder):
9
- def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
9
+ def __init__(self, attn_impl: str = "auto", device: str = "cuda:0", dtype: torch.dtype = torch.float32):
10
10
  super().__init__(
11
- latent_channels=4, scaling_factor=0.13025, shift_factor=0, use_quant_conv=True, device=device, dtype=dtype
11
+ latent_channels=4,
12
+ scaling_factor=0.13025,
13
+ shift_factor=0,
14
+ use_quant_conv=True,
15
+ attn_impl=attn_impl,
16
+ device=device,
17
+ dtype=dtype,
12
18
  )
13
19
 
14
20
  @classmethod
15
- def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
21
+ def from_state_dict(
22
+ cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype, attn_impl: str = "auto"
23
+ ):
16
24
  with no_init_weights():
17
- model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
25
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, attn_impl=attn_impl)
18
26
  model.load_state_dict(state_dict)
19
27
  return model
20
28
 
21
29
 
22
30
  class SDXLVAEDecoder(VAEDecoder):
23
- def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
31
+ def __init__(self, attn_impl: str = "auto", device: str = "cuda:0", dtype: torch.dtype = torch.float32):
24
32
  super().__init__(
25
33
  latent_channels=4,
26
34
  scaling_factor=0.13025,
27
35
  shift_factor=0,
28
36
  use_post_quant_conv=True,
37
+ attn_impl=attn_impl,
29
38
  device=device,
30
39
  dtype=dtype,
31
40
  )
32
41
 
33
42
  @classmethod
34
- def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
43
+ def from_state_dict(
44
+ cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype, attn_impl: str = "auto"
45
+ ):
35
46
  with no_init_weights():
36
- model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
47
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, attn_impl=attn_impl)
37
48
  model.load_state_dict(state_dict)
38
49
  return model
@@ -8,6 +8,7 @@ from einops import rearrange
8
8
 
9
9
  from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
10
10
  from diffsynth_engine.models.basic.attention import attention, long_context_attention
11
+ from diffsynth_engine.models.basic.transformer_helper import RMSNorm
11
12
  from diffsynth_engine.models.utils import no_init_weights
12
13
  from diffsynth_engine.utils.constants import (
13
14
  WAN_DIT_1_3B_T2V_CONFIG_FILE,
@@ -57,26 +58,6 @@ def rope_apply(x, freqs):
57
58
  return x_out.to(x.dtype).flatten(3)
58
59
 
59
60
 
60
- class RMSNorm(nn.Module):
61
- def __init__(
62
- self,
63
- dim,
64
- eps=1e-5,
65
- device: str = "cuda:0",
66
- dtype: torch.dtype = torch.bfloat16,
67
- ):
68
- super().__init__()
69
- self.eps = eps
70
- self.dim = dim
71
- self.weight = nn.Parameter(torch.ones(dim, device=device, dtype=dtype))
72
-
73
- def norm(self, x):
74
- return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
75
-
76
- def forward(self, x):
77
- return self.norm(x.float()).to(x.dtype) * self.weight
78
-
79
-
80
61
  class SelfAttention(nn.Module):
81
62
  def __init__(
82
63
  self,
@@ -366,6 +366,7 @@ class FluxImagePipeline(BasePipeline):
366
366
  negative_prompt_emb: torch.Tensor,
367
367
  positive_add_text_embeds: torch.Tensor,
368
368
  negative_add_text_embeds: torch.Tensor,
369
+ image_emb: torch.Tensor | None,
369
370
  image_ids: torch.Tensor,
370
371
  text_ids: torch.Tensor,
371
372
  cfg_scale: float,
@@ -382,6 +383,7 @@ class FluxImagePipeline(BasePipeline):
382
383
  timestep,
383
384
  positive_prompt_emb,
384
385
  positive_add_text_embeds,
386
+ image_emb,
385
387
  image_ids,
386
388
  text_ids,
387
389
  guidance,
@@ -396,6 +398,7 @@ class FluxImagePipeline(BasePipeline):
396
398
  timestep,
397
399
  positive_prompt_emb,
398
400
  positive_add_text_embeds,
401
+ image_emb,
399
402
  image_ids,
400
403
  text_ids,
401
404
  guidance,
@@ -408,6 +411,7 @@ class FluxImagePipeline(BasePipeline):
408
411
  timestep,
409
412
  negative_prompt_emb,
410
413
  negative_add_text_embeds,
414
+ image_emb,
411
415
  image_ids,
412
416
  text_ids,
413
417
  guidance,
@@ -428,6 +432,7 @@ class FluxImagePipeline(BasePipeline):
428
432
  timestep,
429
433
  prompt_emb,
430
434
  add_text_embeds,
435
+ image_emb,
431
436
  image_ids,
432
437
  text_ids,
433
438
  guidance,
@@ -444,6 +449,7 @@ class FluxImagePipeline(BasePipeline):
444
449
  timestep: torch.Tensor,
445
450
  prompt_emb: torch.Tensor,
446
451
  add_text_embeds: torch.Tensor,
452
+ image_emb: torch.Tensor | None,
447
453
  image_ids: torch.Tensor,
448
454
  text_ids: torch.Tensor,
449
455
  guidance: float,
@@ -468,6 +474,7 @@ class FluxImagePipeline(BasePipeline):
468
474
  timestep=timestep,
469
475
  prompt_emb=prompt_emb,
470
476
  pooled_prompt_emb=add_text_embeds,
477
+ image_emb=image_emb,
471
478
  guidance=guidance,
472
479
  text_ids=text_ids,
473
480
  image_ids=image_ids,
@@ -579,14 +586,24 @@ class FluxImagePipeline(BasePipeline):
579
586
  def enable_fp8_linear(self):
580
587
  enable_fp8_linear(self.dit)
581
588
 
589
+ def load_ip_adapter(self, ip_adapter):
590
+ self.ip_adapter = ip_adapter
591
+ self.ip_adapter.inject(self.dit)
592
+
593
+ def unload_ip_adapter(self):
594
+ if self.ip_adapter is not None:
595
+ self.ip_adapter.remove(self.dit)
596
+ self.ip_adapter = None
597
+
582
598
  @torch.no_grad()
583
599
  def __call__(
584
600
  self,
585
601
  prompt: str,
586
602
  negative_prompt: str = "",
587
- cfg_scale: float = 1.0,
603
+ ref_image: Image.Image | None = None, # use for ip-adapter, instance-id
604
+ cfg_scale: float = 1.0, # 官方的flux模型不支持cfg调整
588
605
  clip_skip: int = 2,
589
- input_image: Image.Image | None = None,
606
+ input_image: Image.Image | None = None, # use for img2img
590
607
  denoising_strength: float = 1.0,
591
608
  height: int = 1024,
592
609
  width: int = 1024,
@@ -624,6 +641,11 @@ class FluxImagePipeline(BasePipeline):
624
641
  # ControlNet
625
642
  controlnet_params = self.prepare_controlnet_params(controlnet_params, h=height, w=width)
626
643
 
644
+ # image_emb
645
+ image_emb = (
646
+ self.ip_adapter.encode_image(ref_image) if self.ip_adapter is not None and ref_image is not None else None
647
+ )
648
+
627
649
  # Denoise
628
650
  self.load_models_to_device(["dit"])
629
651
  for i, timestep in enumerate(tqdm(timesteps)):
@@ -635,6 +657,7 @@ class FluxImagePipeline(BasePipeline):
635
657
  negative_prompt_emb=negative_prompt_emb,
636
658
  positive_add_text_embeds=positive_add_text_embeds,
637
659
  negative_add_text_embeds=negative_add_text_embeds,
660
+ image_emb=image_emb,
638
661
  image_ids=image_ids,
639
662
  text_ids=text_ids,
640
663
  cfg_scale=cfg_scale,
@@ -217,8 +217,12 @@ class SDImagePipeline(BasePipeline):
217
217
  clip_state_dict, device=init_device, dtype=model_config.clip_dtype
218
218
  )
219
219
  unet = SDUNet.from_state_dict(unet_state_dict, device=init_device, dtype=model_config.unet_dtype)
220
- vae_decoder = SDVAEDecoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
221
- vae_encoder = SDVAEEncoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
220
+ vae_decoder = SDVAEDecoder.from_state_dict(
221
+ vae_state_dict, device=init_device, dtype=model_config.vae_dtype, attn_impl="sdpa"
222
+ )
223
+ vae_encoder = SDVAEEncoder.from_state_dict(
224
+ vae_state_dict, device=init_device, dtype=model_config.vae_dtype, attn_impl="sdpa"
225
+ )
222
226
 
223
227
  pipe = cls(
224
228
  tokenizer=tokenizer,
@@ -203,8 +203,12 @@ class SDXLImagePipeline(BasePipeline):
203
203
  clip_g_state_dict, device=init_device, dtype=model_config.clip_g_dtype
204
204
  )
205
205
  unet = SDXLUNet.from_state_dict(unet_state_dict, device=init_device, dtype=model_config.unet_dtype)
206
- vae_decoder = SDXLVAEDecoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
207
- vae_encoder = SDXLVAEEncoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
206
+ vae_decoder = SDXLVAEDecoder.from_state_dict(
207
+ vae_state_dict, device=init_device, dtype=model_config.vae_dtype, attn_impl="sdpa"
208
+ )
209
+ vae_encoder = SDXLVAEEncoder.from_state_dict(
210
+ vae_state_dict, device=init_device, dtype=model_config.vae_dtype, attn_impl="sdpa"
211
+ )
208
212
 
209
213
  pipe = cls(
210
214
  tokenizer=tokenizer,
@@ -387,6 +391,11 @@ class SDXLImagePipeline(BasePipeline):
387
391
  self.load_models_to_device(["unet"])
388
392
  for i, timestep in enumerate(tqdm(timesteps)):
389
393
  timestep = timestep.unsqueeze(0).to(dtype=self.dtype)
394
+ positive_prompt_emb = positive_prompt_emb.to(self.dtype)
395
+ negative_prompt_emb = negative_prompt_emb.to(self.dtype)
396
+ positive_add_text_embeds = positive_add_text_embeds.to(self.dtype)
397
+ negative_add_text_embeds = negative_add_text_embeds.to(self.dtype)
398
+ add_time_id = add_time_id.to(self.dtype)
390
399
  # Classifier-free guidance
391
400
  noise_pred = self.predict_noise_with_cfg(
392
401
  latents=latents,
@@ -26,9 +26,10 @@ def fetch_model(
26
26
  path: Optional[str] = None,
27
27
  access_token: Optional[str] = None,
28
28
  source: str = "modelscope",
29
+ fetch_safetensors: bool = True,
29
30
  ) -> str:
30
31
  if source == "modelscope":
31
- return fetch_modelscope_model(model_uri, revision, path, access_token)
32
+ return fetch_modelscope_model(model_uri, revision, path, access_token, fetch_safetensors)
32
33
  if source == "civitai":
33
34
  return fetch_civitai_model(model_uri)
34
35
  raise ValueError(f'source should be one of {MODEL_SOURCES} but got "{source}"')
@@ -39,6 +40,7 @@ def fetch_modelscope_model(
39
40
  revision: Optional[str] = None,
40
41
  path: Optional[str] = None,
41
42
  access_token: Optional[str] = None,
43
+ fetch_safetensors: bool = True,
42
44
  ) -> str:
43
45
  lock_file_name = f"modelscope.{model_id.replace('/', '--')}.{revision if revision else '__version'}.lock"
44
46
  lock_file_path = os.path.join(DIFFSYNTH_FILELOCK_DIR, lock_file_name)
@@ -55,7 +57,7 @@ def fetch_modelscope_model(
55
57
  else:
56
58
  path = dirpath
57
59
 
58
- if os.path.isdir(path):
60
+ if os.path.isdir(path) and fetch_safetensors:
59
61
  return _fetch_safetensors(path)
60
62
  return path
61
63
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent