diffsynth-engine 0.6.1.dev27__tar.gz → 0.6.1.dev29__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/PKG-INFO +1 -1
  2. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/configs/pipeline.py +5 -0
  3. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/base.py +1 -1
  4. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/basic/lora.py +1 -0
  5. diffsynth_engine-0.6.1.dev29/diffsynth_engine/models/basic/lora_nunchaku.py +221 -0
  6. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/basic/video_sparse_attention.py +15 -3
  7. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/qwen_image/__init__.py +8 -0
  8. diffsynth_engine-0.6.1.dev29/diffsynth_engine/models/qwen_image/qwen_image_dit_nunchaku.py +341 -0
  9. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/base.py +11 -4
  10. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/qwen_image.py +64 -2
  11. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/wan_video.py +25 -1
  12. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/flag.py +24 -0
  13. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/parallel.py +23 -107
  14. diffsynth_engine-0.6.1.dev29/diffsynth_engine/utils/process_group.py +149 -0
  15. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine.egg-info/PKG-INFO +1 -1
  16. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine.egg-info/SOURCES.txt +3 -0
  17. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/.gitattributes +0 -0
  18. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/.gitignore +0 -0
  19. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/.pre-commit-config.yaml +0 -0
  20. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/LICENSE +0 -0
  21. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/MANIFEST.in +0 -0
  22. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/README.md +0 -0
  23. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/assets/dingtalk.png +0 -0
  24. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/assets/showcase.jpeg +0 -0
  25. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/assets/tongyi.svg +0 -0
  26. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/__init__.py +0 -0
  27. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/__init__.py +0 -0
  28. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
  29. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
  30. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
  31. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
  32. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
  33. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
  34. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  35. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
  36. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
  37. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
  38. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
  39. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
  40. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
  41. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
  42. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  43. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
  44. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  45. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
  46. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
  47. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
  48. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
  49. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
  50. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
  51. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
  52. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
  53. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
  54. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/components/vae.json +0 -0
  55. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
  56. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
  57. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
  58. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_config.json +0 -0
  59. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_vision_config.json +0 -0
  60. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae.json +0 -0
  61. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae_keymap.json +0 -0
  62. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
  63. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
  64. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
  65. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
  66. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
  67. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
  68. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan2.1_flf2v_14b.json +0 -0
  69. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan2.1_i2v_14b.json +0 -0
  70. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_1.3b.json +0 -0
  71. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_14b.json +0 -0
  72. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan2.2_i2v_a14b.json +0 -0
  73. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan2.2_s2v_14b.json +0 -0
  74. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan2.2_t2v_a14b.json +0 -0
  75. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan2.2_ti2v_5b.json +0 -0
  76. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +0 -0
  77. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/vae/wan2.1_vae.json +0 -0
  78. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/vae/wan2.2_vae.json +0 -0
  79. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/vae/wan_vae_keymap.json +0 -0
  80. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
  81. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
  82. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
  83. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
  84. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
  85. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  86. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
  87. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
  88. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/qwen_image/qwen2_vl_image_processor.json +0 -0
  89. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/added_tokens.json +0 -0
  90. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/merges.txt +0 -0
  91. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/special_tokens_map.json +0 -0
  92. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer.json +0 -0
  93. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer_config.json +0 -0
  94. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/vocab.json +0 -0
  95. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
  96. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
  97. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
  98. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
  99. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
  100. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
  101. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
  102. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
  103. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
  104. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  105. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
  106. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
  107. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/configs/__init__.py +0 -0
  108. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/configs/controlnet.py +0 -0
  109. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/kernels/__init__.py +0 -0
  110. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/__init__.py +0 -0
  111. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/basic/__init__.py +0 -0
  112. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/basic/attention.py +0 -0
  113. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
  114. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/basic/timestep.py +0 -0
  115. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/basic/transformer_helper.py +0 -0
  116. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/basic/unet_helper.py +0 -0
  117. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/flux/__init__.py +0 -0
  118. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/flux/flux_controlnet.py +0 -0
  119. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/flux/flux_dit.py +0 -0
  120. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/flux/flux_dit_fbcache.py +0 -0
  121. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/flux/flux_ipadapter.py +0 -0
  122. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/flux/flux_redux.py +0 -0
  123. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
  124. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/flux/flux_vae.py +0 -0
  125. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/hunyuan3d/__init__.py +0 -0
  126. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +0 -0
  127. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/hunyuan3d/hunyuan3d_dit.py +0 -0
  128. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/hunyuan3d/hunyuan3d_vae.py +0 -0
  129. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/hunyuan3d/moe.py +0 -0
  130. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/hunyuan3d/surface_extractor.py +0 -0
  131. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/hunyuan3d/volume_decoder.py +0 -0
  132. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/qwen_image/qwen2_5_vl.py +0 -0
  133. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/qwen_image/qwen_image_dit.py +0 -0
  134. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +0 -0
  135. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/qwen_image/qwen_image_vae.py +0 -0
  136. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd/__init__.py +0 -0
  137. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd/sd_controlnet.py +0 -0
  138. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
  139. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd/sd_unet.py +0 -0
  140. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd/sd_vae.py +0 -0
  141. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd3/__init__.py +0 -0
  142. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd3/sd3_dit.py +0 -0
  143. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
  144. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
  145. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sdxl/__init__.py +0 -0
  146. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sdxl/sdxl_controlnet.py +0 -0
  147. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
  148. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sdxl/sdxl_unet.py +0 -0
  149. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
  150. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/text_encoder/__init__.py +0 -0
  151. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/text_encoder/clip.py +0 -0
  152. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/text_encoder/siglip.py +0 -0
  153. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/text_encoder/t5.py +0 -0
  154. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/vae/__init__.py +0 -0
  155. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/vae/vae.py +0 -0
  156. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/wan/__init__.py +0 -0
  157. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/wan/wan_audio_encoder.py +0 -0
  158. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/wan/wan_dit.py +0 -0
  159. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
  160. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/wan/wan_s2v_dit.py +0 -0
  161. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
  162. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/wan/wan_vae.py +0 -0
  163. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/__init__.py +0 -0
  164. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/flux_image.py +0 -0
  165. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/hunyuan3d_shape.py +0 -0
  166. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/sd_image.py +0 -0
  167. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/sdxl_image.py +0 -0
  168. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/utils.py +0 -0
  169. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/wan_s2v.py +0 -0
  170. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/processor/__init__.py +0 -0
  171. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/processor/canny_processor.py +0 -0
  172. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/processor/depth_processor.py +0 -0
  173. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tokenizers/__init__.py +0 -0
  174. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tokenizers/base.py +0 -0
  175. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tokenizers/clip.py +0 -0
  176. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tokenizers/qwen2.py +0 -0
  177. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tokenizers/qwen2_vl_image_processor.py +0 -0
  178. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tokenizers/qwen2_vl_processor.py +0 -0
  179. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tokenizers/t5.py +0 -0
  180. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tokenizers/wan.py +0 -0
  181. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tools/__init__.py +0 -0
  182. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tools/flux_inpainting_tool.py +0 -0
  183. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tools/flux_outpainting_tool.py +0 -0
  184. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tools/flux_reference_tool.py +0 -0
  185. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tools/flux_replace_tool.py +0 -0
  186. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/__init__.py +0 -0
  187. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/cache.py +0 -0
  188. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/constants.py +0 -0
  189. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/download.py +0 -0
  190. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/env.py +0 -0
  191. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/fp8_linear.py +0 -0
  192. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/gguf.py +0 -0
  193. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/image.py +0 -0
  194. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/loader.py +0 -0
  195. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/lock.py +0 -0
  196. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/logging.py +0 -0
  197. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/memory/__init__.py +0 -0
  198. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/memory/linear_regression.py +0 -0
  199. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/memory/memory_predcit_model.py +0 -0
  200. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/offload.py +0 -0
  201. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/onnx.py +0 -0
  202. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/platform.py +0 -0
  203. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/prompt.py +0 -0
  204. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/video.py +0 -0
  205. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
  206. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine.egg-info/requires.txt +0 -0
  207. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine.egg-info/top_level.txt +0 -0
  208. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/docs/tutorial.md +0 -0
  209. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/docs/tutorial_zh.md +0 -0
  210. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/pyproject.toml +0 -0
  211. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/setup.cfg +0 -0
  212. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev27
