diffsynth-engine 0.3.6.dev12__tar.gz → 0.3.6.dev14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/PKG-INFO +1 -1
  2. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/README.md +6 -0
  3. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +2 -3
  4. diffsynth_engine-0.3.6.dev12/diffsynth_engine/conf/models/wan/dit/14b-i2v.json → diffsynth_engine-0.3.6.dev14/diffsynth_engine/conf/models/wan/dit/wan2.1-flf2v-14b.json +5 -2
  5. diffsynth_engine-0.3.6.dev12/diffsynth_engine/conf/models/wan/dit/14b-flf2v.json → diffsynth_engine-0.3.6.dev14/diffsynth_engine/conf/models/wan/dit/wan2.1-i2v-14b.json +2 -2
  6. diffsynth_engine-0.3.6.dev12/diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json → diffsynth_engine-0.3.6.dev14/diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-1.3b.json +0 -1
  7. diffsynth_engine-0.3.6.dev12/diffsynth_engine/conf/models/wan/dit/14b-t2v.json → diffsynth_engine-0.3.6.dev14/diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-14b.json +0 -1
  8. diffsynth_engine-0.3.6.dev14/diffsynth_engine/conf/models/wan/dit/wan2.2-i2v-a14b.json +16 -0
  9. diffsynth_engine-0.3.6.dev14/diffsynth_engine/conf/models/wan/dit/wan2.2-t2v-a14b.json +16 -0
  10. diffsynth_engine-0.3.6.dev14/diffsynth_engine/conf/models/wan/dit/wan2.2-ti2v-5b.json +14 -0
  11. diffsynth_engine-0.3.6.dev14/diffsynth_engine/conf/models/wan/vae/wan2.1-vae.json +48 -0
  12. diffsynth_engine-0.3.6.dev14/diffsynth_engine/conf/models/wan/vae/wan2.2-vae.json +112 -0
  13. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/configs/pipeline.py +6 -1
  14. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/basic/attention.py +53 -33
  15. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/wan/wan_dit.py +52 -32
  16. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/wan/wan_vae.py +355 -60
  17. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/pipelines/base.py +15 -11
  18. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/pipelines/wan_video.py +175 -74
  19. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/utils/constants.py +10 -4
  20. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/utils/parallel.py +3 -1
  21. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine.egg-info/PKG-INFO +1 -1
  22. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine.egg-info/SOURCES.txt +9 -4
  23. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/.gitignore +0 -0
  24. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/.pre-commit-config.yaml +0 -0
  25. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/LICENSE +0 -0
  26. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/MANIFEST.in +0 -0
  27. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/assets/dingtalk.png +0 -0
  28. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/assets/showcase.jpeg +0 -0
  29. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/__init__.py +0 -0
  30. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/__init__.py +0 -0
  31. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
  32. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
  33. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
  34. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
  35. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
  36. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
  37. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  38. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
  39. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
  40. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
  41. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
  42. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
  43. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
  44. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
  45. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  46. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  47. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
  48. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
  49. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
  50. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
  51. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
  52. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
  53. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
  54. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
  55. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
  56. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/models/components/vae.json +0 -0
  57. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
  58. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
  59. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
  60. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
  61. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
  62. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
  63. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
  64. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
  65. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
  66. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
  67. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
  68. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
  69. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
  70. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
  71. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  72. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
  73. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
  74. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
  75. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
  76. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
  77. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
  78. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
  79. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
  80. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
  81. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
  82. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
  83. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  84. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
  85. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
  86. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/configs/__init__.py +0 -0
  87. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/configs/controlnet.py +0 -0
  88. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/kernels/__init__.py +0 -0
  89. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/__init__.py +0 -0
  90. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/base.py +0 -0
  91. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/basic/__init__.py +0 -0
  92. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/basic/lora.py +0 -0
  93. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
  94. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/basic/timestep.py +0 -0
  95. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/basic/transformer_helper.py +0 -0
  96. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/basic/unet_helper.py +0 -0
  97. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/flux/__init__.py +0 -0
  98. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/flux/flux_controlnet.py +0 -0
  99. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/flux/flux_dit.py +0 -0
  100. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/flux/flux_dit_fbcache.py +0 -0
  101. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/flux/flux_ipadapter.py +0 -0
  102. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/flux/flux_redux.py +0 -0
  103. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
  104. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/flux/flux_vae.py +0 -0
  105. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/sd/__init__.py +0 -0
  106. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/sd/sd_controlnet.py +0 -0
  107. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
  108. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/sd/sd_unet.py +0 -0
  109. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/sd/sd_vae.py +0 -0
  110. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/sd3/__init__.py +0 -0
  111. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/sd3/sd3_dit.py +0 -0
  112. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
  113. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
  114. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/sdxl/__init__.py +0 -0
  115. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/sdxl/sdxl_controlnet.py +0 -0
  116. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
  117. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/sdxl/sdxl_unet.py +0 -0
  118. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
  119. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/text_encoder/__init__.py +0 -0
  120. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/text_encoder/clip.py +0 -0
  121. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/text_encoder/siglip.py +0 -0
  122. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/text_encoder/t5.py +0 -0
  123. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/utils.py +0 -0
  124. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/vae/__init__.py +0 -0
  125. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/vae/vae.py +0 -0
  126. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/wan/__init__.py +0 -0
  127. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
  128. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
  129. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/pipelines/__init__.py +0 -0
  130. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/pipelines/flux_image.py +0 -0
  131. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/pipelines/sd_image.py +0 -0
  132. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/pipelines/sdxl_image.py +0 -0
  133. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/pipelines/utils.py +0 -0
  134. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/processor/__init__.py +0 -0
  135. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/processor/canny_processor.py +0 -0
  136. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/processor/depth_processor.py +0 -0
  137. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/tokenizers/__init__.py +0 -0
  138. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/tokenizers/base.py +0 -0
  139. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/tokenizers/clip.py +0 -0
  140. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/tokenizers/t5.py +0 -0
  141. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/tokenizers/wan.py +0 -0
  142. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/tools/__init__.py +0 -0
  143. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/tools/flux_inpainting_tool.py +0 -0
  144. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/tools/flux_outpainting_tool.py +0 -0
  145. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/tools/flux_reference_tool.py +0 -0
  146. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/tools/flux_replace_tool.py +0 -0
  147. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/utils/__init__.py +0 -0
  148. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/utils/download.py +0 -0
  149. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/utils/env.py +0 -0
  150. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/utils/flag.py +0 -0
  151. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/utils/fp8_linear.py +0 -0
  152. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/utils/gguf.py +0 -0
  153. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/utils/image.py +0 -0
  154. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/utils/loader.py +0 -0
  155. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/utils/lock.py +0 -0
  156. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/utils/logging.py +0 -0
  157. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/utils/offload.py +0 -0
  158. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/utils/onnx.py +0 -0
  159. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/utils/platform.py +0 -0
  160. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/utils/prompt.py +0 -0
  161. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine/utils/video.py +0 -0
  162. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
  163. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine.egg-info/requires.txt +0 -0
  164. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/diffsynth_engine.egg-info/top_level.txt +0 -0
  165. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/docs/tutorial.md +0 -0
  166. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/docs/tutorial_zh.md +0 -0
  167. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/pyproject.toml +0 -0
  168. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/setup.cfg +0 -0
  169. {diffsynth_engine-0.3.6.dev12 → diffsynth_engine-0.3.6.dev14}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.3.6.dev12
