diffsynth-engine 0.6.1.dev9__tar.gz → 0.6.1.dev11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/PKG-INFO +1 -1
  2. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/__init__.py +6 -0
  3. diffsynth_engine-0.6.1.dev11/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +29 -0
  4. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +22 -6
  5. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/configs/__init__.py +11 -2
  6. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/configs/controlnet.py +13 -0
  7. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/configs/pipeline.py +6 -0
  8. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +4 -2
  9. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/qwen_image/qwen_image_dit.py +112 -22
  10. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/wan/wan_audio_encoder.py +6 -2
  11. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/pipelines/base.py +26 -6
  12. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/pipelines/hunyuan3d_shape.py +1 -1
  13. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/pipelines/qwen_image.py +181 -50
  14. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/download.py +1 -5
  15. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/loader.py +25 -6
  16. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/video.py +3 -1
  17. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine.egg-info/PKG-INFO +1 -1
  18. diffsynth_engine-0.6.1.dev9/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -10
  19. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/.gitattributes +0 -0
  20. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/.gitignore +0 -0
  21. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/.pre-commit-config.yaml +0 -0
  22. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/LICENSE +0 -0
  23. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/MANIFEST.in +0 -0
  24. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/README.md +0 -0
  25. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/assets/dingtalk.png +0 -0
  26. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/assets/showcase.jpeg +0 -0
  27. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/assets/tongyi.svg +0 -0
  28. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/__init__.py +0 -0
  29. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
  30. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
  31. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
  32. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
  33. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  34. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
  35. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
  36. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
  37. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
  38. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
  39. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
  40. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
  41. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  42. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
  43. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  44. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
  45. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
  46. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
  47. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
  48. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
  49. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
  50. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
  51. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
  52. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
  53. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/components/vae.json +0 -0
  54. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
  55. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
  56. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
  57. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_config.json +0 -0
  58. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_vision_config.json +0 -0
  59. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae.json +0 -0
  60. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae_keymap.json +0 -0
  61. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
  62. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
  63. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
  64. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
  65. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
  66. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
  67. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/wan/dit/wan2.1-flf2v-14b.json +0 -0
  68. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/wan/dit/wan2.1-i2v-14b.json +0 -0
  69. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-1.3b.json +0 -0
  70. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-14b.json +0 -0
  71. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/wan/dit/wan2.2-i2v-a14b.json +0 -0
  72. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/wan/dit/wan2.2-s2v-14b.json +0 -0
  73. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/wan/dit/wan2.2-t2v-a14b.json +0 -0
  74. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/wan/dit/wan2.2-ti2v-5b.json +0 -0
  75. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/wan/vae/wan-vae-keymap.json +0 -0
  76. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/wan/vae/wan2.1-vae.json +0 -0
  77. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/models/wan/vae/wan2.2-vae.json +0 -0
  78. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
  79. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
  80. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
  81. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
  82. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
  83. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  84. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
  85. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
  86. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/qwen_image/qwen2_vl_image_processor.json +0 -0
  87. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/added_tokens.json +0 -0
  88. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/merges.txt +0 -0
  89. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/special_tokens_map.json +0 -0
  90. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer.json +0 -0
  91. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer_config.json +0 -0
  92. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/vocab.json +0 -0
  93. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
  94. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
  95. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
  96. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
  97. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
  98. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
  99. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
  100. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
  101. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
  102. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  103. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
  104. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
  105. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/kernels/__init__.py +0 -0
  106. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/__init__.py +0 -0
  107. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/base.py +0 -0
  108. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/basic/__init__.py +0 -0
  109. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/basic/attention.py +0 -0
  110. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/basic/lora.py +0 -0
  111. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
  112. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/basic/timestep.py +0 -0
  113. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/basic/transformer_helper.py +0 -0
  114. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/basic/unet_helper.py +0 -0
  115. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/flux/__init__.py +0 -0
  116. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/flux/flux_controlnet.py +0 -0
  117. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/flux/flux_dit.py +0 -0
  118. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/flux/flux_dit_fbcache.py +0 -0
  119. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/flux/flux_ipadapter.py +0 -0
  120. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/flux/flux_redux.py +0 -0
  121. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
  122. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/flux/flux_vae.py +0 -0
  123. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/hunyuan3d/__init__.py +0 -0
  124. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/hunyuan3d/hunyuan3d_dit.py +0 -0
  125. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/hunyuan3d/hunyuan3d_vae.py +0 -0
  126. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/hunyuan3d/moe.py +0 -0
  127. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/hunyuan3d/surface_extractor.py +0 -0
  128. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/hunyuan3d/volume_decoder.py +0 -0
  129. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/qwen_image/__init__.py +0 -0
  130. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/qwen_image/qwen2_5_vl.py +0 -0
  131. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +0 -0
  132. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/qwen_image/qwen_image_vae.py +0 -0
  133. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/sd/__init__.py +0 -0
  134. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/sd/sd_controlnet.py +0 -0
  135. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
  136. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/sd/sd_unet.py +0 -0
  137. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/sd/sd_vae.py +0 -0
  138. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/sd3/__init__.py +0 -0
  139. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/sd3/sd3_dit.py +0 -0
  140. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
  141. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
  142. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/sdxl/__init__.py +0 -0
  143. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/sdxl/sdxl_controlnet.py +0 -0
  144. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
  145. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/sdxl/sdxl_unet.py +0 -0
  146. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
  147. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/text_encoder/__init__.py +0 -0
  148. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/text_encoder/clip.py +0 -0
  149. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/text_encoder/siglip.py +0 -0
  150. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/text_encoder/t5.py +0 -0
  151. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/vae/__init__.py +0 -0
  152. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/vae/vae.py +0 -0
  153. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/wan/__init__.py +0 -0
  154. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/wan/wan_dit.py +0 -0
  155. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
  156. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/wan/wan_s2v_dit.py +0 -0
  157. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
  158. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/models/wan/wan_vae.py +0 -0
  159. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/pipelines/__init__.py +0 -0
  160. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/pipelines/flux_image.py +0 -0
  161. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/pipelines/sd_image.py +0 -0
  162. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/pipelines/sdxl_image.py +0 -0
  163. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/pipelines/utils.py +0 -0
  164. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/pipelines/wan_s2v.py +0 -0
  165. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/pipelines/wan_video.py +0 -0
  166. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/processor/__init__.py +0 -0
  167. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/processor/canny_processor.py +0 -0
  168. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/processor/depth_processor.py +0 -0
  169. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/tokenizers/__init__.py +0 -0
  170. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/tokenizers/base.py +0 -0
  171. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/tokenizers/clip.py +0 -0
  172. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/tokenizers/qwen2.py +0 -0
  173. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/tokenizers/qwen2_vl_image_processor.py +0 -0
  174. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/tokenizers/qwen2_vl_processor.py +0 -0
  175. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/tokenizers/t5.py +0 -0
  176. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/tokenizers/wan.py +0 -0
  177. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/tools/__init__.py +0 -0
  178. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/tools/flux_inpainting_tool.py +0 -0
  179. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/tools/flux_outpainting_tool.py +0 -0
  180. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/tools/flux_reference_tool.py +0 -0
  181. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/tools/flux_replace_tool.py +0 -0
  182. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/__init__.py +0 -0
  183. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/cache.py +0 -0
  184. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/constants.py +0 -0
  185. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/env.py +0 -0
  186. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/flag.py +0 -0
  187. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/fp8_linear.py +0 -0
  188. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/gguf.py +0 -0
  189. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/image.py +0 -0
  190. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/lock.py +0 -0
  191. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/logging.py +0 -0
  192. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/memory/__init__.py +0 -0
  193. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/memory/linear_regression.py +0 -0
  194. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/memory/memory_predcit_model.py +0 -0
  195. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/offload.py +0 -0
  196. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/onnx.py +0 -0
  197. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/parallel.py +0 -0
  198. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/platform.py +0 -0
  199. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine/utils/prompt.py +0 -0
  200. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine.egg-info/SOURCES.txt +0 -0
  201. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
  202. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine.egg-info/requires.txt +0 -0
  203. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/diffsynth_engine.egg-info/top_level.txt +0 -0
  204. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/docs/tutorial.md +0 -0
  205. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/docs/tutorial_zh.md +0 -0
  206. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/pyproject.toml +0 -0
  207. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/setup.cfg +0 -0
  208. {diffsynth_engine-0.6.1.dev9 → diffsynth_engine-0.6.1.dev11}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev9
