diffsynth-engine 0.6.1.dev27__tar.gz → 0.6.1.dev28__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/PKG-INFO +1 -1
  2. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/configs/pipeline.py +5 -0
  3. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/base.py +1 -1
  4. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/basic/lora.py +1 -0
  5. diffsynth_engine-0.6.1.dev28/diffsynth_engine/models/basic/lora_nunchaku.py +221 -0
  6. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/qwen_image/__init__.py +8 -0
  7. diffsynth_engine-0.6.1.dev28/diffsynth_engine/models/qwen_image/qwen_image_dit_nunchaku.py +341 -0
  8. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/pipelines/base.py +11 -4
  9. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/pipelines/qwen_image.py +64 -2
  10. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/flag.py +24 -0
  11. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine.egg-info/PKG-INFO +1 -1
  12. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine.egg-info/SOURCES.txt +2 -0
  13. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/.gitattributes +0 -0
  14. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/.gitignore +0 -0
  15. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/.pre-commit-config.yaml +0 -0
  16. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/LICENSE +0 -0
  17. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/MANIFEST.in +0 -0
  18. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/README.md +0 -0
  19. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/assets/dingtalk.png +0 -0
  20. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/assets/showcase.jpeg +0 -0
  21. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/assets/tongyi.svg +0 -0
  22. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/__init__.py +0 -0
  23. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/__init__.py +0 -0
  24. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
  25. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
  26. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
  27. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
  28. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
  29. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
  30. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  31. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
  32. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
  33. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
  34. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
  35. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
  36. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
  37. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
  38. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  39. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
  40. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  41. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
  42. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
  43. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
  44. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
  45. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
  46. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
  47. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
  48. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
  49. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
  50. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/components/vae.json +0 -0
  51. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
  52. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
  53. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
  54. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_config.json +0 -0
  55. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_vision_config.json +0 -0
  56. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae.json +0 -0
  57. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae_keymap.json +0 -0
  58. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
  59. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
  60. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
  61. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
  62. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
  63. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
  64. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/wan/dit/wan2.1_flf2v_14b.json +0 -0
  65. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/wan/dit/wan2.1_i2v_14b.json +0 -0
  66. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_1.3b.json +0 -0
  67. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_14b.json +0 -0
  68. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/wan/dit/wan2.2_i2v_a14b.json +0 -0
  69. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/wan/dit/wan2.2_s2v_14b.json +0 -0
  70. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/wan/dit/wan2.2_t2v_a14b.json +0 -0
  71. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/wan/dit/wan2.2_ti2v_5b.json +0 -0
  72. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +0 -0
  73. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/wan/vae/wan2.1_vae.json +0 -0
  74. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/wan/vae/wan2.2_vae.json +0 -0
  75. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/models/wan/vae/wan_vae_keymap.json +0 -0
  76. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
  77. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
  78. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
  79. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
  80. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
  81. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  82. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
  83. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
  84. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/qwen_image/qwen2_vl_image_processor.json +0 -0
  85. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/added_tokens.json +0 -0
  86. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/merges.txt +0 -0
  87. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/special_tokens_map.json +0 -0
  88. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer.json +0 -0
  89. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer_config.json +0 -0
  90. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/vocab.json +0 -0
  91. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
  92. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
  93. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
  94. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
  95. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
  96. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
  97. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
  98. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
  99. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
  100. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  101. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
  102. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
  103. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/configs/__init__.py +0 -0
  104. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/configs/controlnet.py +0 -0
  105. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/kernels/__init__.py +0 -0
  106. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/__init__.py +0 -0
  107. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/basic/__init__.py +0 -0
  108. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/basic/attention.py +0 -0
  109. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
  110. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/basic/timestep.py +0 -0
  111. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/basic/transformer_helper.py +0 -0
  112. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/basic/unet_helper.py +0 -0
  113. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/basic/video_sparse_attention.py +0 -0
  114. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/flux/__init__.py +0 -0
  115. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/flux/flux_controlnet.py +0 -0
  116. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/flux/flux_dit.py +0 -0
  117. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/flux/flux_dit_fbcache.py +0 -0
  118. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/flux/flux_ipadapter.py +0 -0
  119. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/flux/flux_redux.py +0 -0
  120. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
  121. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/flux/flux_vae.py +0 -0
  122. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/hunyuan3d/__init__.py +0 -0
  123. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +0 -0
  124. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/hunyuan3d/hunyuan3d_dit.py +0 -0
  125. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/hunyuan3d/hunyuan3d_vae.py +0 -0
  126. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/hunyuan3d/moe.py +0 -0
  127. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/hunyuan3d/surface_extractor.py +0 -0
  128. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/hunyuan3d/volume_decoder.py +0 -0
  129. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/qwen_image/qwen2_5_vl.py +0 -0
  130. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/qwen_image/qwen_image_dit.py +0 -0
  131. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +0 -0
  132. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/qwen_image/qwen_image_vae.py +0 -0
  133. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/sd/__init__.py +0 -0
  134. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/sd/sd_controlnet.py +0 -0
  135. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
  136. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/sd/sd_unet.py +0 -0
  137. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/sd/sd_vae.py +0 -0
  138. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/sd3/__init__.py +0 -0
  139. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/sd3/sd3_dit.py +0 -0
  140. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
  141. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
  142. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/sdxl/__init__.py +0 -0
  143. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/sdxl/sdxl_controlnet.py +0 -0
  144. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
  145. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/sdxl/sdxl_unet.py +0 -0
  146. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
  147. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/text_encoder/__init__.py +0 -0
  148. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/text_encoder/clip.py +0 -0
  149. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/text_encoder/siglip.py +0 -0
  150. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/text_encoder/t5.py +0 -0
  151. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/vae/__init__.py +0 -0
  152. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/vae/vae.py +0 -0
  153. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/wan/__init__.py +0 -0
  154. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/wan/wan_audio_encoder.py +0 -0
  155. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/wan/wan_dit.py +0 -0
  156. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
  157. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/wan/wan_s2v_dit.py +0 -0
  158. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
  159. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/models/wan/wan_vae.py +0 -0
  160. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/pipelines/__init__.py +0 -0
  161. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/pipelines/flux_image.py +0 -0
  162. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/pipelines/hunyuan3d_shape.py +0 -0
  163. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/pipelines/sd_image.py +0 -0
  164. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/pipelines/sdxl_image.py +0 -0
  165. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/pipelines/utils.py +0 -0
  166. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/pipelines/wan_s2v.py +0 -0
  167. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/pipelines/wan_video.py +0 -0
  168. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/processor/__init__.py +0 -0
  169. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/processor/canny_processor.py +0 -0
  170. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/processor/depth_processor.py +0 -0
  171. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/tokenizers/__init__.py +0 -0
  172. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/tokenizers/base.py +0 -0
  173. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/tokenizers/clip.py +0 -0
  174. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/tokenizers/qwen2.py +0 -0
  175. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/tokenizers/qwen2_vl_image_processor.py +0 -0
  176. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/tokenizers/qwen2_vl_processor.py +0 -0
  177. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/tokenizers/t5.py +0 -0
  178. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/tokenizers/wan.py +0 -0
  179. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/tools/__init__.py +0 -0
  180. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/tools/flux_inpainting_tool.py +0 -0
  181. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/tools/flux_outpainting_tool.py +0 -0
  182. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/tools/flux_reference_tool.py +0 -0
  183. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/tools/flux_replace_tool.py +0 -0
  184. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/__init__.py +0 -0
  185. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/cache.py +0 -0
  186. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/constants.py +0 -0
  187. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/download.py +0 -0
  188. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/env.py +0 -0
  189. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/fp8_linear.py +0 -0
  190. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/gguf.py +0 -0
  191. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/image.py +0 -0
  192. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/loader.py +0 -0
  193. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/lock.py +0 -0
  194. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/logging.py +0 -0
  195. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/memory/__init__.py +0 -0
  196. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/memory/linear_regression.py +0 -0
  197. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/memory/memory_predcit_model.py +0 -0
  198. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/offload.py +0 -0
  199. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/onnx.py +0 -0
  200. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/parallel.py +0 -0
  201. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/platform.py +0 -0
  202. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/prompt.py +0 -0
  203. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine/utils/video.py +0 -0
  204. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
  205. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine.egg-info/requires.txt +0 -0
  206. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/diffsynth_engine.egg-info/top_level.txt +0 -0
  207. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/docs/tutorial.md +0 -0
  208. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/docs/tutorial_zh.md +0 -0
  209. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/pyproject.toml +0 -0
  210. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/setup.cfg +0 -0
  211. {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev28}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev27
