diffsynth-engine 0.4.3.dev9__tar.gz → 0.4.3.dev10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/.gitignore +2 -1
  2. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/PKG-INFO +2 -2
  3. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_config.json +2 -1
  4. diffsynth_engine-0.4.3.dev10/diffsynth_engine/conf/tokenizers/qwen_image/qwen2_vl_image_processor.json +29 -0
  5. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/basic/attention.py +3 -3
  6. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/qwen_image/qwen2_5_vl.py +41 -57
  7. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/qwen_image/qwen_image_dit.py +45 -28
  8. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/pipelines/base.py +1 -1
  9. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/pipelines/qwen_image.py +125 -13
  10. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/pipelines/sd_image.py +3 -3
  11. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/pipelines/sdxl_image.py +10 -6
  12. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/tokenizers/__init__.py +4 -0
  13. diffsynth_engine-0.4.3.dev10/diffsynth_engine/tokenizers/qwen2_vl_image_processor.py +157 -0
  14. diffsynth_engine-0.4.3.dev10/diffsynth_engine/tokenizers/qwen2_vl_processor.py +100 -0
  15. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/constants.py +6 -0
  16. diffsynth_engine-0.4.3.dev10/diffsynth_engine/utils/image.py +238 -0
  17. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/offload.py +6 -5
  18. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine.egg-info/PKG-INFO +2 -2
  19. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine.egg-info/SOURCES.txt +3 -0
  20. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine.egg-info/requires.txt +1 -1
  21. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/pyproject.toml +1 -1
  22. diffsynth_engine-0.4.3.dev9/diffsynth_engine/utils/image.py +0 -25
  23. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/.pre-commit-config.yaml +0 -0
  24. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/LICENSE +0 -0
  25. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/MANIFEST.in +0 -0
  26. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/README.md +0 -0
  27. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/assets/dingtalk.png +0 -0
  28. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/assets/showcase.jpeg +0 -0
  29. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/__init__.py +0 -0
  30. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/__init__.py +0 -0
  31. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
  32. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
  33. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
  34. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
  35. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
  36. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
  37. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  38. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
  39. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
  40. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
  41. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
  42. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
  43. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
  44. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
  45. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  46. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
  47. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  48. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
  49. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
  50. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
  51. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
  52. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
  53. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
  54. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
  55. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
  56. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
  57. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/components/vae.json +0 -0
  58. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
  59. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
  60. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
  61. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_vision_config.json +0 -0
  62. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae.json +0 -0
  63. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae_keymap.json +0 -0
  64. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
  65. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
  66. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
  67. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
  68. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
  69. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
  70. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/wan/dit/wan2.1-flf2v-14b.json +0 -0
  71. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/wan/dit/wan2.1-i2v-14b.json +0 -0
  72. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-1.3b.json +0 -0
  73. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-14b.json +0 -0
  74. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/wan/dit/wan2.2-i2v-a14b.json +0 -0
  75. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/wan/dit/wan2.2-t2v-a14b.json +0 -0
  76. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/wan/dit/wan2.2-ti2v-5b.json +0 -0
  77. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/wan/vae/wan-vae-keymap.json +0 -0
  78. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/wan/vae/wan2.1-vae.json +0 -0
  79. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/models/wan/vae/wan2.2-vae.json +0 -0
  80. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
  81. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
  82. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
  83. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
  84. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
  85. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  86. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
  87. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
  88. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/added_tokens.json +0 -0
  89. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/merges.txt +0 -0
  90. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/special_tokens_map.json +0 -0
  91. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer.json +0 -0
  92. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer_config.json +0 -0
  93. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/vocab.json +0 -0
  94. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
  95. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
  96. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
  97. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
  98. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
  99. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
  100. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
  101. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
  102. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
  103. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  104. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
  105. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
  106. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/configs/__init__.py +0 -0
  107. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/configs/controlnet.py +0 -0
  108. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/configs/pipeline.py +0 -0
  109. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/kernels/__init__.py +0 -0
  110. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/__init__.py +0 -0
  111. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/base.py +0 -0
  112. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/basic/__init__.py +0 -0
  113. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/basic/lora.py +0 -0
  114. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
  115. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/basic/timestep.py +0 -0
  116. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/basic/transformer_helper.py +0 -0
  117. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/basic/unet_helper.py +0 -0
  118. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/flux/__init__.py +0 -0
  119. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/flux/flux_controlnet.py +0 -0
  120. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/flux/flux_dit.py +0 -0
  121. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/flux/flux_dit_fbcache.py +0 -0
  122. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/flux/flux_ipadapter.py +0 -0
  123. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/flux/flux_redux.py +0 -0
  124. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
  125. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/flux/flux_vae.py +0 -0
  126. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/hunyuan3d/__init__.py +0 -0
  127. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +0 -0
  128. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/hunyuan3d/hunyuan3d_dit.py +0 -0
  129. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/hunyuan3d/hunyuan3d_vae.py +0 -0
  130. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/hunyuan3d/moe.py +0 -0
  131. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/hunyuan3d/surface_extractor.py +0 -0
  132. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/hunyuan3d/volume_decoder.py +0 -0
  133. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/qwen_image/__init__.py +0 -0
  134. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +0 -0
  135. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/qwen_image/qwen_image_vae.py +0 -0
  136. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/sd/__init__.py +0 -0
  137. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/sd/sd_controlnet.py +0 -0
  138. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
  139. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/sd/sd_unet.py +0 -0
  140. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/sd/sd_vae.py +0 -0
  141. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/sd3/__init__.py +0 -0
  142. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/sd3/sd3_dit.py +0 -0
  143. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
  144. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
  145. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/sdxl/__init__.py +0 -0
  146. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/sdxl/sdxl_controlnet.py +0 -0
  147. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
  148. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/sdxl/sdxl_unet.py +0 -0
  149. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
  150. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/text_encoder/__init__.py +0 -0
  151. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/text_encoder/clip.py +0 -0
  152. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/text_encoder/siglip.py +0 -0
  153. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/text_encoder/t5.py +0 -0
  154. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/utils.py +0 -0
  155. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/vae/__init__.py +0 -0
  156. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/vae/vae.py +0 -0
  157. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/wan/__init__.py +0 -0
  158. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/wan/wan_dit.py +0 -0
  159. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
  160. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
  161. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/models/wan/wan_vae.py +0 -0
  162. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/pipelines/__init__.py +0 -0
  163. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/pipelines/flux_image.py +0 -0
  164. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/pipelines/hunyuan3d_shape.py +0 -0
  165. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/pipelines/utils.py +0 -0
  166. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/pipelines/wan_video.py +0 -0
  167. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/processor/__init__.py +0 -0
  168. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/processor/canny_processor.py +0 -0
  169. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/processor/depth_processor.py +0 -0
  170. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/tokenizers/base.py +0 -0
  171. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/tokenizers/clip.py +0 -0
  172. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/tokenizers/qwen2.py +0 -0
  173. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/tokenizers/t5.py +0 -0
  174. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/tokenizers/wan.py +0 -0
  175. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/tools/__init__.py +0 -0
  176. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/tools/flux_inpainting_tool.py +0 -0
  177. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/tools/flux_outpainting_tool.py +0 -0
  178. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/tools/flux_reference_tool.py +0 -0
  179. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/tools/flux_replace_tool.py +0 -0
  180. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/__init__.py +0 -0
  181. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/cache.py +0 -0
  182. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/download.py +0 -0
  183. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/env.py +0 -0
  184. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/flag.py +0 -0
  185. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/fp8_linear.py +0 -0
  186. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/gguf.py +0 -0
  187. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/loader.py +0 -0
  188. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/lock.py +0 -0
  189. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/logging.py +0 -0
  190. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/memory/__init__.py +0 -0
  191. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/memory/linear_regression.py +0 -0
  192. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/memory/memory_predcit_model.py +0 -0
  193. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/onnx.py +0 -0
  194. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/parallel.py +0 -0
  195. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/platform.py +0 -0
  196. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/prompt.py +0 -0
  197. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine/utils/video.py +0 -0
  198. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
  199. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/diffsynth_engine.egg-info/top_level.txt +0 -0
  200. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/docs/tutorial.md +0 -0
  201. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/docs/tutorial_zh.md +0 -0
  202. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/setup.cfg +0 -0
  203. {diffsynth_engine-0.4.3.dev9 → diffsynth_engine-0.4.3.dev10}/setup.py +0 -0