3
+ Version: 0.6.1.dev29
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -251,6 +251,11 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
251
251
  # override OptimizationConfig
252
252
  fbcache_relative_l1_threshold = 0.009
253
253
 
254
+ # svd
255
+ use_nunchaku: Optional[bool] = field(default=None, init=False)
256
+ use_nunchaku_awq: Optional[bool] = field(default=None, init=False)
257
+ use_nunchaku_attn: Optional[bool] = field(default=None, init=False)
258
+
254
259
  @classmethod
255
260
  def basic_config(
256
261
  cls,
@@ -40,7 +40,7 @@ class PreTrainedModel(nn.Module):
40
40
 
41
41
  def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = True):
42
42
  for args in lora_args:
43
- key = args["name"]
43
+ key = args["key"]
44
44
  module = self.get_submodule(key)
45
45
  if not isinstance(module, (LoRALinear, LoRAConv2d)):
46
46
  raise ValueError(f"Unsupported lora key: {key}")
@@ -132,6 +132,7 @@ class LoRALinear(nn.Linear):
132
132
  device: str,
133
133
  dtype: torch.dtype,
134
134
  save_original_weight: bool = True,
135
+ **kwargs,
135
136
  ):
136
137
  if save_original_weight and self._original_weight is None:
137
138
  if self.weight.dtype == torch.float8_e4m3fn:
@@ -0,0 +1,221 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from collections import OrderedDict
4
+
5
+ from .lora import LoRA
6
+ from nunchaku.models.linear import AWQW4A16Linear, SVDQW4A4Linear
7
+ from nunchaku.lora.flux.nunchaku_converter import (
8
+ pack_lowrank_weight,
9
+ unpack_lowrank_weight,
10
+ )
11
+
12
+
13
+ class LoRASVDQW4A4Linear(nn.Module):
14
+ def __init__(
15
+ self,
16
+ origin_linear: SVDQW4A4Linear,
17
+ ):
18
+ super().__init__()
19
+
20
+ self.origin_linear = origin_linear
21
+ self.base_rank = self.origin_linear.rank
22
+ self._lora_dict = OrderedDict()
23
+
24
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
25
+ return self.origin_linear(x)
26
+
27
+ def __getattr__(self, name: str):
28
+ try:
29
+ return super().__getattr__(name)
30
+ except AttributeError:
31
+ return getattr(self.origin_linear, name)
32
+
33
+ def _apply_lora_weights(self, name: str, down: torch.Tensor, up: torch.Tensor, alpha: int, scale: float, rank: int):
34
+ final_scale = scale * (alpha / rank)
35
+
36
+ up_scaled = (up * final_scale).to(
37
+ dtype=self.origin_linear.proj_up.dtype, device=self.origin_linear.proj_up.device
38
+ )
39
+ down_final = down.to(dtype=self.origin_linear.proj_down.dtype, device=self.origin_linear.proj_down.device)
40
+
41
+ with torch.no_grad():
42
+ pd_packed = self.origin_linear.proj_down.data
43
+ pu_packed = self.origin_linear.proj_up.data
44
+ pd = unpack_lowrank_weight(pd_packed, down=True)
45
+ pu = unpack_lowrank_weight(pu_packed, down=False)
46
+
47
+ new_proj_down = torch.cat([pd, down_final], dim=0)
48
+ new_proj_up = torch.cat([pu, up_scaled], dim=1)
49
+
50
+ self.origin_linear.proj_down.data = pack_lowrank_weight(new_proj_down, down=True)
51
+ self.origin_linear.proj_up.data = pack_lowrank_weight(new_proj_up, down=False)
52
+
53
+ current_total_rank = self.origin_linear.rank
54
+ self.origin_linear.rank += rank
55
+ self._lora_dict[name] = {"rank": rank, "alpha": alpha, "scale": scale, "start_idx": current_total_rank}
56
+
57
+ def add_frozen_lora(
58
+ self,
59
+ name: str,
60
+ scale: float,
61
+ rank: int,
62
+ alpha: int,
63
+ up: torch.Tensor,
64
+ down: torch.Tensor,
65
+ device: str,
66
+ dtype: torch.dtype,
67
+ **kwargs,
68
+ ):
69
+ if name in self._lora_dict:
70
+ raise ValueError(f"LoRA with name '{name}' already exists.")
71
+
72
+ self._apply_lora_weights(name, down, up, alpha, scale, rank)
73
+
74
+ def add_qkv_lora(
75
+ self,
76
+ name: str,
77
+ scale: float,
78
+ rank: int,
79
+ alpha: int,
80
+ q_up: torch.Tensor,
81
+ q_down: torch.Tensor,
82
+ k_up: torch.Tensor,
83
+ k_down: torch.Tensor,
84
+ v_up: torch.Tensor,
85
+ v_down: torch.Tensor,
86
+ device: str,
87
+ dtype: torch.dtype,
88
+ **kwargs,
89
+ ):
90
+ if name in self._lora_dict:
91
+ raise ValueError(f"LoRA with name '{name}' already exists.")
92
+
93
+ fused_down = torch.cat([q_down, k_down, v_down], dim=0)
94
+
95
+ fused_rank = 3 * rank
96
+ out_q, out_k = q_up.shape[0], k_up.shape[0]
97
+ fused_up = torch.zeros((self.out_features, fused_rank), device=q_up.device, dtype=q_up.dtype)
98
+ fused_up[:out_q, :rank] = q_up
99
+ fused_up[out_q : out_q + out_k, rank : 2 * rank] = k_up
100
+ fused_up[out_q + out_k :, 2 * rank :] = v_up
101
+
102
+ self._apply_lora_weights(name, fused_down, fused_up, alpha, scale, rank)
103
+
104
+ def modify_scale(self, name: str, scale: float):
105
+ if name not in self._lora_dict:
106
+ raise ValueError(f"LoRA name {name} not found in {self.__class__.__name__}")
107
+
108
+ info = self._lora_dict[name]
109
+ old_scale = info["scale"]
110
+
111
+ if old_scale == scale:
112
+ return
113
+
114
+ if old_scale == 0:
115
+ scale_factor = 0.0
116
+ else:
117
+ scale_factor = scale / old_scale
118
+
119
+ with torch.no_grad():
120
+ lora_rank = info["rank"]
121
+ start_idx = info["start_idx"]
122
+ end_idx = start_idx + lora_rank
123
+
124
+ pu_packed = self.origin_linear.proj_up.data
125
+ pu = unpack_lowrank_weight(pu_packed, down=False)
126
+ pu[:, start_idx:end_idx] *= scale_factor
127
+
128
+ self.origin_linear.proj_up.data = pack_lowrank_weight(pu, down=False)
129
+
130
+ self._lora_dict[name]["scale"] = scale
131
+
132
+ def clear(self, release_all_cpu_memory: bool = False):
133
+ if not self._lora_dict:
134
+ return
135
+
136
+ with torch.no_grad():
137
+ pd_packed = self.origin_linear.proj_down.data
138
+ pu_packed = self.origin_linear.proj_up.data
139
+
140
+ pd = unpack_lowrank_weight(pd_packed, down=True)
141
+ pu = unpack_lowrank_weight(pu_packed, down=False)
142
+
143
+ pd_reset = pd[: self.base_rank, :].clone()
144
+ pu_reset = pu[:, : self.base_rank].clone()
145
+
146
+ self.origin_linear.proj_down.data = pack_lowrank_weight(pd_reset, down=True)
147
+ self.origin_linear.proj_up.data = pack_lowrank_weight(pu_reset, down=False)
148
+
149
+ self.origin_linear.rank = self.base_rank
150
+
151
+ self._lora_dict.clear()
152
+
153
+
154
+ class LoRAAWQW4A16Linear(nn.Module):
155
+ def __init__(self, origin_linear: AWQW4A16Linear):
156
+ super().__init__()
157
+ self.origin_linear = origin_linear
158
+ self._lora_dict = OrderedDict()
159
+
160
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
161
+ quantized_output = self.origin_linear(x)
162
+
163
+ for name, lora in self._lora_dict.items():
164
+ quantized_output += lora(x.to(lora.dtype)).to(quantized_output.dtype)
165
+
166
+ return quantized_output
167
+
168
+ def __getattr__(self, name: str):
169
+ try:
170
+ return super().__getattr__(name)
171
+ except AttributeError:
172
+ return getattr(self.origin_linear, name)
173
+
174
+ def add_lora(
175
+ self,
176
+ name: str,
177
+ scale: float,
178
+ rank: int,
179
+ alpha: int,
180
+ up: torch.Tensor,
181
+ down: torch.Tensor,
182
+ device: str,
183
+ dtype: torch.dtype,
184
+ **kwargs,
185
+ ):
186
+ up_linear = nn.Linear(rank, self.out_features, bias=False, device="meta", dtype=dtype).to_empty(device=device)
187
+ down_linear = nn.Linear(self.in_features, rank, bias=False, device="meta", dtype=dtype).to_empty(device=device)
188
+
189
+ up_linear.weight.data = up.reshape(self.out_features, rank)
190
+ down_linear.weight.data = down.reshape(rank, self.in_features)
191
+
192
+ lora = LoRA(scale, rank, alpha, up_linear, down_linear, device, dtype)
193
+ self._lora_dict[name] = lora
194
+
195
+ def modify_scale(self, name: str, scale: float):
196
+ if name not in self._lora_dict:
197
+ raise ValueError(f"LoRA name {name} not found in {self.__class__.__name__}")
198
+ self._lora_dict[name].scale = scale
199
+
200
+ def add_frozen_lora(self, *args, **kwargs):
201
+ raise NotImplementedError("Frozen LoRA (merging weights) is not supported for AWQW4A16Linear.")
202
+
203
+ def clear(self, *args, **kwargs):
204
+ self._lora_dict.clear()
205
+
206
+
207
+ def patch_nunchaku_model_for_lora(model: nn.Module):
208
+ def _recursive_patch(module: nn.Module):
209
+ for name, child_module in module.named_children():
210
+ replacement = None
211
+ if isinstance(child_module, AWQW4A16Linear):
212
+ replacement = LoRAAWQW4A16Linear(child_module)
213
+ elif isinstance(child_module, SVDQW4A4Linear):
214
+ replacement = LoRASVDQW4A4Linear(child_module)
215
+
216
+ if replacement:
217
+ setattr(module, name, replacement)
218
+ else:
219
+ _recursive_patch(child_module)
220
+
221
+ _recursive_patch(model)
@@ -3,10 +3,15 @@ import math
3
3
  import functools