3
+ Version: 0.6.1.dev11
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -14,6 +14,9 @@ from .configs import (
14
14
  AttnImpl,
15
15
  ControlNetParams,
16
16
  ControlType,
17
+ QwenImageControlNetParams,
18
+ QwenImageControlType,
19
+ LoraConfig,
17
20
  )
18
21
  from .pipelines import (
19
22
  SDImagePipeline,
@@ -58,6 +61,8 @@ __all__ = [
58
61
  "AttnImpl",
59
62
  "ControlNetParams",
60
63
  "ControlType",
64
+ "QwenImageControlNetParams",
65
+ "QwenImageControlType",
61
66
  "SDImagePipeline",
62
67
  "SDControlNet",
63
68
  "SDXLImagePipeline",
@@ -74,6 +79,7 @@ __all__ = [
74
79
  "FluxIPAdapterRefTool",
75
80
  "FluxReplaceByControlTool",
76
81
  "FluxReduxRefTool",
82
+ "LoraConfig",
77
83
  "fetch_model",
78
84
  "fetch_modelscope_model",
79
85
  "register_fetch_modelscope_model",
@@ -0,0 +1,29 @@
1
+ import torch
2
+
3
+
4
+ def append_zero(x):
5
+ return torch.cat([x, x.new_zeros([1])])
6
+
7
+
8
+ class BaseScheduler:
9
+ def __init__(self):
10
+ self._stored_config = {}
11
+
12
+ def store_config(self):
13
+ self._stored_config = {
14
+ config_name: config_value
15
+ for config_name, config_value in vars(self).items()
16
+ if not config_name.startswith("_")
17
+ }
18
+
19
+ def update_config(self, config_dict):
20
+ for config_name, new_value in config_dict.items():
21
+ if hasattr(self, config_name):
22
+ setattr(self, config_name, new_value)
23
+
24
+ def restore_config(self):
25
+ for config_name, config_value in self._stored_config.items():
26
+ setattr(self, config_name, config_value)
27
+
28
+ def schedule(self, num_inference_steps: int):
29
+ raise NotImplementedError()
@@ -12,16 +12,23 @@ class RecifitedFlowScheduler(BaseScheduler):
12
12
  def __init__(
13
13
  self,
14
14
  shift=1.0,
15
- sigma_min=0.001,
16
- sigma_max=1.0,
15
+ sigma_min=None,
16
+ sigma_max=None,
17
17
  num_train_timesteps=1000,
18
18
  use_dynamic_shifting=False,
19
+ shift_terminal=None,
20
+ exponential_shift_mu=None,
19
21
  ):
22
+ super().__init__()
20
23
  self.shift = shift
21
24
  self.sigma_min = sigma_min
22
25
  self.sigma_max = sigma_max
23
26
  self.num_train_timesteps = num_train_timesteps
24
27
  self.use_dynamic_shifting = use_dynamic_shifting
28
+ self.shift_terminal = shift_terminal
29
+ # static mu for distill model
30
+ self.exponential_shift_mu = exponential_shift_mu
31
+ self.store_config()
25
32
 
26
33
  def _sigma_to_t(self, sigma):
27
34
  return sigma * self.num_train_timesteps
@@ -35,21 +42,30 @@ class RecifitedFlowScheduler(BaseScheduler):
35
42
  def _shift_sigma(self, sigma: torch.Tensor, shift: float):
36
43
  return shift * sigma / (1 + (shift - 1) * sigma)
37
44
 
45
+ def _stretch_shift_to_terminal(self, sigma: torch.Tensor):
46
+ one_minus_z = 1 - sigma
47
+ scale_factor = one_minus_z[-1] / (1 - self.shift_terminal)
48
+ return 1 - (one_minus_z / scale_factor)
49
+
38
50
  def schedule(
39
51
  self,
40
52
  num_inference_steps: int,
41
53
  mu: float | None = None,
42
- sigma_min: float | None = None,
43
- sigma_max: float | None = None,
54
+ sigma_min: float = 0.001,
55
+ sigma_max: float = 1.0,
44
56
  append_value: float = 0,
45
57
  ):
46
- sigma_min = self.sigma_min if sigma_min is None else sigma_min
47
- sigma_max = self.sigma_max if sigma_max is None else sigma_max
58
+ sigma_min = sigma_min if self.sigma_min is None else self.sigma_min
59
+ sigma_max = sigma_max if self.sigma_max is None else self.sigma_max
48
60
  sigmas = torch.linspace(sigma_max, sigma_min, num_inference_steps)
61
+ if self.exponential_shift_mu is not None:
62
+ mu = self.exponential_shift_mu
49
63
  if self.use_dynamic_shifting:
50
64
  sigmas = self._time_shift(mu, 1.0, sigmas) # FLUX
51
65
  else:
52
66
  sigmas = self._shift_sigma(sigmas, self.shift)
67
+ if self.shift_terminal is not None:
68
+ sigmas = self._stretch_shift_to_terminal(sigmas)
53
69
  timesteps = sigmas * self.num_train_timesteps
54
70
  sigmas = append(sigmas, append_value)
55
71
  return sigmas, timesteps
@@ -17,9 +17,15 @@ from .pipeline import (
17
17
  WanStateDicts,
18
18
  WanS2VStateDicts,
19
19
  QwenImageStateDicts,
20
+ LoraConfig,
20
21
  AttnImpl,
21
22
  )
22
- from .controlnet import ControlType, ControlNetParams
23
+ from .controlnet import (
24
+ ControlType,
25
+ ControlNetParams,
26
+ QwenImageControlNetParams,
27
+ QwenImageControlType,
28
+ )
23
29
 
24
30
  __all__ = [
25
31
  "BaseConfig",
@@ -40,7 +46,10 @@ __all__ = [
40
46
  "WanStateDicts",
41
47
  "WanS2VStateDicts",
42
48
  "QwenImageStateDicts",
43
- "AttnImpl",
49
+ "QwenImageControlType",
50
+ "QwenImageControlNetParams",
44
51
  "ControlType",
45
52
  "ControlNetParams",
53
+ "LoraConfig",
54
+ "AttnImpl",
46
55
  ]
@@ -34,3 +34,16 @@ class ControlNetParams:
34
34
  control_start: float = 0
35
35
  control_end: float = 1
36
36
  processor_name: Optional[str] = None # only used for sdxl controlnet union now
37
+
38
+
39
+ class QwenImageControlType(Enum):
40
+ eligen = "eligen"
41
+ in_context = "in_context"
42
+
43
+
44
+ @dataclass
45
+ class QwenImageControlNetParams:
46
+ image: ImageType
47
+ model: str
48
+ control_type: QwenImageControlType
49
+ scale: float = 1.0
@@ -365,3 +365,9 @@ def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig |
365
365
  config.tp_degree = 1
366
366
  else:
367
367
  raise ValueError("sp_ulysses_degree and sp_ring_degree must be specified together")
368
+
369
+
370
+ @dataclass
371
+ class LoraConfig:
372
+ scale: float
373
+ scheduler_config: Optional[Dict] = None
@@ -2,7 +2,7 @@ import torch.nn as nn
2
2
  import torchvision.transforms as transforms
3
3
  import collections.abc
4
4
  import math
5
- from typing import Optional, Tuple, Dict
5
+ from typing import Optional, Dict
6
6
 
7
7
  import torch
8
8
  from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
@@ -112,7 +112,9 @@ class Dinov2SelfAttention(nn.Module):
112
112
  def __init__(self, hidden_size: int, num_attention_heads: int, qkv_bias: bool) -> None:
113
113
  super().__init__()
114
114
  if hidden_size % num_attention_heads != 0:
115
- raise ValueError(f"hidden_size {hidden_size} is not a multiple of num_attention_heads {num_attention_heads}.")
115
+ raise ValueError(
116
+ f"hidden_size {hidden_size} is not a multiple of num_attention_heads {num_attention_heads}."
117
+ )
116
118
 
117
119
  self.num_attention_heads = num_attention_heads
118
120
  self.attention_head_size = int(hidden_size / num_attention_heads)
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
- from typing import Any, Dict, Tuple, Union, Optional
3
+ from typing import Any, Dict, List, Tuple, Union, Optional
4
4
  from einops import rearrange
5
5
 
6
6
  from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
@@ -190,7 +190,8 @@ class QwenDoubleStreamAttention(nn.Module):
190
190
  self,
191
191
  image: torch.FloatTensor,
192
192
  text: torch.FloatTensor,
193
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
193
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
194
+ attn_mask: Optional[torch.Tensor] = None,
194
195
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
195
196
  img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
196
197
  txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
@@ -206,8 +207,8 @@ class QwenDoubleStreamAttention(nn.Module):
206
207
  img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
207
208
  txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
208
209
 
209
- if image_rotary_emb is not None:
210
- img_freqs, txt_freqs = image_rotary_emb
210
+ if rotary_emb is not None:
211
+ img_freqs, txt_freqs = rotary_emb
211
212
  img_q = apply_rotary_emb_qwen(img_q, img_freqs)
212
213
  img_k = apply_rotary_emb_qwen(img_k, img_freqs)
213
214
  txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
@@ -221,7 +222,7 @@ class QwenDoubleStreamAttention(nn.Module):
221
222
  joint_k = joint_k.transpose(1, 2)
222
223
  joint_v = joint_v.transpose(1, 2)
223
224
 
224
- joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, **self.attn_kwargs)
225
+ joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **self.attn_kwargs)
225
226
 
226
227
  joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype)
227
228
 
@@ -285,7 +286,8 @@ class QwenImageTransformerBlock(nn.Module):
285
286
  image: torch.Tensor,
286
287
  text: torch.Tensor,
287
288
  temb: torch.Tensor,
288
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
289
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
290
+ attn_mask: Optional[torch.Tensor] = None,
289
291
  ) -> Tuple[torch.Tensor, torch.Tensor]:
290
292
  img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
291
293
  txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
@@ -299,7 +301,8 @@ class QwenImageTransformerBlock(nn.Module):
299
301
  img_attn_out, txt_attn_out = self.attn(
300
302
  image=img_modulated,
301
303
  text=txt_modulated,
302
- image_rotary_emb=image_rotary_emb,
304
+ rotary_emb=rotary_emb,
305
+ attn_mask=attn_mask,
303
306
  )
304
307
 
305
308
  image = image + img_gate * img_attn_out
@@ -368,13 +371,74 @@ class QwenImageDiT(PreTrainedModel):
368
371
  )
369
372
  return hidden_states
370
373
 
374
+ def process_entity_masks(
375
+ self,
376
+ text: torch.Tensor,
377
+ text_seq_lens: torch.LongTensor,
378
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
379
+ video_fhw: List[Tuple[int, int, int]],
380
+ entity_text: List[torch.Tensor],
381
+ entity_seq_lens: List[torch.LongTensor],
382
+ entity_masks: List[torch.Tensor],
383
+ device: str,
384
+ dtype: torch.dtype,
385
+ ):
386
+ entity_seq_lens = [seq_lens.max().item() for seq_lens in entity_seq_lens]
387
+ text_seq_lens = entity_seq_lens + [text_seq_lens.max().item()]
388
+ entity_text = [
389
+ self.txt_in(self.txt_norm(text[:, :seq_len])) for text, seq_len in zip(entity_text, entity_seq_lens)
390
+ ]
391
+ text = torch.cat(entity_text + [text], dim=1)
392
+
393
+ entity_txt_freqs = [self.pos_embed(video_fhw, seq_len, device)[1] for seq_len in entity_seq_lens]
394
+ img_freqs, txt_freqs = rotary_emb
395
+ txt_freqs = torch.cat(entity_txt_freqs + [txt_freqs], dim=0)
396
+ rotary_emb = (img_freqs, txt_freqs)
397
+
398
+ global_mask = torch.ones_like(entity_masks[0], device=device, dtype=dtype)
399
+ patched_masks = [self.patchify(mask) for mask in entity_masks + [global_mask]]
400
+ batch_size, image_seq_len = patched_masks[0].shape[:2]
401
+ total_seq_len = sum(text_seq_lens) + image_seq_len
402
+ attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), device=device, dtype=torch.bool)
403
+
404
+ # text-image attention mask
405
+ img_start, img_end = sum(text_seq_lens), total_seq_len
406
+ cumsum = [0]
407
+ for seq_len in text_seq_lens:
408
+ cumsum.append(cumsum[-1] + seq_len)
409
+ for i, patched_mask in enumerate(patched_masks):
410
+ txt_start, txt_end = cumsum[i], cumsum[i + 1]
411
+ mask = torch.sum(patched_mask, dim=-1) > 0
412
+ mask = mask.unsqueeze(1).repeat(1, text_seq_lens[i], 1)
413
+ # text-to-image attention
414
+ attention_mask[:, txt_start:txt_end, img_start:img_end] = mask
415
+ # image-to-text attention
416
+ attention_mask[:, img_start:img_end, txt_start:txt_end] = mask.transpose(1, 2)
417
+ # entity text tokens should not attend to each other
418
+ for i in range(len(text_seq_lens)):
419
+ for j in range(len(text_seq_lens)):
420
+ if i == j:
421
+ continue
422
+ i_start, i_end = cumsum[i], cumsum[i + 1]
423
+ j_start, j_end = cumsum[j], cumsum[j + 1]
424
+ attention_mask[:, i_start:i_end, j_start:j_end] = False
425
+
426
+ attn_mask = torch.zeros_like(attention_mask, device=device, dtype=dtype)
427
+ attn_mask[~attention_mask] = -torch.inf
428
+ attn_mask = attn_mask.unsqueeze(1)
429
+ return text, rotary_emb, attn_mask
430
+
371
431
  def forward(
372
432
  self,
373
433
  image: torch.Tensor,
374
434
  edit: torch.Tensor = None,
375
- text: torch.Tensor = None,
376
435
  timestep: torch.LongTensor = None,
377
- txt_seq_lens: torch.LongTensor = None,
436
+ text: torch.Tensor = None,
437
+ text_seq_lens: torch.LongTensor = None,
438
+ context_latents: Optional[torch.Tensor] = None,
439
+ entity_text: Optional[List[torch.Tensor]] = None,
440
+ entity_seq_lens: Optional[List[torch.LongTensor]] = None,
441
+ entity_masks: Optional[List[torch.Tensor]] = None,
378
442
  ):