@@ -8,4 +8,5 @@ dist/
8
8
  *.egg-info/
9
9
  .DS_Store/
10
10
  .pytest_cache/
11
- .ruff_cache/
11
+ .ruff_cache/
12
+ CLAUDE.md
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.4.3.dev9
3
+ Version: 0.4.3.dev10
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -29,7 +29,7 @@ Requires-Dist: scikit-image
29
29
  Requires-Dist: trimesh
30
30
  Provides-Extra: dev
31
31
  Requires-Dist: diffusers==0.31.0; extra == "dev"
32
- Requires-Dist: transformers==4.45.2; extra == "dev"
32
+ Requires-Dist: transformers==4.52.4; extra == "dev"
33
33
  Requires-Dist: accelerate; extra == "dev"
34
34
  Requires-Dist: build; extra == "dev"
35
35
  Requires-Dist: ruff; extra == "dev"
@@ -21,5 +21,6 @@
21
21
  "vision_start_token_id": 151652,
22
22
  "vision_end_token_id": 151653,
23
23
  "image_token_id": 151655,
24
- "video_token_id": 151656
24
+ "video_token_id": 151656,
25
+ "attn_impl": "sdpa"
25
26
  }
@@ -0,0 +1,29 @@
1
+ {
2
+ "do_convert_rgb": true,
3
+ "do_normalize": true,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.48145466,
8
+ 0.4578275,
9
+ 0.40821073
10
+ ],
11
+ "image_processor_type": "Qwen2VLImageProcessor",
12
+ "image_std": [
13
+ 0.26862954,
14
+ 0.26130258,
15
+ 0.27577711
16
+ ],
17
+ "max_pixels": 12845056,
18
+ "merge_size": 2,
19
+ "min_pixels": 3136,
20
+ "patch_size": 14,
21
+ "processor_class": "Qwen2_5_VLProcessor",
22
+ "resample": 3,
23
+ "rescale_factor": 0.00392156862745098,
24
+ "size": {
25
+ "longest_edge": 12845056,
26
+ "shortest_edge": 3136
27
+ },
28
+ "temporal_patch_size": 2
29
+ }
@@ -1,9 +1,9 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
3
4
  from einops import rearrange, repeat