4
4
 
5
5
  from diffsynth_engine.utils.flag import VIDEO_SPARSE_ATTN_AVAILABLE
6
- from diffsynth_engine.utils.parallel import get_sp_ulysses_group, get_sp_ring_world_size
6
+ from diffsynth_engine.utils.process_group import get_sp_ulysses_group, get_sp_ring_world_size
7
7
 
8
+
9
+ vsa_core = None
8
10
  if VIDEO_SPARSE_ATTN_AVAILABLE:
9
- from vsa import video_sparse_attn as vsa_core
11
+ try:
12
+ from vsa import video_sparse_attn as vsa_core
13
+ except Exception:
14
+ vsa_core = None
10
15
 
11
16
  VSA_TILE_SIZE = (4, 4, 4)
12
17
 
@@ -171,6 +176,12 @@ def video_sparse_attn(
171
176
  variable_block_sizes: torch.LongTensor,
172
177
  non_pad_index: torch.LongTensor,
173
178
  ):
179
+ if vsa_core is None:
180
+ raise RuntimeError(
181
+ "Video sparse attention (VSA) is not available. "
182
+ "Please install the 'vsa' package and ensure all its dependencies (including pytest) are installed."
183
+ )
184
+
174
185
  q = tile(q, num_tiles, tile_partition_indices, non_pad_index)