379
443
  h, w = image.shape[-2:]
380
444
  fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
@@ -385,36 +449,62 @@ class QwenImageDiT(PreTrainedModel):
385
449
  cfg_parallel(
386
450
  (
387
451
  image,
388
- edit,
389
- text,
452
+ *(edit if edit is not None else ()),
390
453
  timestep,
391
- txt_seq_lens,
454
+ text,
455
+ text_seq_lens,
456
+ *(entity_text if entity_text is not None else ()),
457
+ *(entity_seq_lens if entity_seq_lens is not None else ()),
458
+ *(entity_masks if entity_masks is not None else ()),
459
+ context_latents,
392
460
  ),
393
461
  use_cfg=use_cfg,
394
462
  ),
395
463
  ):
396
464
  conditioning = self.time_text_embed(timestep, image.dtype)
397
465
  video_fhw = [(1, h // 2, w // 2)] # frame, height, width
398
- max_length = txt_seq_lens.max().item()
466
+ text_seq_len = text_seq_lens.max().item()
399
467
  image = self.patchify(image)
400
468
  image_seq_len = image.shape[1]
469
+ if context_latents is not None:
470
+ context_latents = context_latents.to(dtype=image.dtype)
471
+ context_latents = self.patchify(context_latents)
472
+ image = torch.cat([image, context_latents], dim=1)
473
+ video_fhw += [(1, h // 2, w // 2)]
401
474
  if edit is not None:
402
- edit = edit.to(dtype=image.dtype)
403
- edit = self.patchify(edit)
404
- image = torch.cat([image, edit], dim=1)
405
- video_fhw += video_fhw
475
+ for img in edit:
476
+ img = img.to(dtype=image.dtype)
477
+ edit_h, edit_w = img.shape[-2:]
478
+ img = self.patchify(img)
479
+ image = torch.cat([image, img], dim=1)
480
+ video_fhw += [(1, edit_h // 2, edit_w // 2)]
406
481
 
407
- image_rotary_emb = self.pos_embed(video_fhw, max_length, image.device)
482
+ rotary_emb = self.pos_embed(video_fhw, text_seq_len, image.device)
408
483
 
409
484
  image = self.img_in(image)
410
- text = self.txt_in(self.txt_norm(text[:, :max_length]))
485
+ text = self.txt_in(self.txt_norm(text[:, :text_seq_len]))
486
+
487
+ attn_mask = None
488
+ if entity_text is not None:
489
+ text, rotary_emb, attn_mask = self.process_entity_masks(
490
+ text,
491
+ text_seq_lens,
492
+ rotary_emb,
493
+ video_fhw,
494
+ entity_text,
495
+ entity_seq_lens,
496
+ entity_masks,
497
+ image.device,
498
+ image.dtype,
499
+ )
411
500
 
412
501
  for block in self.transformer_blocks:
413
- text, image = block(image=image, text=text, temb=conditioning, image_rotary_emb=image_rotary_emb)
502
+ text, image = block(
503
+ image=image, text=text, temb=conditioning, rotary_emb=rotary_emb, attn_mask=attn_mask
504
+ )
414
505
  image = self.norm_out(image, conditioning)
415
506
  image = self.proj_out(image)
416
- if edit is not None:
417
- image = image[:, :image_seq_len]
507
+ image = image[:, :image_seq_len]
418
508
 
419
509
  image = self.unpatchify(image, h, w)
420
510
 
@@ -267,9 +267,13 @@ def linear_interpolation(features: torch.Tensor, input_fps: int, output_fps: int
267
267
  return output_features.transpose(1, 2) # [1, output_len, 512]
268
268
 
269
269
 
270
- def extract_audio_feat(audio_input: torch.Tensor, model: Wav2Vec2Model, dtype=torch.float32, device="cuda:0") -> torch.Tensor:
270
+ def extract_audio_feat(
271
+ audio_input: torch.Tensor, model: Wav2Vec2Model, dtype=torch.float32, device="cuda:0"
272
+ ) -> torch.Tensor:
271
273
  video_rate = 30
272
- input_values = (audio_input - audio_input.mean(dim=1, keepdim=True)) / torch.sqrt(audio_input.var(dim=1, keepdim=True) + 1e-7)
274
+ input_values = (audio_input - audio_input.mean(dim=1, keepdim=True)) / torch.sqrt(
275
+ audio_input.var(dim=1, keepdim=True) + 1e-7
276
+ )
273
277
  feat = torch.cat(model(input_values.to(device)))
274
278
  feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)
275
279
  return feat.to(dtype) # Encoding for the motion
@@ -2,10 +2,10 @@ import os
2
2
  import torch
3
3
  import numpy as np
4
4
  from einops import rearrange
5
- from typing import Dict, List, Tuple
5
+ from typing import Dict, List, Tuple, Union
6
6
  from PIL import Image
7
7
 
8
- from diffsynth_engine.configs import BaseConfig, BaseStateDicts
8
+ from diffsynth_engine.configs import BaseConfig, BaseStateDicts, LoraConfig
9
9
  from diffsynth_engine.utils.offload import enable_sequential_cpu_offload, offload_model_to_dict, restore_model_from_dict
10
10
  from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
11
11
  from diffsynth_engine.utils.gguf import load_gguf_checkpoint
@@ -53,7 +53,7 @@ class BasePipeline:
53
53
 
54
54
  def update_weights(self, state_dicts: BaseStateDicts) -> None:
55
55
  raise NotImplementedError()
56
-
56
+
57
57
  @staticmethod
58
58
  def update_component(
59
59
  component: torch.nn.Module,
@@ -65,10 +65,27 @@ class BasePipeline:
65
65
  component.load_state_dict(state_dict, assign=True)
66
66
  component.to(device=device, dtype=dtype, non_blocking=True)
67
67
 
68
- def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
69
- for lora_path, lora_scale in lora_list:
70
- logger.info(f"loading lora from {lora_path} with scale {lora_scale}")
68
+ def load_loras(
69
+ self,
70
+ lora_list: List[Tuple[str, Union[float, LoraConfig]]],
71
+ fused: bool = True,
72
+ save_original_weight: bool = False,
73
+ ):
74
+ for lora_path, lora_item in lora_list:
75
+ if isinstance(lora_item, float):
76
+ lora_scale = lora_item
77
+ scheduler_config = None
78
+ if isinstance(lora_item, LoraConfig):
79
+ lora_scale = lora_item.scale
80
+ scheduler_config = lora_item.scheduler_config
81
+
82
+ logger.info(f"loading lora from {lora_path} with LoraConfig (scale={lora_scale})")
71
83
  state_dict = load_file(lora_path, device=self.device)
84
+
85
+ if scheduler_config is not None:
86
+ self.apply_scheduler_config(scheduler_config)
87
+ logger.info(f"Applied scheduler args from LoraConfig: {scheduler_config}")
88
+
72
89
  lora_state_dict = self.lora_converter.convert(state_dict)
73
90
  for model_name, state_dict in lora_state_dict.items():
74
91
  model = getattr(self, model_name)
@@ -92,6 +109,9 @@ class BasePipeline:
92
109
  def load_lora(self, path: str, scale: float, fused: bool = True, save_original_weight: bool = False):
93
110
  self.load_loras([(path, scale)], fused, save_original_weight)
94
111
 
112
+ def apply_scheduler_config(self, scheduler_config: Dict):
113
+ pass
114
+
95
115
  def unload_loras(self):
96
116
  raise NotImplementedError()
97
117
 
@@ -200,5 +200,5 @@ class Hunyuan3DShapePipeline(BasePipeline):
200
200
  model_outputs = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
201
201
  latents = self.sampler.step(latents, model_outputs, i)
202
202
  if progress_callback is not None:
203
- progress_callback(i, len(timesteps), "DENOISING")
203
+ progress_callback(i, len(timesteps), "DENOISING")
204
204
  return self.decode_latents(latents)