4
5
  from typing import Optional
5
6
 
6
- import torch.nn.functional as F
7
7
  from diffsynth_engine.utils import logging
8
8
  from diffsynth_engine.utils.flag import (
9
9
  FLASH_ATTN_3_AVAILABLE,
@@ -42,11 +42,11 @@ if XFORMERS_AVAILABLE:
42
42
 
43
43
  if SDPA_AVAILABLE:
44
44
 
45
- def sdpa_attn(q, k, v, attn_mask=None, scale=None):
45
+ def sdpa_attn(q, k, v, attn_mask=None, is_causal=False, scale=None):
46
46
  q = q.transpose(1, 2)
47
47
  k = k.transpose(1, 2)
48
48
  v = v.transpose(1, 2)
49
- out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=scale)
49
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=is_causal, scale=scale)
50
50
  return out.transpose(1, 2)
51
51
 
52
52
 
@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Tuple, Optional
7
7
 
8
8
  from diffsynth_engine.models.base import PreTrainedModel
9
9
  from diffsynth_engine.models.basic.transformer_helper import RMSNorm
10
- from diffsynth_engine.models.basic.attention import attention
10
+ from diffsynth_engine.models.basic import attention as attention_ops
11
11
  from diffsynth_engine.models.utils import no_init_weights