3
+ Version: 0.6.1.dev28
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -251,6 +251,11 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
251
251
  # override OptimizationConfig
252
252
  fbcache_relative_l1_threshold = 0.009
253
253
 
254
+ # svd
255
+ use_nunchaku: Optional[bool] = field(default=None, init=False)
256
+ use_nunchaku_awq: Optional[bool] = field(default=None, init=False)
257
+ use_nunchaku_attn: Optional[bool] = field(default=None, init=False)
258
+
254
259
  @classmethod
255
260
  def basic_config(
256
261
  cls,
@@ -40,7 +40,7 @@ class PreTrainedModel(nn.Module):
40
40
 
41
41
  def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = True):
42
42
  for args in lora_args:
43
- key = args["name"]
43
+ key = args["key"]
44
44
  module = self.get_submodule(key)
45
45
  if not isinstance(module, (LoRALinear, LoRAConv2d)):
46
46
  raise ValueError(f"Unsupported lora key: {key}")
@@ -132,6 +132,7 @@ class LoRALinear(nn.Linear):
132
132
  device: str,
133
133
  dtype: torch.dtype,
134
134
  save_original_weight: bool = True,
135
+ **kwargs,
135
136
  ):
136
137
  if save_original_weight and self._original_weight is None:
137
138
  if self.weight.dtype == torch.float8_e4m3fn:
@@ -0,0 +1,221 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from collections import OrderedDict
4
+
5
+ from .lora import LoRA
6
+ from nunchaku.models.linear import AWQW4A16Linear, SVDQW4A4Linear
7
+ from nunchaku.lora.flux.nunchaku_converter import (
8
+ pack_lowrank_weight,
9
+ unpack_lowrank_weight,
10
+ )
11
+
12
+
13
+ class LoRASVDQW4A4Linear(nn.Module):
14
+ def __init__(
15
+ self,
16
+ origin_linear: SVDQW4A4Linear,
17
+ ):
18
+ super().__init__()
19
+
20
+ self.origin_linear = origin_linear
21
+ self.base_rank = self.origin_linear.rank
22
+ self._lora_dict = OrderedDict()
23
+
24
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
25
+ return self.origin_linear(x)
26
+
27
+ def __getattr__(self, name: str):
28
+ try:
29
+ return super().__getattr__(name)
30
+ except AttributeError:
31
+ return getattr(self.origin_linear, name)
32
+
33
+ def _apply_lora_weights(self, name: str, down: torch.Tensor, up: torch.Tensor, alpha: int, scale: float, rank: int):
34
+ final_scale = scale * (alpha / rank)
35
+
36
+ up_scaled = (up * final_scale).to(
37
+ dtype=self.origin_linear.proj_up.dtype, device=self.origin_linear.proj_up.device
38
+ )
39
+ down_final = down.to(dtype=self.origin_linear.proj_down.dtype, device=self.origin_linear.proj_down.device)
40
+
41
+ with torch.no_grad():
42
+ pd_packed = self.origin_linear.proj_down.data
43
+ pu_packed = self.origin_linear.proj_up.data
44
+ pd = unpack_lowrank_weight(pd_packed, down=True)
45
+ pu = unpack_lowrank_weight(pu_packed, down=False)
46
+
47
+ new_proj_down = torch.cat([pd, down_final], dim=0)
48
+ new_proj_up = torch.cat([pu, up_scaled], dim=1)
49
+
50
+ self.origin_linear.proj_down.data = pack_lowrank_weight(new_proj_down, down=True)
51
+ self.origin_linear.proj_up.data = pack_lowrank_weight(new_proj_up, down=False)
52
+
53
+ current_total_rank = self.origin_linear.rank
54
+ self.origin_linear.rank += rank
55
+ self._lora_dict[name] = {"rank": rank, "alpha": alpha, "scale": scale, "start_idx": current_total_rank}
56
+
57
+ def add_frozen_lora(
58
+ self,
59
+ name: str,
60
+ scale: float,
61
+ rank: int,
62
+ alpha: int,
63
+ up: torch.Tensor,
64
+ down: torch.Tensor,
65
+ device: str,
66
+ dtype: torch.dtype,
67
+ **kwargs,
68
+ ):
69
+ if name in self._lora_dict:
70
+ raise ValueError(f"LoRA with name '{name}' already exists.")
71
+
72
+ self._apply_lora_weights(name, down, up, alpha, scale, rank)
73
+
74
+ def add_qkv_lora(
75
+ self,
76
+ name: str,
77
+ scale: float,
78
+ rank: int,
79
+ alpha: int,
80
+ q_up: torch.Tensor,
81
+ q_down: torch.Tensor,
82
+ k_up: torch.Tensor,
83
+ k_down: torch.Tensor,
84
+ v_up: torch.Tensor,
85
+ v_down: torch.Tensor,
86
+ device: str,
87
+ dtype: torch.dtype,
88
+ **kwargs,
89
+ ):
90
+ if name in self._lora_dict:
91
+ raise ValueError(f"LoRA with name '{name}' already exists.")
92
+
93
+ fused_down = torch.cat([q_down, k_down, v_down], dim=0)
94
+
95
+ fused_rank = 3 * rank
96
+ out_q, out_k = q_up.shape[0], k_up.shape[0]
97
+ fused_up = torch.zeros((self.out_features, fused_rank), device=q_up.device, dtype=q_up.dtype)
98
+ fused_up[:out_q, :rank] = q_up
99
+ fused_up[out_q : out_q + out_k, rank : 2 * rank] = k_up
100
+ fused_up[out_q + out_k :, 2 * rank :] = v_up
101
+
102
+ self._apply_lora_weights(name, fused_down, fused_up, alpha, scale, rank)
103
+
104
+ def modify_scale(self, name: str, scale: float):
105
+ if name not in self._lora_dict:
106
+ raise ValueError(f"LoRA name {name} not found in {self.__class__.__name__}")
107
+
108
+ info = self._lora_dict[name]
109
+ old_scale = info["scale"]
110
+
111
+ if old_scale == scale:
112
+ return
113
+
114
+ if old_scale == 0:
115
+ scale_factor = 0.0
116
+ else:
117
+ scale_factor = scale / old_scale
118
+
119
+ with torch.no_grad():
120
+ lora_rank = info["rank"]
121
+ start_idx = info["start_idx"]
122
+ end_idx = start_idx + lora_rank
123
+
124
+ pu_packed = self.origin_linear.proj_up.data
125
+ pu = unpack_lowrank_weight(pu_packed, down=False)
126
+ pu[:, start_idx:end_idx] *= scale_factor
127
+
128
+ self.origin_linear.proj_up.data = pack_lowrank_weight(pu, down=False)
129
+
130
+ self._lora_dict[name]["scale"] = scale
131
+
132
+ def clear(self, release_all_cpu_memory: bool = False):
133
+ if not self._lora_dict:
134
+ return
135
+
136
+ with torch.no_grad():
137
+ pd_packed = self.origin_linear.proj_down.data
138
+ pu_packed = self.origin_linear.proj_up.data
139
+
140
+ pd = unpack_lowrank_weight(pd_packed, down=True)
141
+ pu = unpack_lowrank_weight(pu_packed, down=False)
142
+
143
+ pd_reset = pd[: self.base_rank, :].clone()
144
+ pu_reset = pu[:, : self.base_rank].clone()
145
+
146
+ self.origin_linear.proj_down.data = pack_lowrank_weight(pd_reset, down=True)
147
+ self.origin_linear.proj_up.data = pack_lowrank_weight(pu_reset, down=False)
148
+
149
+ self.origin_linear.rank = self.base_rank
150
+
151
+ self._lora_dict.clear()
152
+
153
+
154
+ class LoRAAWQW4A16Linear(nn.Module):
155
+ def __init__(self, origin_linear: AWQW4A16Linear):
156
+ super().__init__()
157
+ self.origin_linear = origin_linear
158
+ self._lora_dict = OrderedDict()
159
+
160
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
161
+ quantized_output = self.origin_linear(x)
162
+
163
+ for name, lora in self._lora_dict.items():
164
+ quantized_output += lora(x.to(lora.dtype)).to(quantized_output.dtype)
165
+
166
+ return quantized_output
167
+
168
+ def __getattr__(self, name: str):
169
+ try:
170
+ return super().__getattr__(name)
171
+ except AttributeError:
172
+ return getattr(self.origin_linear, name)
173
+
174
+ def add_lora(
175
+ self,
176
+ name: str,
177
+ scale: float,
178
+ rank: int,
179
+ alpha: int,
180
+ up: torch.Tensor,
181
+ down: torch.Tensor,
182
+ device: str,
183
+ dtype: torch.dtype,
184
+ **kwargs,
185
+ ):
186
+ up_linear = nn.Linear(rank, self.out_features, bias=False, device="meta", dtype=dtype).to_empty(device=device)
187
+ down_linear = nn.Linear(self.in_features, rank, bias=False, device="meta", dtype=dtype).to_empty(device=device)
188
+
189
+ up_linear.weight.data = up.reshape(self.out_features, rank)
190
+ down_linear.weight.data = down.reshape(rank, self.in_features)
191
+
192
+ lora = LoRA(scale, rank, alpha, up_linear, down_linear, device, dtype)
193
+ self._lora_dict[name] = lora
194
+
195
+ def modify_scale(self, name: str, scale: float):
196
+ if name not in self._lora_dict:
197
+ raise ValueError(f"LoRA name {name} not found in {self.__class__.__name__}")
198
+ self._lora_dict[name].scale = scale
199
+
200
+ def add_frozen_lora(self, *args, **kwargs):
201
+ raise NotImplementedError("Frozen LoRA (merging weights) is not supported for AWQW4A16Linear.")
202
+
203
+ def clear(self, *args, **kwargs):
204
+ self._lora_dict.clear()
205
+
206
+
207
+ def patch_nunchaku_model_for_lora(model: nn.Module):
208
+ def _recursive_patch(module: nn.Module):
209
+ for name, child_module in module.named_children():
210
+ replacement = None
211
+ if isinstance(child_module, AWQW4A16Linear):
212
+ replacement = LoRAAWQW4A16Linear(child_module)
213
+ elif isinstance(child_module, SVDQW4A4Linear):
214
+ replacement = LoRASVDQW4A4Linear(child_module)
215
+
216
+ if replacement:
217
+ setattr(module, name, replacement)
218
+ else:
219
+ _recursive_patch(child_module)
220
+
221
+ _recursive_patch(model)
@@ -11,3 +11,11 @@ __all__ = [
11
11
  "Qwen2_5_VLVisionConfig",
12
12
  "Qwen2_5_VLConfig",
13
13
  ]
