diffsynth-engine 0.2.4__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/PKG-INFO +2 -1
  2. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/README.md +9 -2
  3. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/basic/attention.py +0 -2
  4. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/flux/flux_dit.py +70 -26
  5. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/utils.py +0 -1
  6. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/vae/vae.py +4 -0
  7. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/wan/wan_dit.py +17 -42
  8. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/pipelines/base.py +65 -10
  9. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/pipelines/flux_image.py +137 -25
  10. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/pipelines/sd_image.py +3 -0
  11. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/pipelines/sdxl_image.py +3 -0
  12. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/pipelines/wan_video.py +11 -39
  13. diffsynth_engine-0.2.6/diffsynth_engine/processor/canny_processor.py +21 -0
  14. diffsynth_engine-0.2.6/diffsynth_engine/processor/depth_processor.py +42 -0
  15. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tools/flux_inpainting_tool.py +3 -1
  16. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tools/flux_outpainting_tool.py +3 -1
  17. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tools/flux_reference_tool.py +1 -1
  18. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tools/flux_replace_tool.py +1 -1
  19. diffsynth_engine-0.2.6/diffsynth_engine/utils/__init__.py +0 -0
  20. diffsynth_engine-0.2.6/diffsynth_engine/utils/onnx.py +33 -0
  21. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/parallel.py +64 -2
  22. diffsynth_engine-0.2.6/diffsynth_engine/utils/platform.py +12 -0
  23. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine.egg-info/PKG-INFO +2 -1
  24. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine.egg-info/SOURCES.txt +5 -0
  25. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine.egg-info/requires.txt +1 -0
  26. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/pyproject.toml +2 -1
  27. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/.gitignore +0 -0
  28. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/.pre-commit-config.yaml +0 -0
  29. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/LICENSE +0 -0
  30. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/MANIFEST.in +0 -0
  31. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/assets/dingtalk.png +0 -0
  32. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/assets/showcase.jpeg +0 -0
  33. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/__init__.py +0 -0
  34. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/__init__.py +0 -0
  35. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
  36. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
  37. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
  38. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
  39. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
  40. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
  41. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  42. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
  43. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
  44. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
  45. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
  46. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
  47. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
  48. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
  49. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  50. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
  51. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  52. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
  53. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
  54. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
  55. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
  56. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
  57. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
  58. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
  59. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
  60. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
  61. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/components/vae.json +0 -0
  62. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
  63. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
  64. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
  65. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
  66. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
  67. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
  68. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
  69. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
  70. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
  71. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +0 -0
  72. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/wan/dit/14b-flf2v.json +0 -0
  73. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/wan/dit/14b-i2v.json +0 -0
  74. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/wan/dit/14b-t2v.json +0 -0
  75. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
  76. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
  77. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
  78. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
  79. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
  80. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  81. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
  82. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
  83. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
  84. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
  85. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
  86. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
  87. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
  88. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
  89. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
  90. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
  91. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
  92. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  93. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
  94. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
  95. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/kernels/__init__.py +0 -0
  96. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/__init__.py +0 -0
  97. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/base.py +0 -0
  98. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/basic/__init__.py +0 -0
  99. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/basic/lora.py +0 -0
  100. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
  101. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/basic/timestep.py +0 -0
  102. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/basic/transformer_helper.py +0 -0
  103. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/basic/unet_helper.py +0 -0
  104. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/flux/__init__.py +0 -0
  105. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/flux/flux_controlnet.py +0 -0
  106. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/flux/flux_ipadapter.py +0 -0
  107. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/flux/flux_redux.py +0 -0
  108. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
  109. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/flux/flux_vae.py +0 -0
  110. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sd/__init__.py +0 -0
  111. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
  112. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sd/sd_unet.py +0 -0
  113. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sd/sd_vae.py +0 -0
  114. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sd3/__init__.py +0 -0
  115. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sd3/sd3_dit.py +0 -0
  116. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
  117. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
  118. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sdxl/__init__.py +0 -0
  119. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
  120. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sdxl/sdxl_unet.py +0 -0
  121. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
  122. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/text_encoder/__init__.py +0 -0
  123. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/text_encoder/clip.py +0 -0
  124. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/text_encoder/siglip.py +0 -0
  125. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/text_encoder/t5.py +0 -0
  126. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/vae/__init__.py +0 -0
  127. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/wan/__init__.py +0 -0
  128. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
  129. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
  130. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/wan/wan_vae.py +0 -0
  131. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/pipelines/__init__.py +0 -0
  132. {diffsynth_engine-0.2.4/diffsynth_engine/utils → diffsynth_engine-0.2.6/diffsynth_engine/processor}/__init__.py +0 -0
  133. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tokenizers/__init__.py +0 -0
  134. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tokenizers/base.py +0 -0
  135. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tokenizers/clip.py +0 -0
  136. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tokenizers/t5.py +0 -0
  137. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tokenizers/wan.py +0 -0
  138. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tools/__init__.py +0 -0
  139. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/constants.py +0 -0
  140. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/download.py +0 -0
  141. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/env.py +0 -0
  142. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/flag.py +0 -0
  143. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/fp8_linear.py +0 -0
  144. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/gguf.py +0 -0
  145. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/image.py +0 -0
  146. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/loader.py +0 -0
  147. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/lock.py +0 -0
  148. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/logging.py +0 -0
  149. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/offload.py +0 -0
  150. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/prompt.py +0 -0
  151. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/video.py +0 -0
  152. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
  153. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine.egg-info/top_level.txt +0 -0
  154. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/docs/tutorial.md +0 -0
  155. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/docs/tutorial_zh.md +0 -0
  156. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/setup.cfg +0 -0
  157. {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.2.4
3
+ Version: 0.2.6
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -23,6 +23,7 @@ Requires-Dist: torchsde
23
23
  Requires-Dist: pillow
24
24
  Requires-Dist: imageio[ffmpeg]
25
25
  Requires-Dist: yunchang; sys_platform == "linux"
26
+ Requires-Dist: onnxruntime
26
27
  Provides-Extra: dev
27
28
  Requires-Dist: diffusers==0.31.0; extra == "dev"
28
29
  Requires-Dist: transformers==4.45.2; extra == "dev"
@@ -45,7 +45,7 @@ Text to image
45
45
  ```python
46
46
  from diffsynth_engine import fetch_model, FluxImagePipeline
47
47
 
48
- model_path = fetch_model("muse/flux-with-vae", path="flux_with_vae.safetensors")
48
+ model_path = fetch_model("muse/flux-with-vae", path="flux1-dev-with-vae.safetensors")
49
49
  pipe = FluxImagePipeline.from_pretrained(model_path, device='cuda:0')
50
50
  image = pipe(prompt="a cat")
51
51
  image.save("image.png")
@@ -54,7 +54,7 @@ Text to image with LoRA
54
54
  ```python
55
55
  from diffsynth_engine import fetch_model, FluxImagePipeline
56
56
 
57
- model_path = fetch_model("muse/flux-with-vae", path="flux_with_vae.safetensors")
57
+ model_path = fetch_model("muse/flux-with-vae", path="flux1-dev-with-vae.safetensors")
58
58
  lora_path = fetch_model("DonRat/MAJICFLUS_SuperChinesestyleheongsam", path="麦橘超国风旗袍.safetensors")
59
59
 
60
60
  pipe = FluxImagePipeline.from_pretrained(model_path, device='cuda:0')
@@ -77,6 +77,13 @@ If you have any questions or feedback, please scan the QR code below, or send em
77
77
  <img src="assets/dingtalk.png" alt="dingtalk" width="400" />
78
78
  </div>
79
79
 
80
+ ## Contributing
81
+ We welcome contributions to DiffSynth-Engine. After Install from source, we recommand developers install this project using following command to setup the development environment.
82
+ ```bash
83
+ pip install -e '.[dev]'
84
+ ```
85
+ TODO: Please refer to [CONTRIBUTING.md](./CONTRIBUTING.md) for more details.
86
+
80
87
  ## License
81
88
  This project is licensed under the Apache License 2.0. See the LICENSE file for details.
82
89
 
@@ -201,10 +201,8 @@ def long_context_attention(
201
201
  assert attn_impl in [
202
202
  None,
203
203
  "auto",
204
- "eager",
205
204
  "flash_attn_2",
206
205
  "flash_attn_3",
207
- "xformers",
208
206
  "sdpa",
209
207
  "sage_attn",
210
208
  "sparge_attn",
@@ -13,11 +13,12 @@ from diffsynth_engine.models.basic.transformer_helper import (
13
13
  )
14
14
  from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
15
15
  from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
16
+ from diffsynth_engine.models.basic import attention as attention_ops
16
17
  from diffsynth_engine.models.utils import no_init_weights
17
18
  from diffsynth_engine.utils.gguf import gguf_inference
18
19
  from diffsynth_engine.utils.fp8_linear import fp8_inference
19
20
  from diffsynth_engine.utils.constants import FLUX_DIT_CONFIG_FILE
20
- from diffsynth_engine.models.basic.attention import attention
21
+ from diffsynth_engine.utils.parallel import sequence_parallel, sequence_parallel_unshard
21
22
  from diffsynth_engine.utils import logging
22
23
 
23
24
 
@@ -198,7 +199,7 @@ class FluxDoubleAttention(nn.Module):
198
199
  k = torch.cat([self.norm_k_b(k_b), self.norm_k_a(k_a)], dim=1)
199
200
  v = torch.cat([v_b, v_a], dim=1)
200
201
  q, k = apply_rope(q, k, rope_emb)
201
- attn_out = attention(q, k, v, attn_impl=self.attn_impl)
202
+ attn_out = attention_ops.attention(q, k, v, attn_impl=self.attn_impl)
202
203
  attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
203
204
  text_out, image_out = attn_out[:, : text.shape[1]], attn_out[:, text.shape[1] :]
204
205
  image_out, text_out = self.attention_callback(
@@ -286,7 +287,7 @@ class FluxSingleAttention(nn.Module):
286
287
  def forward(self, x, rope_emb, image_emb):
287
288
  q, k, v = rearrange(self.to_qkv(x), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
288
289
  q, k = apply_rope(self.norm_q_a(q), self.norm_k_a(k), rope_emb)
289
- attn_out = attention(q, k, v, attn_impl=self.attn_impl)
290
+ attn_out = attention_ops.attention(q, k, v, attn_impl=self.attn_impl)
290
291
  attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
291
292
  return self.attention_callback(attn_out=attn_out, x=x, q=q, k=k, v=v, rope_emb=rope_emb, image_emb=image_emb)
292
293
 
@@ -322,7 +323,9 @@ class FluxDiT(PreTrainedModel):
322
323
 
323
324
  def __init__(
324
325
  self,
326
+ in_channel: int = 64,
325
327
  attn_impl: Optional[str] = None,
328
+ use_usp: bool = False,
326
329
  device: str = "cuda:0",
327
330
  dtype: torch.dtype = torch.bfloat16,
328
331
  ):
@@ -336,7 +339,8 @@ class FluxDiT(PreTrainedModel):
336
339
  nn.Linear(3072, 3072, device=device, dtype=dtype),
337
340
  )
338
341
  self.context_embedder = nn.Linear(4096, 3072, device=device, dtype=dtype)
339
- self.x_embedder = nn.Linear(64, 3072, device=device, dtype=dtype)
342
+ # normal flux has 64 channels, bfl canny and depth has 128 channels, bfl fill has 384 channels, bfl redux has 64 channels
343
+ self.x_embedder = nn.Linear(in_channel, 3072, device=device, dtype=dtype)
340
344
 
341
345
  self.blocks = nn.ModuleList(
342
346
  [FluxDoubleTransformerBlock(3072, 24, attn_impl=attn_impl, device=device, dtype=dtype) for _ in range(19)]
@@ -347,6 +351,8 @@ class FluxDiT(PreTrainedModel):
347
351
  self.final_norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
348
352
  self.final_proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
349
353
 
354
+ self.use_usp = use_usp
355
+
350
356
  def patchify(self, hidden_states):
351
357
  hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
352
358
  return hidden_states
@@ -357,7 +363,8 @@ class FluxDiT(PreTrainedModel):
357
363
  )
358
364
  return hidden_states
359
365
 
360
- def prepare_image_ids(self, latents):
366
+ @staticmethod
367
+ def prepare_image_ids(latents: torch.Tensor):
361
368
  batch_size, _, height, width = latents.shape
362
369
  latent_image_ids = torch.zeros(height // 2, width // 2, 3)
363
370
  latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
@@ -387,7 +394,14 @@ class FluxDiT(PreTrainedModel):
387
394
  controlnet_single_block_output=None,
388
395
  **kwargs,
389
396
  ):
390
- height, width = hidden_states.shape[-2:]
397
+ h, w = hidden_states.shape[-2:]
398
+ controlnet_double_block_output = (
399
+ controlnet_double_block_output if controlnet_double_block_output is not None else ()
400
+ )
401
+ controlnet_single_block_output = (
402
+ controlnet_single_block_output if controlnet_single_block_output is not None else ()
403
+ )
404
+
391
405
  fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
392
406
  with fp8_inference(fp8_linear_enabled), gguf_inference():
393
407
  if image_ids is None:
@@ -400,28 +414,54 @@ class FluxDiT(PreTrainedModel):
400
414
  guidance = guidance * 1000
401
415
  conditioning += self.guidance_embedder(guidance, hidden_states.dtype)
402
416
  conditioning += self.pooled_text_embedder(pooled_prompt_emb)
403
- prompt_emb = self.context_embedder(prompt_emb)
404
417
  rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
418
+ text_rope_emb = rope_emb[:, :, : text_ids.size(1)]
419
+ image_rope_emb = rope_emb[:, :, text_ids.size(1) :]
405
420
  hidden_states = self.patchify(hidden_states)
406
- hidden_states = self.x_embedder(hidden_states)
407
- for i, block in enumerate(self.blocks):
408
- hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
409
- if controlnet_double_block_output is not None:
410
- interval_control = len(self.blocks) / len(controlnet_double_block_output)
411
- interval_control = int(np.ceil(interval_control))
412
- hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
413
- hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
414
- for i, block in enumerate(self.single_blocks):
415
- hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
416
- if controlnet_single_block_output is not None:
417
- interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
418
- interval_control = int(np.ceil(interval_control))
419
- hidden_states = hidden_states + controlnet_single_block_output[i // interval_control]
420
-
421
- hidden_states = hidden_states[:, prompt_emb.shape[1] :]
422
- hidden_states = self.final_norm_out(hidden_states, conditioning)
423
- hidden_states = self.final_proj_out(hidden_states)
424
- hidden_states = self.unpatchify(hidden_states, height, width)
421
+
422
+ with sequence_parallel(
423
+ (
424
+ hidden_states,
425
+ prompt_emb,
426
+ text_rope_emb,
427
+ image_rope_emb,
428
+ *controlnet_double_block_output,
429
+ *controlnet_single_block_output,
430
+ ),
431
+ seq_dims=(
432
+ 1,
433
+ 1,
434
+ 2,
435
+ 2,
436
+ *(1 for _ in controlnet_double_block_output),
437
+ *(1 for _ in controlnet_single_block_output),
438
+ ),
439
+ enabled=self.use_usp,
440
+ ):
441
+ hidden_states = self.x_embedder(hidden_states)
442
+ prompt_emb = self.context_embedder(prompt_emb)
443
+ rope_emb = torch.cat((text_rope_emb, image_rope_emb), dim=2)
444
+
445
+ for i, block in enumerate(self.blocks):
446
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
447
+ if len(controlnet_double_block_output) > 0:
448
+ interval_control = len(self.blocks) / len(controlnet_double_block_output)
449
+ interval_control = int(np.ceil(interval_control))
450
+ hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
451
+ hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
452
+ for i, block in enumerate(self.single_blocks):
453
+ hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
454
+ if len(controlnet_single_block_output) > 0:
455
+ interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
456
+ interval_control = int(np.ceil(interval_control))
457
+ hidden_states = hidden_states + controlnet_single_block_output[i // interval_control]
458
+
459
+ hidden_states = hidden_states[:, prompt_emb.shape[1] :]
460
+ hidden_states = self.final_norm_out(hidden_states, conditioning)
461
+ hidden_states = self.final_proj_out(hidden_states)
462
+ (hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(h * w // 4,))
463
+
464
+ hidden_states = self.unpatchify(hidden_states, h, w)
425
465
  return hidden_states
426
466
 
427
467
  @classmethod
@@ -430,14 +470,18 @@ class FluxDiT(PreTrainedModel):
430
470
  state_dict: Dict[str, torch.Tensor],
431
471
  device: str,
432
472
  dtype: torch.dtype,
473
+ in_channel: int = 64,
433
474
  attn_impl: Optional[str] = None,
475
+ use_usp: bool = False,
434
476
  ):
435
477
  with no_init_weights():
436
478
  model = torch.nn.utils.skip_init(
437
479
  cls,
438
480
  device=device,
439
481
  dtype=dtype,
482
+ in_channel=in_channel,
440
483
  attn_impl=attn_impl,
484
+ use_usp=use_usp,
441
485
  )
442
486
  model = model.requires_grad_(False) # for loading gguf
443
487
  model.load_state_dict(state_dict, assign=True)
@@ -2,7 +2,6 @@ import torch
2
2
  import torch.nn as nn
3
3
  from contextlib import contextmanager
4
4
 
5
-
6
5
  # mofified from transformers.modeling_utils
7
6
  TORCH_INIT_FUNCTIONS = {
8
7
  "uniform_": nn.init.uniform_,
@@ -167,6 +167,8 @@ class VAEDecoder(PreTrainedModel):
167
167
  self.conv_norm_out = nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6, device=device, dtype=dtype)
168
168
  self.conv_act = nn.SiLU()
169
169
  self.conv_out = nn.Conv2d(128, 3, kernel_size=3, padding=1, device=device, dtype=dtype)
170
+ self.device = device
171
+ self.dtype = dtype
170
172
 
171
173
  def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
172
174
  original_dtype = sample.dtype
@@ -277,6 +279,8 @@ class VAEEncoder(PreTrainedModel):
277
279
  self.conv_norm_out = nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6, device=device, dtype=dtype)
278
280
  self.conv_act = nn.SiLU()
279
281
  self.conv_out = nn.Conv2d(512, 2 * latent_channels, kernel_size=3, padding=1, device=device, dtype=dtype)
282
+ self.device = device
283
+ self.dtype = dtype
280
284
 
281
285
  def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
282
286
  original_dtype = sample.dtype
@@ -2,12 +2,11 @@ import math
2
2
  import json
3
3
  import torch
4
4
  import torch.nn as nn
5
- import torch.distributed as dist
6
5
  from typing import Tuple, Optional
7
6
  from einops import rearrange
8
7
 
9
8
  from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
10
- from diffsynth_engine.models.basic.attention import attention, long_context_attention
9
+ from diffsynth_engine.models.basic import attention as attention_ops
11
10
  from diffsynth_engine.models.basic.transformer_helper import RMSNorm
12
11
  from diffsynth_engine.models.utils import no_init_weights
13
12
  from diffsynth_engine.utils.constants import (
@@ -17,11 +16,7 @@ from diffsynth_engine.utils.constants import (
17
16
  WAN_DIT_14B_FLF2V_CONFIG_FILE,
18
17
  )
19
18
  from diffsynth_engine.utils.gguf import gguf_inference
20
- from diffsynth_engine.utils.parallel import (
21
- get_sp_group,
22
- get_sp_world_size,
23
- get_sp_rank,
24
- )
19
+ from diffsynth_engine.utils.parallel import sequence_parallel, sequence_parallel_unshard
25
20
 
26
21
  T5_TOKEN_NUM = 512
27
22
  FLF_TOKEN_NUM = 257 * 2
@@ -90,20 +85,12 @@ class SelfAttention(nn.Module):
90
85
  q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
91
86
  k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
92
87
  v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
93
- if getattr(self, "use_usp", False):
94
- x = long_context_attention(
95
- q=rope_apply(q, freqs),
96
- k=rope_apply(k, freqs),
97
- v=v,
98
- attn_impl=self.attn_impl,
99
- )
100
- else:
101
- x = attention(
102
- q=rope_apply(q, freqs),
103
- k=rope_apply(k, freqs),
104
- v=v,
105
- attn_impl=self.attn_impl,
106
- )
88
+ x = attention_ops.attention(
89
+ q=rope_apply(q, freqs),
90
+ k=rope_apply(k, freqs),
91
+ v=v,
92
+ attn_impl=self.attn_impl,
93
+ )
107
94
  x = x.flatten(2)
108
95
  return self.o(x)
109
96
 
@@ -148,12 +135,12 @@ class CrossAttention(nn.Module):
148
135
  k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
149
136
  v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
150
137
 
151
- x = attention(q, k, v, attn_impl=self.attn_impl).flatten(2)
138
+ x = attention_ops.attention(q, k, v, attn_impl=self.attn_impl).flatten(2)
152
139
  if self.has_image_input:
153
140
  k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
154
141
  k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
155
142
  v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
156
- y = attention(q, k_img, v_img, attn_impl=self.attn_impl).flatten(2)
143
+ y = attention_ops.attention(q, k_img, v_img, attn_impl=self.attn_impl).flatten(2)
157
144
  x = x + y
158
145
  return self.o(x)
159
146
 
@@ -316,10 +303,7 @@ class WanDiT(PreTrainedModel):
316
303
  if has_image_input:
317
304
  self.img_emb = MLP(1280, dim, flf_pos_emb, device=device, dtype=dtype) # clip_feature_dim = 1280
318
305
 
319
- if use_usp:
320
- setattr(self, "use_usp", True)
321
- for block in self.blocks:
322
- setattr(block.self_attn, "use_usp", True)
306
+ self.use_usp = use_usp
323
307
 
324
308
  def patchify(self, x: torch.Tensor):
325
309
  x = self.patch_embedding(x) # b c f h w -> b 4c f h/2 w/2
@@ -368,21 +352,12 @@ class WanDiT(PreTrainedModel):
368
352
  .reshape(f * h * w, 1, -1)
369
353
  .to(x.device)
370
354
  )
371
- if getattr(self, "use_usp", False):
372
- s, p = x.size(1), get_sp_world_size() # (sequence_length, parallelism)
373
- split_size = [s // p + 1 if i < s % p else s // p for i in range(p)]
374
- x = torch.split(x, split_size, dim=1)[get_sp_rank()]
375
- freqs = torch.split(freqs, split_size, dim=0)[get_sp_rank()]
376
-
377
- for block in self.blocks:
378
- x = block(x, context, t_mod, freqs)
379
- x = self.head(x, t)
380
-
381
- if getattr(self, "use_usp", False):
382
- b, d = x.size(0), x.size(2) # (batch_size, out_dim)
383
- xs = [torch.zeros((b, s, d), dtype=x.dtype, device=x.device) for s in split_size]
384
- dist.all_gather(xs, x, group=get_sp_group())
385
- x = torch.concat(xs, dim=1)
355
+
356
+ with sequence_parallel([x, freqs], seq_dims=(1, 0), enabled=self.use_usp):
357
+ for block in self.blocks:
358
+ x = block(x, context, t_mod, freqs)
359
+ x = self.head(x, t)
360
+ (x,) = sequence_parallel_unshard((x,), seq_dims=(1,), seq_lens=(f * h * w,))
386
361
  x = self.unpatchify(x, (f, h, w))
387
362
  return x
388
363
 
@@ -4,10 +4,11 @@ import numpy as np
4
4
  from typing import Dict, List, Tuple
5
5
  from PIL import Image
6
6
  from dataclasses import dataclass
7
- from diffsynth_engine.utils.loader import load_file
8
7
  from diffsynth_engine.utils.offload import enable_sequential_cpu_offload
9
8
  from diffsynth_engine.utils.gguf import load_gguf_checkpoint
10
9
  from diffsynth_engine.utils import logging
10
+ from diffsynth_engine.utils.loader import load_file
11
+ from diffsynth_engine.utils.platform import empty_cache
11
12
 
12
13
  logger = logging.get_logger(__name__)
13
14
 
@@ -25,14 +26,21 @@ class LoRAStateDictConverter:
25
26
  class BasePipeline:
26
27
  lora_converter = LoRAStateDictConverter()
27
28
 
28
- def __init__(self, vae_tiled, vae_tile_size, vae_tile_stride, device="cuda:0", dtype=torch.float16):
29
+ def __init__(
30
+ self,
31
+ vae_tiled: bool = False,
32
+ vae_tile_size: int = -1,
33
+ vae_tile_stride: int = -1,
34
+ device="cuda:0",
35
+ dtype=torch.float16,
36
+ ):
29
37
  super().__init__()
30
- self.device = device
31
- self.dtype = dtype
32
- self.offload_mode = None
33
38
  self.vae_tiled = vae_tiled
34
39
  self.vae_tile_size = vae_tile_size
35
40
  self.vae_tile_stride = vae_tile_stride
41
+ self.device = device
42
+ self.dtype = dtype
43
+ self.offload_mode = None
36
44
  self.model_names = []
37
45
 
38
46
  @classmethod
@@ -144,6 +152,7 @@ class BasePipeline:
144
152
  return noise
145
153
 
146
154
  def encode_image(self, image: torch.Tensor) -> torch.Tensor:
155
+ image = image.to(self.device, self.vae_encoder.dtype)
147
156
  latents = self.vae_encoder(
148
157
  image, tiled=self.vae_tiled, tile_size=self.vae_tile_size, tile_stride=self.vae_tile_stride
149
158
  )
@@ -151,8 +160,9 @@ class BasePipeline:
151
160
 
152
161
  def decode_image(self, latent: torch.Tensor) -> torch.Tensor:
153
162
  vae_dtype = self.vae_decoder.conv_in.weight.dtype
163
+ latent = latent.to(self.device, vae_dtype)
154
164
  image = self.vae_decoder(
155
- latent.to(vae_dtype), tiled=self.vae_tiled, tile_size=self.vae_tile_size, tile_stride=self.vae_tile_stride
165
+ latent, tiled=self.vae_tiled, tile_size=self.vae_tile_size, tile_stride=self.vae_tile_stride
156
166
  )
157
167
  return image
158
168
 
@@ -196,8 +206,53 @@ class BasePipeline:
196
206
  model.eval()
197
207
  return self
198
208
 
199
- def enable_fp8_linear(self):
200
- raise NotImplementedError()
209
+ @staticmethod
210
+ def init_parallel_config(
211
+ parallelism: int,
212
+ use_cfg_parallel: bool,
213
+ model_config: ModelConfig,
214
+ ):
215
+ assert parallelism in (2, 4, 8), "parallelism must be 2, 4 or 8"
216
+ cfg_degree = 2 if use_cfg_parallel else 1
217
+ sp_ulysses_degree = getattr(model_config, "sp_ulysses_degree", None)
218
+ sp_ring_degree = getattr(model_config, "sp_ring_degree", None)
219
+ tp_degree = getattr(model_config, "tp_degree", None)
220
+ use_fsdp = getattr(model_config, "use_fsdp", False)
221
+
222
+ if tp_degree is not None:
223
+ assert sp_ulysses_degree is None and sp_ring_degree is None, (
224
+ "not allowed to enable sequence parallel and tensor parallel together; "
225
+ "either set sp_ulysses_degree=None, sp_ring_degree=None or set tp_degree=None during pipeline initialization"
226
+ )
227
+ assert use_fsdp is False, (
228
+ "not allowed to enable fully sharded data parallel and tensor parallel together; "
229
+ "either set use_fsdp=False or set tp_degree=None during pipeline initialization"
230
+ )
231
+ assert parallelism == cfg_degree * tp_degree, (
232
+ f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * tp_degree ({tp_degree})"
233
+ )
234
+ sp_ulysses_degree = 1
235
+ sp_ring_degree = 1
236
+ elif sp_ulysses_degree is None and sp_ring_degree is None:
237
+ # use ulysses if not specified
238
+ sp_ulysses_degree = parallelism // cfg_degree
239
+ sp_ring_degree = 1
240
+ tp_degree = 1
241
+ elif sp_ulysses_degree is not None and sp_ring_degree is not None:
242
+ assert parallelism == cfg_degree * sp_ulysses_degree * sp_ring_degree, (
243
+ f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * "
244
+ f"sp_ulysses_degree ({sp_ulysses_degree}) * sp_ring_degree ({sp_ring_degree})"
245
+ )
246
+ tp_degree = 1
247
+ else:
248
+ raise ValueError("sp_ulysses_degree and sp_ring_degree must be specified together")
249
+ return {
250
+ "cfg_degree": cfg_degree,
251
+ "sp_ulysses_degree": sp_ulysses_degree,
252
+ "sp_ring_degree": sp_ring_degree,
253
+ "tp_degree": tp_degree,
254
+ "use_fsdp": use_fsdp,
255
+ }
201
256
 
202
257
  @staticmethod
203
258
  def validate_offload_mode(offload_mode: str | None):
@@ -233,7 +288,7 @@ class BasePipeline:
233
288
  return
234
289
  if self.offload_mode == "sequential_cpu_offload":
235
290
  # fresh the cuda cache
236
- torch.cuda.empty_cache()
291
+ empty_cache()
237
292
  return
238
293
 
239
294
  # offload unnecessary models to cpu
@@ -248,4 +303,4 @@ class BasePipeline:
248
303
  if model is not None and (p := next(model.parameters(), None)) is not None and p.device != self.device:
249
304
  model.to(self.device)
250
305
  # fresh the cuda cache
251
- torch.cuda.empty_cache()
306
+ empty_cache()