12
12
  from diffsynth_engine.utils.cache import Cache, DynamicCache
13
13
  from diffsynth_engine.utils import logging
@@ -152,17 +152,15 @@ class Qwen2_5_VisionRotaryEmbedding(nn.Module):
152
152
  self,
153
153
  dim: int = 80,
154
154
  theta: float = 10000.0,
155
- device: str = "cuda:0",
156
- dtype: torch.dtype = torch.bfloat16,
157
155
  ):
158
156
  super().__init__()
159
- with torch.device(device):
160
- inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
161
- self.register_buffer("inv_freq", inv_freq, persistent=False)
157
+ with torch.device("cpu"):
158
+ self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
162
159
 
163
- def forward(self, seqlen: int) -> torch.Tensor:
164
- seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
165
- freqs = torch.outer(seq, self.inv_freq)
160
+ def forward(self, seqlen: int, device: str) -> torch.Tensor:
161
+ inv_freq = self.inv_freq.to(device=device)
162
+ seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype)
163
+ freqs = torch.outer(seq, inv_freq)
166
164
  return freqs
167
165
 
168
166
 
@@ -222,7 +220,7 @@ class Qwen2_5_VisionAttention(nn.Module):
222
220
  q = rearrange(q, "s n d -> 1 s n d")
223
221
  k = rearrange(k, "s n d -> 1 s n d")
224
222
  v = rearrange(v, "s n d -> 1 s n d")
225
- out = attention(q, k, v, attn_impl=self.attn_impl, attn_mask=attention_mask)
223
+ out = attention_ops.attention(q, k, v, attn_impl=self.attn_impl, attn_mask=attention_mask)
226
224
  out = rearrange(out, "1 s n d -> s (n d)")
227
225
  out = self.proj(out)
228
226
  return out
@@ -301,7 +299,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
301
299
  dtype=dtype,
302
300
  )
303
301
  head_dim = config.hidden_size // config.num_heads