14
+
15
+ try:
16
+ from .qwen_image_dit_nunchaku import QwenImageDiTNunchaku
17
+
18
+ __all__.append("QwenImageDiTNunchaku")
19
+
20
+ except (ImportError, ModuleNotFoundError):
21
+ pass
@@ -0,0 +1,341 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Any, Dict, List, Tuple, Optional
4
+ from einops import rearrange
5
+
6
+ from diffsynth_engine.models.basic import attention as attention_ops
7
+ from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
8
+ from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, RMSNorm
9
+ from diffsynth_engine.models.qwen_image.qwen_image_dit import (
10
+ QwenFeedForward,
11
+ apply_rotary_emb_qwen,
12
+ QwenDoubleStreamAttention,
13
+ QwenImageTransformerBlock,
14
+ QwenImageDiT,
15
+ QwenEmbedRope,
16
+ )
17
+
18
+ from nunchaku.models.utils import fuse_linears
19
+ from nunchaku.ops.fused import fused_gelu_mlp
20
+ from nunchaku.models.linear import AWQW4A16Linear, SVDQW4A4Linear
21
+ from diffsynth_engine.models.basic.lora import LoRALinear, LoRAConv2d
22
+ from diffsynth_engine.models.basic.lora_nunchaku import LoRASVDQW4A4Linear, LoRAAWQW4A16Linear
23
+
24
+
25
+ class QwenDoubleStreamAttentionNunchaku(QwenDoubleStreamAttention):
26
+ def __init__(
27
+ self,
28
+ dim_a,
29
+ dim_b,
30
+ num_heads,
31
+ head_dim,
32
+ device: str = "cuda:0",
33
+ dtype: torch.dtype = torch.bfloat16,
34
+ nunchaku_rank: int = 32,
35
+ ):
36
+ super().__init__(dim_a, dim_b, num_heads, head_dim, device=device, dtype=dtype)
37
+
38
+ to_qkv = fuse_linears([self.to_q, self.to_k, self.to_v])
39
+ self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, rank=nunchaku_rank)
40
+ self.to_out = SVDQW4A4Linear.from_linear(self.to_out, rank=nunchaku_rank)
41
+
42
+ del self.to_q, self.to_k, self.to_v
43
+
44
+ add_qkv_proj = fuse_linears([self.add_q_proj, self.add_k_proj, self.add_v_proj])
45
+ self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, rank=nunchaku_rank)
46
+ self.to_add_out = SVDQW4A4Linear.from_linear(self.to_add_out, rank=nunchaku_rank)
47
+
48
+ del self.add_q_proj, self.add_k_proj, self.add_v_proj
49
+
50
+ def forward(
51
+ self,
52
+ image: torch.FloatTensor,
53
+ text: torch.FloatTensor,
54
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
55
+ attn_mask: Optional[torch.Tensor] = None,
56
+ attn_kwargs: Optional[Dict[str, Any]] = None,
57
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
58
+ img_q, img_k, img_v = self.to_qkv(image).chunk(3, dim=-1)
59
+ txt_q, txt_k, txt_v = self.add_qkv_proj(text).chunk(3, dim=-1)
60
+
61
+ img_q = rearrange(img_q, "b s (h d) -> b s h d", h=self.num_heads)
62
+ img_k = rearrange(img_k, "b s (h d) -> b s h d", h=self.num_heads)
63
+ img_v = rearrange(img_v, "b s (h d) -> b s h d", h=self.num_heads)
64
+
65
+ txt_q = rearrange(txt_q, "b s (h d) -> b s h d", h=self.num_heads)
66
+ txt_k = rearrange(txt_k, "b s (h d) -> b s h d", h=self.num_heads)
67
+ txt_v = rearrange(txt_v, "b s (h d) -> b s h d", h=self.num_heads)
68
+
69
+ img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
70
+ txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
71
+
72
+ if rotary_emb is not None:
73
+ img_freqs, txt_freqs = rotary_emb
74
+ img_q = apply_rotary_emb_qwen(img_q, img_freqs)
75
+ img_k = apply_rotary_emb_qwen(img_k, img_freqs)
76
+ txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
77
+ txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
78
+
79
+ joint_q = torch.cat([txt_q, img_q], dim=1)
80
+ joint_k = torch.cat([txt_k, img_k], dim=1)
81
+ joint_v = torch.cat([txt_v, img_v], dim=1)
82
+
83
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
84
+ joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **attn_kwargs)
85
+
86
+ joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype)
87
+
88
+ txt_attn_output = joint_attn_out[:, : text.shape[1], :]
89
+ img_attn_output = joint_attn_out[:, text.shape[1] :, :]
90
+
91
+ img_attn_output = self.to_out(img_attn_output)
92
+ txt_attn_output = self.to_add_out(txt_attn_output)
93
+
94
+ return img_attn_output, txt_attn_output
95
+
96
+
97
+ class QwenFeedForwardNunchaku(QwenFeedForward):
98
+ def __init__(
99
+ self,
100
+ dim: int,
101
+ dim_out: Optional[int] = None,
102
+ dropout: float = 0.0,
103
+ device: str = "cuda:0",
104
+ dtype: torch.dtype = torch.bfloat16,
105
+ rank: int = 32,
106
+ ):
107
+ super().__init__(dim, dim_out, dropout, device=device, dtype=dtype)
108
+ self.net[0].proj = SVDQW4A4Linear.from_linear(self.net[0].proj, rank=rank)
109
+ self.net[2] = SVDQW4A4Linear.from_linear(self.net[2], rank=rank)
110
+ self.net[2].act_unsigned = self.net[2].precision != "nvfp4"
111
+
112
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
113
+ return fused_gelu_mlp(hidden_states, self.net[0].proj, self.net[2])
114
+
115
+
116
+ class QwenImageTransformerBlockNunchaku(QwenImageTransformerBlock):
117
+ def __init__(
118
+ self,
119
+ dim: int,
120
+ num_attention_heads: int,
121
+ attention_head_dim: int,
122
+ eps: float = 1e-6,
123
+ device: str = "cuda:0",
124
+ dtype: torch.dtype = torch.bfloat16,
125
+ scale_shift: float = 1.0,
126
+ use_nunchaku_awq: bool = True,
127
+ use_nunchaku_attn: bool = True,
128
+ nunchaku_rank: int = 32,
129
+ ):
130
+ super().__init__(dim, num_attention_heads, attention_head_dim, eps, device=device, dtype=dtype)
131
+
132
+ self.use_nunchaku_awq = use_nunchaku_awq
133
+ if use_nunchaku_awq:
134
+ self.img_mod[1] = AWQW4A16Linear.from_linear(self.img_mod[1], rank=nunchaku_rank)
135
+
136
+ if use_nunchaku_attn:
137
+ self.attn = QwenDoubleStreamAttentionNunchaku(
138
+ dim_a=dim,
139
+ dim_b=dim,
140
+ num_heads=num_attention_heads,
141
+ head_dim=attention_head_dim,
142
+ device=device,
143
+ dtype=dtype,
144
+ nunchaku_rank=nunchaku_rank,
145
+ )
146
+ else:
147
+ self.attn = QwenDoubleStreamAttention(
148
+ dim_a=dim,
149
+ dim_b=dim,
150
+ num_heads=num_attention_heads,
151
+ head_dim=attention_head_dim,
152
+ device=device,
153
+ dtype=dtype,
154
+ )
155
+
156
+ self.img_mlp = QwenFeedForwardNunchaku(dim=dim, dim_out=dim, device=device, dtype=dtype, rank=nunchaku_rank)
157
+
158
+ if use_nunchaku_awq:
159
+ self.txt_mod[1] = AWQW4A16Linear.from_linear(self.txt_mod[1], rank=nunchaku_rank)
160
+
161
+ self.txt_mlp = QwenFeedForwardNunchaku(dim=dim, dim_out=dim, device=device, dtype=dtype, rank=nunchaku_rank)
162
+
163
+ self.scale_shift = scale_shift
164
+
165
+ def _modulate(self, x, mod_params):
166
+ shift, scale, gate = mod_params.chunk(3, dim=-1)
167
+ if self.use_nunchaku_awq:
168
+ if self.scale_shift != 0:
169
+ scale.add_(self.scale_shift)
170
+ return x * scale.unsqueeze(1) + shift.unsqueeze(1), gate.unsqueeze(1)
171
+ else:
172
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
173
+
174
+ def forward(
175
+ self,
176
+ image: torch.Tensor,
177
+ text: torch.Tensor,
178
+ temb: torch.Tensor,
179
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
180
+ attn_mask: Optional[torch.Tensor] = None,
181
+ attn_kwargs: Optional[Dict[str, Any]] = None,
182
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
183
+ if self.use_nunchaku_awq:
184
+ img_mod_params = self.img_mod(temb) # [B, 6*dim]
185
+ txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
186
+
187
+ # nunchaku's mod_params is [B, 6*dim] instead of [B, dim*6]
188
+ img_mod_params = (
189
+ img_mod_params.view(img_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(img_mod_params.shape[0], -1)
190
+ )
191
+ txt_mod_params = (
192
+ txt_mod_params.view(txt_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(txt_mod_params.shape[0], -1)
193
+ )
194
+
195
+ img_mod_attn, img_mod_mlp = img_mod_params.chunk(2, dim=-1) # [B, 3*dim] each
196
+ txt_mod_attn, txt_mod_mlp = txt_mod_params.chunk(2, dim=-1) # [B, 3*dim] each
197
+ else:
198
+ img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
199
+ txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
200
+
201
+ img_normed = self.img_norm1(image)
202
+ img_modulated, img_gate = self._modulate(img_normed, img_mod_attn)
203
+
204
+ txt_normed = self.txt_norm1(text)
205
+ txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
206
+
207
+ img_attn_out, txt_attn_out = self.attn(
208
+ image=img_modulated,
209
+ text=txt_modulated,
210
+ rotary_emb=rotary_emb,
211
+ attn_mask=attn_mask,
212
+ attn_kwargs=attn_kwargs,
213
+ )
214
+
215
+ image = image + img_gate * img_attn_out
216
+ text = text + txt_gate * txt_attn_out
217
+
218
+ img_normed_2 = self.img_norm2(image)
219
+ img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp)
220
+
221
+ txt_normed_2 = self.txt_norm2(text)
222
+ txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
223
+
224
+ img_mlp_out = self.img_mlp(img_modulated_2)
225
+ txt_mlp_out = self.txt_mlp(txt_modulated_2)
226
+
227
+ image = image + img_gate_2 * img_mlp_out
228
+ text = text + txt_gate_2 * txt_mlp_out
229
+
230
+ return text, image
231
+
232
+
233
+ class QwenImageDiTNunchaku(QwenImageDiT):
234
+ def __init__(
235
+ self,
236
+ num_layers: int = 60,
237
+ device: str = "cuda:0",
238
+ dtype: torch.dtype = torch.bfloat16,
239
+ use_nunchaku_awq: bool = True,
240
+ use_nunchaku_attn: bool = True,
241
+ nunchaku_rank: int = 32,
242
+ ):
243
+ super().__init__()
244
+
245
+ self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True, device=device)
246
+
247
+ self.time_text_embed = TimestepEmbeddings(256, 3072, device=device, dtype=dtype)
248
+
249
+ self.txt_norm = RMSNorm(3584, eps=1e-6, device=device, dtype=dtype)
250
+
251
+ self.img_in = nn.Linear(64, 3072, device=device, dtype=dtype)
252
+ self.txt_in = nn.Linear(3584, 3072, device=device, dtype=dtype)
253
+
254
+ self.transformer_blocks = nn.ModuleList(
255
+ [
256
+ QwenImageTransformerBlockNunchaku(
257
+ dim=3072,
258
+ num_attention_heads=24,
259
+ attention_head_dim=128,
260
+ device=device,
261
+ dtype=dtype,
262
+ scale_shift=0,
263
+ use_nunchaku_awq=use_nunchaku_awq,
264
+ use_nunchaku_attn=use_nunchaku_attn,
265
+ nunchaku_rank=nunchaku_rank,
266
+ )
267
+ for _ in range(num_layers)
268
+ ]
269
+ )
270
+ self.norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
271
+ self.proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
272
+
273
+ @classmethod
274
+ def from_state_dict(
275
+ cls,
276
+ state_dict: Dict[str, torch.Tensor],
277
+ device: str,
278
+ dtype: torch.dtype,
279
+ num_layers: int = 60,
280
+ use_nunchaku_awq: bool = True,
281
+ use_nunchaku_attn: bool = True,
282
+ nunchaku_rank: int = 32,
283
+ ):
284
+ model = cls(
285
+ device="meta",
286
+ dtype=dtype,
287
+ num_layers=num_layers,
288
+ use_nunchaku_awq=use_nunchaku_awq,
289
+ use_nunchaku_attn=use_nunchaku_attn,
290
+ nunchaku_rank=nunchaku_rank,
291
+ )
292
+ model = model.requires_grad_(False)
293
+ model.load_state_dict(state_dict, assign=True)
294
+ model.to(device=device, non_blocking=True)
295
+ return model
296
+
297
+ def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = False):
298
+ fuse_dict = {}
299
+ for args in lora_args:
300
+ key = args["key"]
301
+ if any(suffix in key for suffix in {"add_q_proj", "add_k_proj", "add_v_proj"}):
302
+ fuse_key = f"{key.rsplit('.', 1)[0]}.add_qkv_proj"
303
+ type = key.rsplit(".", 1)[-1].split("_")[1]
304
+ fuse_dict[fuse_key] = fuse_dict.get(fuse_key, {})
305
+ fuse_dict[fuse_key][type] = args
306
+ continue
307
+
308
+ if any(suffix in key for suffix in {"to_q", "to_k", "to_v"}):
309
+ fuse_key = f"{key.rsplit('.', 1)[0]}.to_qkv"
310
+ type = key.rsplit(".", 1)[-1].split("_")[1]
311
+ fuse_dict[fuse_key] = fuse_dict.get(fuse_key, {})
312
+ fuse_dict[fuse_key][type] = args
313
+ continue
314
+
315
+ module = self.get_submodule(key)
316
+ if not isinstance(module, (LoRALinear, LoRAConv2d, LoRASVDQW4A4Linear, LoRAAWQW4A16Linear)):
317
+ raise ValueError(f"Unsupported lora key: {key}")
318
+
319
+ if fused and not isinstance(module, LoRAAWQW4A16Linear):
320
+ module.add_frozen_lora(**args)
321
+ else:
322
+ module.add_lora(**args)
323
+
324
+ for key in fuse_dict.keys():
325
+ module = self.get_submodule(key)
326
+ if not isinstance(module, LoRASVDQW4A4Linear):
327
+ raise ValueError(f"Unsupported lora key: {key}")
328
+ module.add_qkv_lora(
329
+ name=args["name"],
330
+ scale=fuse_dict[key]["q"]["scale"],
331
+ rank=fuse_dict[key]["q"]["rank"],
332
+ alpha=fuse_dict[key]["q"]["alpha"],
333
+ q_up=fuse_dict[key]["q"]["up"],
334
+ q_down=fuse_dict[key]["q"]["down"],
335
+ k_up=fuse_dict[key]["k"]["up"],
336
+ k_down=fuse_dict[key]["k"]["down"],
337
+ v_up=fuse_dict[key]["v"]["up"],
338
+ v_down=fuse_dict[key]["v"]["down"],
339
+ device=fuse_dict[key]["q"]["device"],
340
+ dtype=fuse_dict[key]["q"]["dtype"],
341
+ )
@@ -106,7 +106,8 @@ class BasePipeline:
106
106
  for key, param in state_dict.items():