3
+ Version: 0.3.6.dev14
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -21,6 +21,12 @@ and offloading strategies, enabling loading of larger diffusion models (e.g., Fl
21
21
 
22
22
  - **Cross-Platform Support:** Runnable on Windows, macOS (Apple Silicon), and Linux, ensuring a smooth experience across different operating systems.
23
23
 
24
+ ## News
25
+
26
+ - **[v0.4.0](https://github.com/modelscope/DiffSynth-Engine/releases/tag/v0.4.0)** | **August 1, 2025**:
27
+ - 🔥Supports [Wan2.2](https://modelscope.cn/collections/tongyiwanxiang-22--shipinshengcheng-2bb5b1adef2840) video generation model
28
+ - ⚠️[**Breaking Change**] Improved `from_pretrained` method pipeline initialization
29
+
24
30
  ## Quick Start
25
31
  ### Requirements
26
32
 
@@ -9,13 +9,12 @@ class FlowMatchEulerSampler:
9
9
  self.mask = mask
10
10
 
11
11
  def step(self, latents, model_outputs, i):
12
- if self.mask is not None:
13
- model_outputs = model_outputs * self.mask + self.init_latents * (1 - self.mask)
14
-
15
12
  dt = self.sigmas[i + 1] - self.sigmas[i]
16
13
  latents = latents.to(dtype=torch.float32)
17
14
  latents = latents + model_outputs * dt
18
15
  latents = latents.to(dtype=model_outputs.dtype)
16
+ if self.mask is not None:
17
+ latents = latents * self.mask + self.init_latents * (1 - self.mask)
19
18
  return latents
20
19
 
21
20
  def add_noise(self, latents, noise, sigma):
@@ -1,5 +1,7 @@
1
1
  {
2
- "has_image_input": true,
2
+ "has_clip_feature": true,
3
+ "has_vae_feature": true,
4
+ "flf_pos_emb": true,
3
5
  "patch_size": [1, 2, 2],
4
6
  "in_dim": 36,
5
7
  "dim": 5120,
@@ -9,5 +11,6 @@
9
11
  "out_dim": 16,
10
12
  "num_heads": 40,
11
13
  "num_layers": 40,
12
- "eps": 1e-6
14
+ "eps": 1e-6,
15
+ "shift": 16.0
13
16
  }
@@ -1,6 +1,6 @@
1
1
  {
2
- "has_image_input": true,
3
- "flf_pos_emb": true,
2
+ "has_clip_feature": true,
3
+ "has_vae_feature": true,
4
4
  "patch_size": [1, 2, 2],
5
5
  "in_dim": 36,
6
6
  "dim": 5120,
@@ -0,0 +1,16 @@
1
+ {
2
+ "has_vae_feature": true,
3
+ "patch_size": [1, 2, 2],
4
+ "in_dim": 36,
5
+ "dim": 5120,
6
+ "ffn_dim": 13824,
7
+ "freq_dim": 256,
8
+ "text_dim": 4096,
9
+ "out_dim": 16,
10
+ "num_heads": 40,
11
+ "num_layers": 40,
12
+ "eps": 1e-6,
13
+ "boundary": 0.900,
14
+ "cfg_scale": [3.5, 3.5],
15
+ "num_inference_steps": 40
16
+ }
@@ -0,0 +1,16 @@
1
+ {
2
+ "patch_size": [1, 2, 2],
3
+ "in_dim": 16,
4
+ "dim": 5120,
5
+ "ffn_dim": 13824,
6
+ "freq_dim": 256,
7
+ "text_dim": 4096,
8
+ "out_dim": 16,
9
+ "num_heads": 40,
10
+ "num_layers": 40,
11
+ "eps": 1e-6,
12
+ "boundary": 0.875,
13
+ "shift": 12.0,
14
+ "cfg_scale": [3.0, 4.0],
15
+ "num_inference_steps": 40
16
+ }
@@ -0,0 +1,14 @@
1
+ {
2
+ "fuse_image_latents": true,
3
+ "patch_size": [1, 2, 2],
4
+ "in_dim": 48,
5
+ "dim": 3072,
6
+ "ffn_dim": 14336,
7
+ "freq_dim": 256,
8
+ "text_dim": 4096,
9
+ "out_dim": 48,
10
+ "num_heads": 24,
11
+ "num_layers": 30,
12
+ "eps": 1e-6,
13
+ "fps": 24
14
+ }
@@ -0,0 +1,48 @@
1
+ {
2
+ "in_channels": 3,
3
+ "out_channels": 3,
4
+ "encoder_dim": 96,
5
+ "decoder_dim": 96,
6
+ "z_dim": 16,
7
+ "dim_mult": [1, 2, 4, 4],
8
+ "num_res_blocks": 2,
9
+ "temperal_downsample": [false, true, true],
10
+ "dropout": 0.0,
11
+ "patch_size": 1,
12
+ "mean": [
13
+ -0.7571,
14
+ -0.7089,
15
+ -0.9113,
16
+ 0.1075,
17
+ -0.1745,
18
+ 0.9653,
19
+ -0.1517,
20
+ 1.5508,
21
+ 0.4134,
22
+ -0.0715,
23
+ 0.5517,
24
+ -0.3632,
25
+ -0.1922,
26
+ -0.9497,
27
+ 0.2503,
28
+ -0.2921
29
+ ],
30
+ "std": [
31
+ 2.8184,
32
+ 1.4541,
33
+ 2.3275,
34
+ 2.6558,
35
+ 1.2196,
36
+ 1.7708,
37
+ 2.6052,
38
+ 2.0743,
39
+ 3.2687,
40
+ 2.1526,
41
+ 2.8652,
42
+ 1.5579,
43
+ 1.6382,
44
+ 1.1253,
45
+ 2.8251,
46
+ 1.9160
47
+ ]
48
+ }
@@ -0,0 +1,112 @@
1
+ {
2
+ "in_channels": 12,
3
+ "out_channels": 12,
4
+ "encoder_dim": 160,
5
+ "decoder_dim": 256,
6
+ "z_dim": 48,
7
+ "dim_mult": [1, 2, 4, 4],
8
+ "num_res_blocks": 2,
9
+ "temperal_downsample": [false, true, true],
10
+ "dropout": 0.0,
11
+ "patch_size": 2,
12
+ "mean": [
13
+ -0.2289,
14
+ -0.0052,
15
+ -0.1323,
16
+ -0.2339,
17
+ -0.2799,
18
+ 0.0174,
19
+ 0.1838,
20
+ 0.1557,
21
+ -0.1382,
22
+ 0.0542,
23
+ 0.2813,
24
+ 0.0891,
25
+ 0.1570,
26
+ -0.0098,
27
+ 0.0375,
28
+ -0.1825,
29
+ -0.2246,
30
+ -0.1207,
31
+ -0.0698,
32
+ 0.5109,
33
+ 0.2665,
34
+ -0.2108,
35
+ -0.2158,
36
+ 0.2502,
37
+ -0.2055,
38
+ -0.0322,
39
+ 0.1109,
40
+ 0.1567,
41
+ -0.0729,
42
+ 0.0899,
43
+ -0.2799,
44
+ -0.1230,
45
+ -0.0313,
46
+ -0.1649,
47
+ 0.0117,
48
+ 0.0723,
49
+ -0.2839,
50
+ -0.2083,
51
+ -0.0520,
52
+ 0.3748,
53
+ 0.0152,
54
+ 0.1957,
55
+ 0.1433,
56
+ -0.2944,
57
+ 0.3573,
58
+ -0.0548,
59
+ -0.1681,
60
+ -0.0667
61
+ ],
62
+ "std": [
63
+ 0.4765,
64
+ 1.0364,
65
+ 0.4514,
66
+ 1.1677,
67
+ 0.5313,
68
+ 0.4990,
69
+ 0.4818,
70
+ 0.5013,
71
+ 0.8158,
72
+ 1.0344,
73
+ 0.5894,
74
+ 1.0901,
75
+ 0.6885,
76
+ 0.6165,
77
+ 0.8454,
78
+ 0.4978,
79
+ 0.5759,
80
+ 0.3523,
81
+ 0.7135,
82
+ 0.6804,
83
+ 0.5833,
84
+ 1.4146,
85
+ 0.8986,
86
+ 0.5659,
87
+ 0.7069,
88
+ 0.5338,
89
+ 0.4889,
90
+ 0.4917,
91
+ 0.4069,
92
+ 0.4999,
93
+ 0.6866,
94
+ 0.4093,
95
+ 0.5709,
96
+ 0.6065,
97
+ 0.6415,
98
+ 0.4944,
99
+ 0.5726,
100
+ 1.2042,
101
+ 0.5458,
102
+ 1.6887,
103
+ 0.3971,
104
+ 1.0600,
105
+ 0.3943,
106
+ 0.5537,
107
+ 0.5444,
108
+ 0.4089,
109
+ 0.7468,
110
+ 0.7744
111
+ ]
112
+ }
@@ -139,7 +139,12 @@ class WanPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, Bas
139
139
  vae_dtype: torch.dtype = torch.bfloat16
140
140
  image_encoder_dtype: torch.dtype = torch.bfloat16
141
141
 
142
- shift: Optional[float] = field(default=None, init=False) # RecifitedFlowScheduler shift factor, set by model type
142
+ # default params set by model type
143
+ boundary: Optional[float] = field(default=None, init=False) # boundary
144
+ shift: Optional[float] = field(default=None, init=False) # RecifitedFlowScheduler shift factor
145
+ cfg_scale: Optional[float | Tuple[float, float]] = field(default=None, init=False) # default CFG scale
146
+ num_inference_steps: Optional[int] = field(default=None, init=False) # default inference steps
147
+ fps: Optional[int] = field(default=None, init=False) # default FPS
143
148
 
144
149
  # override BaseConfig
145
150
  vae_tiled: bool = True
@@ -14,6 +14,8 @@ from diffsynth_engine.utils.flag import (
14
14
  SPARGE_ATTN_AVAILABLE,
15
15
  )
16
16
 
17
+ FA3_MAX_HEADDIM = 256
18
+
17
19
  logger = logging.get_logger(__name__)
18
20
 
19
21
 
@@ -130,31 +132,40 @@ def attention(
130
132
  "sage_attn",
131
133
  "sparge_attn",
132
134
  ]
135
+ flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
133
136
  if attn_impl is None or attn_impl == "auto":
134
137
  if FLASH_ATTN_3_AVAILABLE:
135
- return flash_attn3(q, k, v, softmax_scale=scale)
136
- elif XFORMERS_AVAILABLE:
138
+ if flash_attn3_compatible:
139
+ return flash_attn3(q, k, v, softmax_scale=scale)
140
+ else:
141
+ logger.warning(
142
+ f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
143
+ )
144
+ if XFORMERS_AVAILABLE:
137
145
  return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
138
- elif SDPA_AVAILABLE:
146
+ if SDPA_AVAILABLE:
139
147
  return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
140
- elif FLASH_ATTN_2_AVAILABLE:
148
+ if FLASH_ATTN_2_AVAILABLE:
141
149
  return flash_attn2(q, k, v, softmax_scale=scale)
142
- else:
143
- return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
150
+ return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
144
151
  else:
145
152
  if attn_impl == "eager":
146
153
  return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
147
- elif attn_impl == "flash_attn_3":
154
+ if attn_impl == "flash_attn_3":
155
+ if not flash_attn3_compatible:
156
+ raise RuntimeError(
157
+ f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
158
+ )
148
159
  return flash_attn3(q, k, v, softmax_scale=scale)
149
- elif attn_impl == "flash_attn_2":
160
+ if attn_impl == "flash_attn_2":
150
161
  return flash_attn2(q, k, v, softmax_scale=scale)
151
- elif attn_impl == "xformers":
162
+ if attn_impl == "xformers":
152
163
  return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
153
- elif attn_impl == "sdpa":
164
+ if attn_impl == "sdpa":
154
165
  return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
155
- elif attn_impl == "sage_attn":
166
+ if attn_impl == "sage_attn":
156
167
  return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
157
- elif attn_impl == "sparge_attn":
168
+ if attn_impl == "sparge_attn":
158
169
  return sparge_attn(
159
170
  q,
160
171
  k,
@@ -166,8 +177,7 @@ def attention(
166
177
  cdfthreshd=kwargs.get("sparge_cdfthreshd", 0.98),
167
178
  pvthreshd=kwargs.get("sparge_pvthreshd", 50),
168
179
  )
169
- else:
170
- raise ValueError(f"Invalid attention implementation: {attn_impl}")
180
+ raise ValueError(f"Invalid attention implementation: {attn_impl}")
171
181
 
172
182
 
173
183
  class Attention(nn.Module):
@@ -240,32 +250,42 @@ def long_context_attention(
240
250
  "sage_attn",
241
251
  "sparge_attn",
242
252
  ]
253
+ flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
243
254
  if attn_impl is None or attn_impl == "auto":
244
255
  if FLASH_ATTN_3_AVAILABLE:
245
- attn_func = LongContextAttention(attn_type=AttnType.FA3)
246
- elif SDPA_AVAILABLE:
247
- attn_func = LongContextAttention(attn_type=AttnType.TORCH)
248
- elif FLASH_ATTN_2_AVAILABLE:
249
- attn_func = LongContextAttention(attn_type=AttnType.FA)
250
- else:
251
- raise ValueError("No available long context attention implementation")
256
+ if flash_attn3_compatible:
257
+ return LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
258
+ else:
259
+ logger.warning(
260
+ f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
261
+ )
262
+ if SDPA_AVAILABLE:
263
+ return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
264
+ if FLASH_ATTN_2_AVAILABLE:
265
+ return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
266
+ raise ValueError("No available long context attention implementation")
252
267
  else:
253
268
  if attn_impl == "flash_attn_3":
254
- attn_func = LongContextAttention(attn_type=AttnType.FA3)
255
- elif attn_impl == "flash_attn_2":
256
- attn_func = LongContextAttention(attn_type=AttnType.FA)
257
- elif attn_impl == "sdpa":
258
- attn_func = LongContextAttention(attn_type=AttnType.TORCH)
259
- elif attn_impl == "sage_attn":
260
- attn_func = LongContextAttention(attn_type=AttnType.SAGE_FP8)
261
- elif attn_impl == "sparge_attn":
269
+ if flash_attn3_compatible:
270
+ return LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
271
+ else:
272
+ raise RuntimeError(
273
+ f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
274
+ )
275
+ if attn_impl == "flash_attn_2":
276
+ return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
277
+ if attn_impl == "sdpa":
278
+ return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
279
+ if attn_impl == "sage_attn":
280
+ return LongContextAttention(attn_type=AttnType.SAGE_FP8)(q, k, v, softmax_scale=scale)
281
+ if attn_impl == "sparge_attn":
262
282
  attn_processor = SparseAttentionMeansim()
263
283
  # default args from spas_sage2_attn_meansim_cuda
264
284
  attn_processor.smooth_k = torch.tensor(kwargs.get("sparge_smooth_k", True))
265
285
  attn_processor.simthreshd1 = torch.tensor(kwargs.get("sparge_simthreshd1", 0.6))
266
286
  attn_processor.cdfthreshd = torch.tensor(kwargs.get("sparge_cdfthreshd", 0.98))
267
287
  attn_processor.pvthreshd = torch.tensor(kwargs.get("sparge_pvthreshd", 50))
268
- attn_func = LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)
269
- else:
270
- raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
271
- return attn_func(q, k, v, softmax_scale=scale)
288
+ return LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)(
289
+ q, k, v, softmax_scale=scale
290
+ )
291
+ raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
@@ -10,10 +10,13 @@ from diffsynth_engine.models.basic import attention as attention_ops
10
10
  from diffsynth_engine.models.basic.transformer_helper import RMSNorm
11
11
  from diffsynth_engine.models.utils import no_init_weights
12
12
  from diffsynth_engine.utils.constants import (
13
- WAN_DIT_1_3B_T2V_CONFIG_FILE,
14
- WAN_DIT_14B_I2V_CONFIG_FILE,
15
- WAN_DIT_14B_T2V_CONFIG_FILE,
16
- WAN_DIT_14B_FLF2V_CONFIG_FILE,
13
+ WAN2_1_DIT_T2V_1_3B_CONFIG_FILE,
14
+ WAN2_1_DIT_I2V_14B_CONFIG_FILE,
15
+ WAN2_1_DIT_T2V_14B_CONFIG_FILE,
16
+ WAN2_1_DIT_FLF2V_14B_CONFIG_FILE,
17
+ WAN2_2_DIT_TI2V_5B_CONFIG_FILE,
18
+ WAN2_2_DIT_I2V_A14B_CONFIG_FILE,
19
+ WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
17
20
  )
18
21
  from diffsynth_engine.utils.gguf import gguf_inference
19
22
  from diffsynth_engine.utils.parallel import (
@@ -182,7 +185,9 @@ class DiTBlock(nn.Module):
182
185
 
183
186
  def forward(self, x, context, t_mod, freqs):
184
187
  # msa: multi-head self-attention mlp: multi-layer perceptron
185
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + t_mod).chunk(6, dim=1)
188
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
189
+ t.squeeze(1) for t in (self.modulation + t_mod).chunk(6, dim=1)
190
+ ]
186
191
  input_x = modulate(self.norm1(x), shift_msa, scale_msa)
187
192
  x = x + gate_msa * self.self_attn(input_x, freqs)
188
193
  x = x + self.cross_attn(self.norm3(x), context)
@@ -237,7 +242,7 @@ class Head(nn.Module):
237
242
  self.modulation = nn.Parameter(torch.randn(1, 2, dim, device=device, dtype=dtype) / dim**0.5)
238
243
 
239
244
  def forward(self, x, t_mod):
240
- shift, scale = (self.modulation + t_mod).chunk(2, dim=1)
245
+ shift, scale = [t.squeeze(1) for t in (self.modulation + t_mod.unsqueeze(1)).chunk(2, dim=1)]
241
246
  x = self.head(self.norm(x) * (1 + scale) + shift)
242
247
  return x
243
248
 
@@ -263,17 +268,22 @@ class WanDiT(PreTrainedModel):
263
268
  patch_size: Tuple[int, int, int],
264
269
  num_heads: int,
265
270
  num_layers: int,
266
- has_image_input: bool,
271
+ has_clip_feature: bool = False,
272
+ has_vae_feature: bool = False,
273
+ fuse_image_latents: bool = False,
267
274
  flf_pos_emb: bool = False,
268
275
  attn_kwargs: Optional[Dict[str, Any]] = None,
269
- device: str = "cpu",
276
+ device: str = "cuda:0",
270
277
  dtype: torch.dtype = torch.bfloat16,
271
278
  ):
272
279
  super().__init__()
273
280
 
281
+ self.in_dim = in_dim
274
282
  self.dim = dim
275
283
  self.freq_dim = freq_dim
276
- self.has_image_input = has_image_input
284
+ self.has_clip_feature = has_clip_feature
285
+ self.has_vae_feature = has_vae_feature
286
+ self.fuse_image_latents = fuse_image_latents
277
287
  self.patch_size = patch_size
278
288
 
279
289
  self.patch_embedding = nn.Conv3d(
@@ -296,7 +306,7 @@ class WanDiT(PreTrainedModel):
296
306
  )
297
307
  self.blocks = nn.ModuleList(
298
308
  [
299
- DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps, attn_kwargs, device=device, dtype=dtype)
309
+ DiTBlock(has_clip_feature, dim, num_heads, ffn_dim, eps, attn_kwargs, device=device, dtype=dtype)
300
310
  for _ in range(num_layers)
301
311
  ]
302
312
  )