175
186
  k = tile(k, num_tiles, tile_partition_indices, non_pad_index)
176
187
  v = tile(v, num_tiles, tile_partition_indices, non_pad_index)
@@ -212,7 +223,8 @@ def distributed_video_sparse_attn(
212
223
  ):
213
224
  from yunchang.comm.all_to_all import SeqAllToAll4D
214
225
 
215
- assert get_sp_ring_world_size() == 1, "distributed video sparse attention requires ring degree to be 1"
226
+ ring_world_size = get_sp_ring_world_size()
227
+ assert ring_world_size == 1, "distributed video sparse attention requires ring degree to be 1"
216
228
  sp_ulysses_group = get_sp_ulysses_group()
217
229
 
218
230
  q = SeqAllToAll4D.apply(sp_ulysses_group, q, scatter_idx, gather_idx)
@@ -11,3 +11,11 @@ __all__ = [
11
11
  "Qwen2_5_VLVisionConfig",
12
12
  "Qwen2_5_VLConfig",
13
13
  ]
14
+
15
+ try:
16
+ from .qwen_image_dit_nunchaku import QwenImageDiTNunchaku
17
+
18
+ __all__.append("QwenImageDiTNunchaku")
19
+
20
+ except (ImportError, ModuleNotFoundError):
21
+ pass
@@ -0,0 +1,341 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Any, Dict, List, Tuple, Optional
4
+ from einops import rearrange
5
+
6
+ from diffsynth_engine.models.basic import attention as attention_ops
7
+ from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
8
+ from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, RMSNorm
9
+ from diffsynth_engine.models.qwen_image.qwen_image_dit import (
10
+ QwenFeedForward,
11
+ apply_rotary_emb_qwen,
12
+ QwenDoubleStreamAttention,
13
+ QwenImageTransformerBlock,
14
+ QwenImageDiT,
15
+ QwenEmbedRope,
16
+ )
17
+
18
+ from nunchaku.models.utils import fuse_linears
19
+ from nunchaku.ops.fused import fused_gelu_mlp
20
+ from nunchaku.models.linear import AWQW4A16Linear, SVDQW4A4Linear
21
+ from diffsynth_engine.models.basic.lora import LoRALinear, LoRAConv2d
22
+ from diffsynth_engine.models.basic.lora_nunchaku import LoRASVDQW4A4Linear, LoRAAWQW4A16Linear
23
+
24
+
25
+ class QwenDoubleStreamAttentionNunchaku(QwenDoubleStreamAttention):
26
+ def __init__(
27
+ self,
28
+ dim_a,
29
+ dim_b,
30
+ num_heads,
31
+ head_dim,
32
+ device: str = "cuda:0",
33
+ dtype: torch.dtype = torch.bfloat16,
34
+ nunchaku_rank: int = 32,
35
+ ):
36
+ super().__init__(dim_a, dim_b, num_heads, head_dim, device=device, dtype=dtype)
37
+
38
+ to_qkv = fuse_linears([self.to_q, self.to_k, self.to_v])
39
+ self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, rank=nunchaku_rank)
40
+ self.to_out = SVDQW4A4Linear.from_linear(self.to_out, rank=nunchaku_rank)
41
+
42
+ del self.to_q, self.to_k, self.to_v
43
+
44
+ add_qkv_proj = fuse_linears([self.add_q_proj, self.add_k_proj, self.add_v_proj])
45
+ self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, rank=nunchaku_rank)
46
+ self.to_add_out = SVDQW4A4Linear.from_linear(self.to_add_out, rank=nunchaku_rank)
47
+
48
+ del self.add_q_proj, self.add_k_proj, self.add_v_proj
49
+
50
+ def forward(
51
+ self,
52
+ image: torch.FloatTensor,
53
+ text: torch.FloatTensor,
54
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
55
+ attn_mask: Optional[torch.Tensor] = None,
56
+ attn_kwargs: Optional[Dict[str, Any]] = None,
57
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
58
+ img_q, img_k, img_v = self.to_qkv(image).chunk(3, dim=-1)
59
+ txt_q, txt_k, txt_v = self.add_qkv_proj(text).chunk(3, dim=-1)
60
+
61
+ img_q = rearrange(img_q, "b s (h d) -> b s h d", h=self.num_heads)
62
+ img_k = rearrange(img_k, "b s (h d) -> b s h d", h=self.num_heads)
63
+ img_v = rearrange(img_v, "b s (h d) -> b s h d", h=self.num_heads)
64
+
65
+ txt_q = rearrange(txt_q, "b s (h d) -> b s h d", h=self.num_heads)
66
+ txt_k = rearrange(txt_k, "b s (h d) -> b s h d", h=self.num_heads)
67
+ txt_v = rearrange(txt_v, "b s (h d) -> b s h d", h=self.num_heads)
68
+
69
+ img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
70
+ txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
71
+
72
+ if rotary_emb is not None:
73
+ img_freqs, txt_freqs = rotary_emb
74
+ img_q = apply_rotary_emb_qwen(img_q, img_freqs)
75
+ img_k = apply_rotary_emb_qwen(img_k, img_freqs)
76
+ txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
77
+ txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
78
+
79
+ joint_q = torch.cat([txt_q, img_q], dim=1)
80
+ joint_k = torch.cat([txt_k, img_k], dim=1)
81
+ joint_v = torch.cat([txt_v, img_v], dim=1)
82
+
83
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
84
+ joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **attn_kwargs)
85
+
86
+ joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype)
87
+
88
+ txt_attn_output = joint_attn_out[:, : text.shape[1], :]
89
+ img_attn_output = joint_attn_out[:, text.shape[1] :, :]
90
+
91
+ img_attn_output = self.to_out(img_attn_output)
92
+ txt_attn_output = self.to_add_out(txt_attn_output)
93
+
94
+ return img_attn_output, txt_attn_output
95
+
96
+
97
+ class QwenFeedForwardNunchaku(QwenFeedForward):
98
+ def __init__(
99
+ self,
100
+ dim: int,
101
+ dim_out: Optional[int] = None,
102
+ dropout: float = 0.0,
103
+ device: str = "cuda:0",
104
+ dtype: torch.dtype = torch.bfloat16,
105
+ rank: int = 32,
106
+ ):
107
+ super().__init__(dim, dim_out, dropout, device=device, dtype=dtype)
108
+ self.net[0].proj = SVDQW4A4Linear.from_linear(self.net[0].proj, rank=rank)
109
+ self.net[2] = SVDQW4A4Linear.from_linear(self.net[2], rank=rank)
110
+ self.net[2].act_unsigned = self.net[2].precision != "nvfp4"
111
+
112
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
113
+ return fused_gelu_mlp(hidden_states, self.net[0].proj, self.net[2])
114
+
115
+
116
+ class QwenImageTransformerBlockNunchaku(QwenImageTransformerBlock):
117
+ def __init__(
118
+ self,
119
+ dim: int,
120
+ num_attention_heads: int,
121
+ attention_head_dim: int,
122
+ eps: float = 1e-6,
123
+ device: str = "cuda:0",
124
+ dtype: torch.dtype = torch.bfloat16,
125
+ scale_shift: float = 1.0,
126
+ use_nunchaku_awq: bool = True,
127
+ use_nunchaku_attn: bool = True,
128
+ nunchaku_rank: int = 32,
129
+ ):
130
+ super().__init__(dim, num_attention_heads, attention_head_dim, eps, device=device, dtype=dtype)
131
+
132
+ self.use_nunchaku_awq = use_nunchaku_awq
133
+ if use_nunchaku_awq:
134
+ self.img_mod[1] = AWQW4A16Linear.from_linear(self.img_mod[1], rank=nunchaku_rank)
135
+
136
+ if use_nunchaku_attn:
137
+ self.attn = QwenDoubleStreamAttentionNunchaku(
138
+ dim_a=dim,
139
+ dim_b=dim,
140
+ num_heads=num_attention_heads,
141
+ head_dim=attention_head_dim,
142
+ device=device,
143
+ dtype=dtype,
144
+ nunchaku_rank=nunchaku_rank,
145
+ )
146
+ else:
147
+ self.attn = QwenDoubleStreamAttention(
148
+ dim_a=dim,
149
+ dim_b=dim,
150
+ num_heads=num_attention_heads,
151
+ head_dim=attention_head_dim,
152
+ device=device,
153
+ dtype=dtype,
154
+ )
155
+
156
+ self.img_mlp = QwenFeedForwardNunchaku(dim=dim, dim_out=dim, device=device, dtype=dtype, rank=nunchaku_rank)
157
+
158
+ if use_nunchaku_awq:
159
+ self.txt_mod[1] = AWQW4A16Linear.from_linear(self.txt_mod[1], rank=nunchaku_rank)
160
+
161
+ self.txt_mlp = QwenFeedForwardNunchaku(dim=dim, dim_out=dim, device=device, dtype=dtype, rank=nunchaku_rank)
162
+
163
+ self.scale_shift = scale_shift
164
+
165
+ def _modulate(self, x, mod_params):
166
+ shift, scale, gate = mod_params.chunk(3, dim=-1)
167
+ if self.use_nunchaku_awq:
168
+ if self.scale_shift != 0:
169
+ scale.add_(self.scale_shift)
170
+ return x * scale.unsqueeze(1) + shift.unsqueeze(1), gate.unsqueeze(1)
171
+ else:
172
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
173
+
174
+ def forward(
175
+ self,
176
+ image: torch.Tensor,
177
+ text: torch.Tensor,
178
+ temb: torch.Tensor,
179
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
180
+ attn_mask: Optional[torch.Tensor] = None,
181
+ attn_kwargs: Optional[Dict[str, Any]] = None,
182
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
183
+ if self.use_nunchaku_awq:
184
+ img_mod_params = self.img_mod(temb) # [B, 6*dim]
185
+ txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
186
+
187
+ # nunchaku's mod_params is [B, 6*dim] instead of [B, dim*6]
188
+ img_mod_params = (
189
+ img_mod_params.view(img_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(img_mod_params.shape[0], -1)
190
+ )
191
+ txt_mod_params = (
192
+ txt_mod_params.view(txt_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(txt_mod_params.shape[0], -1)
193
+ )
194
+
195
+ img_mod_attn, img_mod_mlp = img_mod_params.chunk(2, dim=-1) # [B, 3*dim] each
196
+ txt_mod_attn, txt_mod_mlp = txt_mod_params.chunk(2, dim=-1) # [B, 3*dim] each
197
+ else:
198
+ img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
199
+ txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
200
+
201
+ img_normed = self.img_norm1(image)
202
+ img_modulated, img_gate = self._modulate(img_normed, img_mod_attn)
203
+
204
+ txt_normed = self.txt_norm1(text)
205
+ txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
206
+
207
+ img_attn_out, txt_attn_out = self.attn(
208
+ image=img_modulated,
209
+ text=txt_modulated,
210
+ rotary_emb=rotary_emb,
211
+ attn_mask=attn_mask,
212
+ attn_kwargs=attn_kwargs,
213
+ )
214
+
215
+ image = image + img_gate * img_attn_out
216
+ text = text + txt_gate * txt_attn_out
217
+
218
+ img_normed_2 = self.img_norm2(image)
219
+ img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp)
220
+
221
+ txt_normed_2 = self.txt_norm2(text)
222
+ txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
223
+
224
+ img_mlp_out = self.img_mlp(img_modulated_2)
225
+ txt_mlp_out = self.txt_mlp(txt_modulated_2)
226
+
227
+ image = image + img_gate_2 * img_mlp_out
228
+ text = text + txt_gate_2 * txt_mlp_out
229
+
230
+ return text, image
231
+
232
+
233
+ class QwenImageDiTNunchaku(QwenImageDiT):
234
+ def __init__(
235
+ self,
236
+ num_layers: int = 60,
237
+ device: str = "cuda:0",
238
+ dtype: torch.dtype = torch.bfloat16,
239
+ use_nunchaku_awq: bool = True,
240
+ use_nunchaku_attn: bool = True,
241
+ nunchaku_rank: int = 32,
242
+ ):
243
+ super().__init__()
244
+
245
+ self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True, device=device)
246
+
247
+ self.time_text_embed = TimestepEmbeddings(256, 3072, device=device, dtype=dtype)
248
+
249
+ self.txt_norm = RMSNorm(3584, eps=1e-6, device=device, dtype=dtype)
250
+
251
+ self.img_in = nn.Linear(64, 3072, device=device, dtype=dtype)
252
+ self.txt_in = nn.Linear(3584, 3072, device=device, dtype=dtype)
253
+
254
+ self.transformer_blocks = nn.ModuleList(
255
+ [
256
+ QwenImageTransformerBlockNunchaku(
257
+ dim=3072,
258
+ num_attention_heads=24,
259
+ attention_head_dim=128,
260
+ device=device,
261
+ dtype=dtype,
262
+ scale_shift=0,
263
+ use_nunchaku_awq=use_nunchaku_awq,
264
+ use_nunchaku_attn=use_nunchaku_attn,
265
+ nunchaku_rank=nunchaku_rank,
266
+ )
267
+ for _ in range(num_layers)
268
+ ]
269
+ )
270
+ self.norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
271
+ self.proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
272
+
273
+ @classmethod
274
+ def from_state_dict(
275
+ cls,
276
+ state_dict: Dict[str, torch.Tensor],
277
+ device: str,
278
+ dtype: torch.dtype,
279
+ num_layers: int = 60,
280
+ use_nunchaku_awq: bool = True,
281
+ use_nunchaku_attn: bool = True,
282
+ nunchaku_rank: int = 32,
283
+ ):
284
+ model = cls(
285
+ device="meta",
286
+ dtype=dtype,
287
+ num_layers=num_layers,
288
+ use_nunchaku_awq=use_nunchaku_awq,
289
+ use_nunchaku_attn=use_nunchaku_attn,
290
+ nunchaku_rank=nunchaku_rank,
291
+ )
292
+ model = model.requires_grad_(False)
293
+ model.load_state_dict(state_dict, assign=True)
294
+ model.to(device=device, non_blocking=True)
295
+ return model
296
+
297
+ def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = False):
298
+ fuse_dict = {}
299
+ for args in lora_args:
300
+ key = args["key"]
301
+ if any(suffix in key for suffix in {"add_q_proj", "add_k_proj", "add_v_proj"}):
302
+ fuse_key = f"{key.rsplit('.', 1)[0]}.add_qkv_proj"
303
+ type = key.rsplit(".", 1)[-1].split("_")[1]
304
+ fuse_dict[fuse_key] = fuse_dict.get(fuse_key, {})
305
+ fuse_dict[fuse_key][type] = args
306
+ continue
307
+
308
+ if any(suffix in key for suffix in {"to_q", "to_k", "to_v"}):
309
+ fuse_key = f"{key.rsplit('.', 1)[0]}.to_qkv"
310
+ type = key.rsplit(".", 1)[-1].split("_")[1]
311
+ fuse_dict[fuse_key] = fuse_dict.get(fuse_key, {})
312
+ fuse_dict[fuse_key][type] = args
313
+ continue
314
+
315
+ module = self.get_submodule(key)
316
+ if not isinstance(module, (LoRALinear, LoRAConv2d, LoRASVDQW4A4Linear, LoRAAWQW4A16Linear)):
317
+ raise ValueError(f"Unsupported lora key: {key}")
318
+
319
+ if fused and not isinstance(module, LoRAAWQW4A16Linear):
320
+ module.add_frozen_lora(**args)
321
+ else:
322
+ module.add_lora(**args)
323
+
324
+ for key in fuse_dict.keys():
325
+ module = self.get_submodule(key)
326
+ if not isinstance(module, LoRASVDQW4A4Linear):
327
+ raise ValueError(f"Unsupported lora key: {key}")
328
+ module.add_qkv_lora(
329
+ name=args["name"],
330
+ scale=fuse_dict[key]["q"]["scale"],
331
+ rank=fuse_dict[key]["q"]["rank"],
332
+ alpha=fuse_dict[key]["q"]["alpha"],
333
+ q_up=fuse_dict[key]["q"]["up"],
334
+ q_down=fuse_dict[key]["q"]["down"],
335
+ k_up=fuse_dict[key]["k"]["up"],
336
+ k_down=fuse_dict[key]["k"]["down"],
337
+ v_up=fuse_dict[key]["v"]["up"],
338
+ v_down=fuse_dict[key]["v"]["down"],
339
+ device=fuse_dict[key]["q"]["device"],
340
+ dtype=fuse_dict[key]["q"]["dtype"],
341
+ )
@@ -106,7 +106,8 @@ class BasePipeline:
106
106
  for key, param in state_dict.items():