107
107
  lora_args.append(
108
108
  {
109
- "name": key,
109
+ "name": lora_path,
110
+ "key": key,
110
111
  "scale": lora_scale,
111
112
  "rank": param["rank"],
112
113
  "alpha": param["alpha"],
@@ -130,7 +131,10 @@ class BasePipeline:
130
131
 
131
132
  @staticmethod
132
133
  def load_model_checkpoint(
133
- checkpoint_path: str | List[str], device: str = "cpu", dtype: torch.dtype = torch.float16
134
+ checkpoint_path: str | List[str],
135
+ device: str = "cpu",
136
+ dtype: torch.dtype = torch.float16,
137
+ convert_dtype: bool = True,
134
138
  ) -> Dict[str, torch.Tensor]:
135
139
  if isinstance(checkpoint_path, str):
136
140
  checkpoint_path = [checkpoint_path]
@@ -140,8 +144,11 @@ class BasePipeline:
140
144
  raise FileNotFoundError(f"{path} is not a file")
141
145
  elif path.endswith(".safetensors"):
142
146
  state_dict_ = load_file(path, device=device)
143
- for key, value in state_dict_.items():
144
- state_dict[key] = value.to(dtype)
147
+ if convert_dtype:
148
+ for key, value in state_dict_.items():
149
+ state_dict[key] = value.to(dtype)
150
+ else:
151
+ state_dict.update(state_dict_)
145
152
 
146
153
  elif path.endswith(".gguf"):
147
154
  state_dict.update(**load_gguf_checkpoint(path, device=device, dtype=dtype))