@@ -305,7 +315,7 @@ class WanDiT(PreTrainedModel):
305
315
  head_dim = dim // num_heads
306
316
  self.freqs = precompute_freqs_cis_3d(head_dim)
307
317
 
308
- if has_image_input:
318
+ if has_clip_feature:
309
319
  self.img_emb = MLP(1280, dim, flf_pos_emb, device=device, dtype=dtype) # clip_feature_dim = 1280
310
320
 
311
321
  def patchify(self, x: torch.Tensor):
@@ -339,13 +349,14 @@ class WanDiT(PreTrainedModel):
339
349
  gguf_inference(),
340
350
  cfg_parallel((x, context, timestep, clip_feature, y), use_cfg=use_cfg),
341
351
  ):
342
- t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
343
- t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
352
+ t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) # (s, d)
353
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) # (s, 6, d)
344
354
  context = self.text_embedding(context)
345
- if self.has_image_input:
355
+ if self.has_vae_feature:
346
356
  x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
347
- clip_embdding = self.img_emb(clip_feature)
348
- context = torch.cat([clip_embdding, context], dim=1) # (b, s1 + s2, d)
357
+ if self.has_clip_feature:
358
+ clip_embedding = self.img_emb(clip_feature)
359
+ context = torch.cat([clip_embedding, context], dim=1) # (b, s1 + s2, d)
349
360
  x, (f, h, w) = self.patchify(x)