107
107
  lora_args.append(
108
108
  {
109
- "name": key,
109
+ "name": lora_path,
110
+ "key": key,
110
111
  "scale": lora_scale,
111
112
  "rank": param["rank"],
112
113
  "alpha": param["alpha"],
@@ -130,7 +131,10 @@ class BasePipeline:
130
131
 
131
132
  @staticmethod
132
133
  def load_model_checkpoint(
133
- checkpoint_path: str | List[str], device: str = "cpu", dtype: torch.dtype = torch.float16
134
+ checkpoint_path: str | List[str],
135
+ device: str = "cpu",
136
+ dtype: torch.dtype = torch.float16,
137
+ convert_dtype: bool = True,
134
138
  ) -> Dict[str, torch.Tensor]:
135
139
  if isinstance(checkpoint_path, str):
136
140
  checkpoint_path = [checkpoint_path]
@@ -140,8 +144,11 @@ class BasePipeline:
140
144
  raise FileNotFoundError(f"{path} is not a file")
141
145
  elif path.endswith(".safetensors"):
142
146
  state_dict_ = load_file(path, device=device)
143
- for key, value in state_dict_.items():
144
- state_dict[key] = value.to(dtype)
147
+ if convert_dtype:
148
+ for key, value in state_dict_.items():
149
+ state_dict[key] = value.to(dtype)
150
+ else:
151
+ state_dict.update(state_dict_)
145
152
 
146
153
  elif path.endswith(".gguf"):
147
154
  state_dict.update(**load_gguf_checkpoint(path, device=device, dtype=dtype))