304
- self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2, device=device, dtype=dtype)
302
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
305
303
  self.blocks = nn.ModuleList(
306
304
  [
307
305
  Qwen2_5_VisionBlock(
@@ -348,7 +346,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
348
346
  pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
349
347
  pos_ids = torch.cat(pos_ids, dim=0)
350
348
  max_grid_size = grid_thw[:, 1:].max()
351
- rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
349
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device=grid_thw.device)
352
350
  rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
353
351
  return rotary_pos_emb
354
352
 
@@ -488,7 +486,6 @@ class Qwen2_5_Attention(nn.Module):
488
486
  hidden_size: int = 3584,
489
487
  num_attention_heads: int = 28,
490
488
  num_key_value_heads: int = 4,
491
- # dropout: float = 0.0,
492
489
  mrope_section: List[int] = [16, 24, 24],
493
490
  attn_impl: Optional[str] = None,
494
491
  device: str = "cuda:0",
@@ -501,7 +498,6 @@ class Qwen2_5_Attention(nn.Module):
501
498
  self.head_dim = hidden_size // num_attention_heads
502
499
  self.num_key_value_heads = num_key_value_heads
503
500
  self.num_key_value_groups = num_attention_heads // num_key_value_heads
504
- # self.dropout = dropout
505
501
  self.mrope_section = mrope_section
506
502
  self.attn_impl = attn_impl
507
503
 
@@ -521,8 +517,6 @@ class Qwen2_5_Attention(nn.Module):
521
517
  self.num_attention_heads * self.head_dim, self.hidden_size, bias=False, device=device, dtype=dtype
522
518
  )
523
519
 
524
- self.rotary_emb = Qwen2_5_VLRotaryEmbedding(dim=self.head_dim, device=device, dtype=dtype)
525
-
526
520
  def forward(
527
521
  self,
528
522
  hidden_states: torch.Tensor,
@@ -556,14 +550,18 @@ class Qwen2_5_Attention(nn.Module):
556
550
  if attention_mask is not None: # no matter the length, we just slice it
557
551
  causal_mask = attention_mask[:, :, :, : key_states.shape[1]]
558
552
 
559
- # TODO: attention_mask for flash attention 2
560
- out = attention(
561
- query_states,
562
- key_states,
563
- value_states,
564
- attn_impl=self.attn_impl,
565
- attn_mask=causal_mask,
566
- )
553
+ # TODO: use is_causal when attention mask is causal
554
+ if self.attn_impl == "sdpa":
555
+ out = attention_ops.sdpa_attn(query_states, key_states, value_states, is_causal=True)
556
+ else:
557
+ # TODO: attention_mask for flash attention 2
558
+ out = attention_ops.attention(
559
+ query_states,
560
+ key_states,
561
+ value_states,
562
+ attn_impl=self.attn_impl,
563
+ attn_mask=causal_mask,
564
+ )
567
565
  out = rearrange(out, "b s n d -> b s (n d)")
568
566
  out = self.o_proj(out)
569
567
  return out, past_key_values
@@ -647,29 +645,29 @@ class Qwen2_5_VLDecoderLayer(nn.Module):
647
645
 
648
646
 
649
647
  class Qwen2_5_VLRotaryEmbedding(nn.Module):
650
- def __init__(self, dim: int = 128, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
648
+ def __init__(self, dim: int = 128):
651
649
  super().__init__()
652
- with torch.device(device):
653
- inv_freq = self.compute_rope(dim) # default rope without dynamic frequency
654
- self.register_buffer("inv_freq", inv_freq, persistent=False)
650
+ with torch.device("cpu"):
651
+ self.inv_freq = self.compute_rope(dim) # default rope without dynamic frequency
655
652
 
656
653
  def compute_rope(self, dim: int, theta: float = 1000000.0):
657
654
  inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
658
655
  return inv_freq
659
656
 
660
657
  @torch.no_grad()
661
- def forward(self, x, position_ids):
658
+ def forward(self, position_ids: torch.LongTensor, device: str, dtype: torch.dtype):
662
659
  # In contrast to other models, Qwen2_5_VL has different position ids for the grids
663
660
  # So we expand the inv_freq to shape (3, ...)
664
- inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
661
+ inv_freq = self.inv_freq.to(device=device)
662
+ inv_freq_expanded = inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
665
663
  position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
666
664
 
667
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
665
+ freqs = (inv_freq_expanded @ position_ids_expanded).transpose(2, 3)
668
666
  emb = torch.cat((freqs, freqs), dim=-1)
669
667
  cos = emb.cos()
670
668
  sin = emb.sin()
671
669
 
672
- return cos.to(device=x.device, dtype=x.dtype), sin.to(device=x.device, dtype=x.dtype)
670
+ return cos.to(device=device, dtype=dtype), sin.to(device=device, dtype=dtype)
673
671
 
674
672
 
675
673
  class Qwen2_5_VLModel(nn.Module):
@@ -702,7 +700,7 @@ class Qwen2_5_VLModel(nn.Module):
702
700
  )
703
701
  self.norm = Qwen2_5_RMSNorm(config.hidden_size, config.rms_norm_eps, device=device, dtype=dtype)
704
702
  head_dim = config.hidden_size // config.num_attention_heads
705
- self.rotary_emb = Qwen2_5_VLRotaryEmbedding(dim=head_dim, device=device, dtype=dtype)
703
+ self.rotary_emb = Qwen2_5_VLRotaryEmbedding(dim=head_dim)
706
704
 
707
705
  def get_input_embeddings(self):
708
706
  return self.embed_tokens
@@ -749,7 +747,7 @@ class Qwen2_5_VLModel(nn.Module):
749
747
  hidden_states = inputs_embeds
750
748
 
751
749
  # create position embeddings to be shared across the decoder layers
752
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
750
+ position_embeddings = self.rotary_emb(position_ids, device=hidden_states.device, dtype=hidden_states.dtype)
753
751
 
754
752
  # decoder layers
755
753
  for decoder_layer in self.layers:
@@ -940,8 +938,7 @@ class Qwen2_5_VLForConditionalGeneration(PreTrainedModel):
940
938
  with torch.device("meta"), no_init_weights():
941
939
  model = cls(vision_config=vision_config, config=config, device=device, dtype=dtype)
942
940
  model.load_state_dict(state_dict, assign=True)
943
- for param in model.parameters(): # skip buffers
944
- param.data = param.data.to(device=device, dtype=dtype, non_blocking=True)
941
+ model.to(device=device, dtype=dtype, non_blocking=True)
945
942
  return model
946
943
 
947
944
  def get_input_embeddings(self):
@@ -1202,27 +1199,14 @@ class Qwen2_5_VLForConditionalGeneration(PreTrainedModel):
1202
1199
  if position_ids is None:
1203
1200
  assert attention_mask is None or attention_mask.ndim == 2, "attention mask must be 2D"
1204
1201
  # calculate RoPE index once per generation in the pre-fill stage only
1205
- if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
1206
- position_ids, rope_deltas = self.get_rope_index(
1207
- input_ids,
1208
- image_grid_thw,
1209
- video_grid_thw,
1210
- second_per_grid_ts,
1211
- attention_mask,
1212
- )
1213
- self.rope_deltas = rope_deltas
1214
- # then use the prev pre-calculated rope-deltas to get the correct position ids
1215
- else:
1216
- batch_size, seq_length, _ = inputs_embeds.shape
1217
- delta = (
1218
- (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0
1219
- )
1220
- position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1221
- position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1222
- if cache_position is not None: # otherwise `deltas` is an int `0`
1223
- delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1224
- position_ids = position_ids.add(delta)
1225
- position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
1202
+ position_ids, rope_deltas = self.get_rope_index(
1203
+ input_ids,
1204
+ image_grid_thw,
1205
+ video_grid_thw,
1206
+ second_per_grid_ts,
1207
+ attention_mask,
1208
+ )
1209
+ self.rope_deltas = rope_deltas
1226
1210
 
1227
1211
  hidden_states, present_key_values = self.model(
1228
1212
  input_ids=None,
@@ -81,41 +81,47 @@ class QwenEmbedRope(nn.Module):
81
81
 
82
82
  def forward(self, video_fhw, txt_length, device):
83
83
  """
84
- Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video
85
- Args: txt_length: an integer representing the length of text
84
+ Args:
85
+ video_fhw (List[Tuple[int, int, int]]): A list of (frame, height, width) tuples for each video/image
86
+ txt_length (int): The maximum length of the text sequences
86
87
  """
87
88
  if self.pos_freqs.device != device:
88
89
  self.pos_freqs = self.pos_freqs.to(device)
89
90
  self.neg_freqs = self.neg_freqs.to(device)
90
91
 
91
- frame, height, width = video_fhw
92
- rope_key = f"{frame}_{height}_{width}"
93
-
94
- if rope_key not in self.rope_cache:
95
- seq_lens = frame * height * width
96
- freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
97
- freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
98
- freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
92
+ vid_freqs = []
93
+ max_vid_index = 0
94
+ for idx, fhw in enumerate(video_fhw):
95
+ frame, height, width = fhw
96
+ rope_key = f"{idx}_{height}_{width}"
97
+
98
+ if rope_key not in self.rope_cache:
99
+ seq_lens = frame * height * width
100
+ freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
101
+ freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
102
+ freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
103
+ if self.scale_rope:
104
+ freqs_height = torch.cat(
105
+ [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
106
+ )
107
+ freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
108
+ freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
109
+ freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
110
+
111
+ else:
112
+ freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
113
+ freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
114
+
115
+ freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
116
+ self.rope_cache[rope_key] = freqs.clone().contiguous()
117
+ vid_freqs.append(self.rope_cache[rope_key])
99
118
  if self.scale_rope:
100
- freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
101
- freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
102
- freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
103
- freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
104
-
119
+ max_vid_index = max(height // 2, width // 2, max_vid_index)
105
120
  else:
106
- freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
107
- freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
108
-
109
- freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
110
- self.rope_cache[rope_key] = freqs.clone().contiguous()
111
- vid_freqs = self.rope_cache[rope_key]
112
-
113
- if self.scale_rope:
114
- max_vid_index = max(height // 2, width // 2)
115
- else:
116
- max_vid_index = max(height, width)
121
+ max_vid_index = max(height, width, max_vid_index)
117
122
 
118
123
  txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + txt_length, ...]
124
+ vid_freqs = torch.cat(vid_freqs, dim=0)
119
125
 
120
126
  return vid_freqs, txt_freqs
121
127
 
@@ -364,6 +370,7 @@ class QwenImageDiT(PreTrainedModel):
364
370
  def forward(
365
371
  self,
366
372
  image: torch.Tensor,
373
+ edit: torch.Tensor = None,
367
374
  text: torch.Tensor = None,
368
375
  timestep: torch.LongTensor = None,
369
376
  txt_seq_lens: torch.LongTensor = None,
@@ -377,6 +384,7 @@ class QwenImageDiT(PreTrainedModel):
377
384
  cfg_parallel(
378
385
  (
379
386
  image,
387
+ edit,
380
388
  text,
381
389
  timestep,
382
390
  txt_seq_lens,
@@ -385,11 +393,18 @@ class QwenImageDiT(PreTrainedModel):
385
393
  ),
386
394
  ):
387
395
  conditioning = self.time_text_embed(timestep, image.dtype)
388
- video_fhw = (1, h // 2, w // 2) # frame, height, width
396
+ video_fhw = [(1, h // 2, w // 2)] # frame, height, width
389
397
  max_length = txt_seq_lens.max().item()
398
+ image = self.patchify(image)
399
+ image_seq_len = image.shape[1]
400
+ if edit is not None:
401
+ edit = edit.to(dtype=image.dtype)
402
+ edit = self.patchify(edit)
403
+ image = torch.cat([image, edit], dim=1)
404
+ video_fhw += video_fhw
405
+
390
406
  image_rotary_emb = self.pos_embed(video_fhw, max_length, image.device)
391
407
 
392
- image = self.patchify(image)
393
408
  image = self.img_in(image)
394
409
  text = self.txt_in(self.txt_norm(text[:, :max_length]))
395
410
 
@@ -397,6 +412,8 @@ class QwenImageDiT(PreTrainedModel):
397
412
  text, image = block(image=image, text=text, temb=conditioning, image_rotary_emb=image_rotary_emb)
398
413
  image = self.norm_out(image, conditioning)
399
414
  image = self.proj_out(image)
415
+ if edit is not None:
416
+ image = image[:, :image_seq_len]
400
417
 
401
418
  image = self.unpatchify(image, h, w)
402
419
 
@@ -164,7 +164,7 @@ class BasePipeline:
164
164
  @staticmethod
165
165
  def generate_noise(shape, seed=None, device="cpu", dtype=torch.float16):
166
166
  generator = None if seed is None else torch.Generator(device).manual_seed(seed)
167
- noise = torch.randn(shape, generator=generator, device=device).to(dtype)
167
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
168
168
  return noise
169
169
 
170
170
  def encode_image(