350
361
  freqs = (
351
362
  torch.cat(
@@ -360,7 +371,7 @@ class WanDiT(PreTrainedModel):
360
371
  .to(x.device)
361
372
  )
362
373
 
363
- with sequence_parallel((x, freqs), seq_dims=(1, 0)):
374
+ with sequence_parallel((x, t, t_mod, freqs), seq_dims=(1, 0, 0, 0)):
364
375
  for block in self.blocks:
365
376
  x = block(x, context, t_mod, freqs)
366
377
  x = self.head(x, t)
@@ -369,26 +380,35 @@ class WanDiT(PreTrainedModel):
369
380
  (x,) = cfg_parallel_unshard((x,), use_cfg=use_cfg)
370
381
  return x
371
382
 
383
+ @staticmethod
384
+ def get_model_config(model_type: str):
385
+ MODEL_CONFIG_FILES = {
386
+ "wan2.1-t2v-1.3b": WAN2_1_DIT_T2V_1_3B_CONFIG_FILE,
387
+ "wan2.1-t2v-14b": WAN2_1_DIT_T2V_14B_CONFIG_FILE,
388
+ "wan2.1-i2v-14b": WAN2_1_DIT_I2V_14B_CONFIG_FILE,
389
+ "wan2.1-flf2v-14b": WAN2_1_DIT_FLF2V_14B_CONFIG_FILE,
390
+ "wan2.2-ti2v-5b": WAN2_2_DIT_TI2V_5B_CONFIG_FILE,
391
+ "wan2.2-t2v-a14b": WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
392
+ "wan2.2-i2v-a14b": WAN2_2_DIT_I2V_A14B_CONFIG_FILE,
393
+ }
394
+ if model_type not in MODEL_CONFIG_FILES:
395
+ raise ValueError(f"Unsupported model type: {model_type}")
396
+
397
+ config_file = MODEL_CONFIG_FILES[model_type]
398
+ with open(config_file, "r") as f:
399
+ config = json.load(f)
400
+ return config
401
+
372
402
  @classmethod
373
403
  def from_state_dict(
374
404
  cls,
375
- state_dict,
376
- device,
377
- dtype,
378
- model_type="1.3b-t2v",
405
+ state_dict: Dict[str, torch.Tensor],
406
+ config: Dict[str, Any],
407
+ device: str = "cuda:0",
408
+ dtype: torch.dtype = torch.bfloat16,
379
409
  attn_kwargs: Optional[Dict[str, Any]] = None,
380
- assign=True,
410
+ assign: bool = True,
381
411
  ):
382
- if model_type == "1.3b-t2v":
383
- config = json.load(open(WAN_DIT_1_3B_T2V_CONFIG_FILE, "r"))
384
- elif model_type == "14b-t2v":
385
- config = json.load(open(WAN_DIT_14B_T2V_CONFIG_FILE, "r"))
386
- elif model_type == "14b-i2v":
387
- config = json.load(open(WAN_DIT_14B_I2V_CONFIG_FILE, "r"))
388
- elif model_type == "14b-flf2v":
389
- config = json.load(open(WAN_DIT_14B_FLF2V_CONFIG_FILE, "r"))
390
- else:
391
- raise ValueError(f"Unsupported model type: {model_type}")
392
412
  with no_init_weights():
393
413
  model = torch.nn.utils.skip_init(cls, **config, device=device, dtype=dtype, attn_kwargs=attn_kwargs)
394
414
  model = model.requires_grad_(False)