diffsynth-engine 0.6.1.dev21__tar.gz → 0.6.1.dev23__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/PKG-INFO +1 -1
- diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/configs/pipeline.py +35 -5
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/basic/attention.py +59 -20
- diffsynth_engine-0.6.1.dev23/diffsynth_engine/models/basic/video_sparse_attention.py +235 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/flux/flux_controlnet.py +7 -19
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/flux/flux_dit.py +22 -36
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/flux/flux_dit_fbcache.py +9 -7
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/qwen_image/qwen_image_dit.py +13 -15
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/wan/wan_dit.py +62 -22
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/flux_image.py +11 -10
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/qwen_image.py +26 -28
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/wan_s2v.py +3 -8
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/wan_video.py +11 -13
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/constants.py +13 -12
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/flag.py +6 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/parallel.py +51 -6
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine.egg-info/PKG-INFO +1 -1
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine.egg-info/SOURCES.txt +13 -11
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/.gitattributes +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/.gitignore +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/.pre-commit-config.yaml +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/LICENSE +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/MANIFEST.in +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/README.md +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/assets/dingtalk.png +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/assets/showcase.jpeg +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/assets/tongyi.svg +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/components/vae.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_config.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_vision_config.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae_keymap.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
- /diffsynth_engine-0.6.1.dev21/diffsynth_engine/conf/models/wan/dit/wan2.1-flf2v-14b.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan2.1_flf2v_14b.json +0 -0
- /diffsynth_engine-0.6.1.dev21/diffsynth_engine/conf/models/wan/dit/wan2.1-i2v-14b.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan2.1_i2v_14b.json +0 -0
- /diffsynth_engine-0.6.1.dev21/diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-1.3b.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_1.3b.json +0 -0
- /diffsynth_engine-0.6.1.dev21/diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-14b.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_14b.json +0 -0
- /diffsynth_engine-0.6.1.dev21/diffsynth_engine/conf/models/wan/dit/wan2.2-i2v-a14b.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan2.2_i2v_a14b.json +0 -0
- /diffsynth_engine-0.6.1.dev21/diffsynth_engine/conf/models/wan/dit/wan2.2-s2v-14b.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan2.2_s2v_14b.json +0 -0
- /diffsynth_engine-0.6.1.dev21/diffsynth_engine/conf/models/wan/dit/wan2.2-t2v-a14b.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan2.2_t2v_a14b.json +0 -0
- /diffsynth_engine-0.6.1.dev21/diffsynth_engine/conf/models/wan/dit/wan2.2-ti2v-5b.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/dit/wan2.2_ti2v_5b.json +0 -0
- /diffsynth_engine-0.6.1.dev21/diffsynth_engine/conf/models/wan/vae/wan2.1-vae.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/vae/wan2.1_vae.json +0 -0
- /diffsynth_engine-0.6.1.dev21/diffsynth_engine/conf/models/wan/vae/wan2.2-vae.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/vae/wan2.2_vae.json +0 -0
- /diffsynth_engine-0.6.1.dev21/diffsynth_engine/conf/models/wan/vae/wan-vae-keymap.json → /diffsynth_engine-0.6.1.dev23/diffsynth_engine/conf/models/wan/vae/wan_vae_keymap.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/qwen_image/qwen2_vl_image_processor.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/added_tokens.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/merges.txt +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/special_tokens_map.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer_config.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/vocab.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/configs/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/configs/controlnet.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/kernels/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/base.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/basic/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/basic/lora.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/basic/timestep.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/basic/transformer_helper.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/basic/unet_helper.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/flux/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/flux/flux_redux.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/flux/flux_vae.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/hunyuan3d/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/hunyuan3d/hunyuan3d_dit.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/hunyuan3d/hunyuan3d_vae.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/hunyuan3d/moe.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/hunyuan3d/surface_extractor.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/hunyuan3d/volume_decoder.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/qwen_image/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/qwen_image/qwen2_5_vl.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/qwen_image/qwen_image_vae.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd/sd_controlnet.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd/sd_unet.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd/sd_vae.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd3/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd3/sd3_dit.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sdxl/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sdxl/sdxl_controlnet.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sdxl/sdxl_unet.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/text_encoder/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/text_encoder/clip.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/text_encoder/siglip.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/text_encoder/t5.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/vae/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/vae/vae.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/wan/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/wan/wan_audio_encoder.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/wan/wan_s2v_dit.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/models/wan/wan_vae.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/base.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/hunyuan3d_shape.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/sd_image.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/sdxl_image.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/pipelines/utils.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/processor/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/processor/canny_processor.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/processor/depth_processor.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tokenizers/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tokenizers/base.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tokenizers/clip.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tokenizers/qwen2.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tokenizers/qwen2_vl_image_processor.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tokenizers/qwen2_vl_processor.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tokenizers/t5.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tokenizers/wan.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tools/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tools/flux_inpainting_tool.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tools/flux_outpainting_tool.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tools/flux_reference_tool.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/tools/flux_replace_tool.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/cache.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/download.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/env.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/fp8_linear.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/gguf.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/image.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/loader.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/lock.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/logging.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/memory/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/memory/linear_regression.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/memory/memory_predcit_model.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/offload.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/onnx.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/platform.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/prompt.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/utils/video.py +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine.egg-info/requires.txt +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine.egg-info/top_level.txt +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/docs/tutorial.md +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/docs/tutorial_zh.md +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/pyproject.toml +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/setup.cfg +0 -0
- {diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/setup.py +0 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
{
|
|
2
|
+
"diffusers": {
|
|
3
|
+
"global_rename_dict": {
|
|
4
|
+
"patch_embedding": "patch_embedding",
|
|
5
|
+
"condition_embedder.text_embedder.linear_1": "text_embedding.0",
|
|
6
|
+
"condition_embedder.text_embedder.linear_2": "text_embedding.2",
|
|
7
|
+
"condition_embedder.time_embedder.linear_1": "time_embedding.0",
|
|
8
|
+
"condition_embedder.time_embedder.linear_2": "time_embedding.2",
|
|
9
|
+
"condition_embedder.time_proj": "time_projection.1",
|
|
10
|
+
"condition_embedder.image_embedder.norm1": "img_emb.proj.0",
|
|
11
|
+
"condition_embedder.image_embedder.ff.net.0.proj": "img_emb.proj.1",
|
|
12
|
+
"condition_embedder.image_embedder.ff.net.2": "img_emb.proj.3",
|
|
13
|
+
"condition_embedder.image_embedder.norm2": "img_emb.proj.4",
|
|
14
|
+
"condition_embedder.image_embedder.pos_embed": "img_emb.emb_pos",
|
|
15
|
+
"proj_out": "head.head",
|
|
16
|
+
"scale_shift_table": "head.modulation"
|
|
17
|
+
},
|
|
18
|
+
"rename_dict": {
|
|
19
|
+
"attn1.to_q": "self_attn.q",
|
|
20
|
+
"attn1.to_k": "self_attn.k",
|
|
21
|
+
"attn1.to_v": "self_attn.v",
|
|
22
|
+
"attn1.to_out.0": "self_attn.o",
|
|
23
|
+
"attn1.norm_q": "self_attn.norm_q",
|
|
24
|
+
"attn1.norm_k": "self_attn.norm_k",
|
|
25
|
+
"to_gate_compress": "self_attn.gate_compress",
|
|
26
|
+
"attn2.to_q": "cross_attn.q",
|
|
27
|
+
"attn2.to_k": "cross_attn.k",
|
|
28
|
+
"attn2.to_v": "cross_attn.v",
|
|
29
|
+
"attn2.to_out.0": "cross_attn.o",
|
|
30
|
+
"attn2.norm_q": "cross_attn.norm_q",
|
|
31
|
+
"attn2.norm_k": "cross_attn.norm_k",
|
|
32
|
+
"attn2.add_k_proj": "cross_attn.k_img",
|
|
33
|
+
"attn2.add_v_proj": "cross_attn.v_img",
|
|
34
|
+
"attn2.norm_added_k": "cross_attn.norm_k_img",
|
|
35
|
+
"norm2": "norm3",
|
|
36
|
+
"ffn.net.0.proj": "ffn.0",
|
|
37
|
+
"ffn.net.2": "ffn.2",
|
|
38
|
+
"scale_shift_table": "modulation"
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
}
|
{diffsynth_engine-0.6.1.dev21 → diffsynth_engine-0.6.1.dev23}/diffsynth_engine/configs/pipeline.py
RENAMED
|
@@ -5,6 +5,7 @@ from dataclasses import dataclass, field
|
|
|
5
5
|
from typing import List, Dict, Tuple, Optional
|
|
6
6
|
|
|
7
7
|
from diffsynth_engine.configs.controlnet import ControlType
|
|
8
|
+
from diffsynth_engine.models.basic.video_sparse_attention import get_vsa_kwargs
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
@dataclass
|
|
@@ -30,16 +31,43 @@ class AttnImpl(Enum):
|
|
|
30
31
|
SDPA = "sdpa" # Scaled Dot Product Attention
|
|
31
32
|
SAGE = "sage" # Sage Attention
|
|
32
33
|
SPARGE = "sparge" # Sparge Attention
|
|
34
|
+
VSA = "vsa" # Video Sparse Attention
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class SpargeAttentionParams:
|
|
39
|
+
smooth_k: bool = True
|
|
40
|
+
cdfthreshd: float = 0.6
|
|
41
|
+
simthreshd1: float = 0.98
|
|
42
|
+
pvthreshd: float = 50.0
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class VideoSparseAttentionParams:
|
|
47
|
+
sparsity: float = 0.9
|
|
33
48
|
|
|
34
49
|
|
|
35
50
|
@dataclass
|
|
36
51
|
class AttentionConfig:
|
|
37
52
|
dit_attn_impl: AttnImpl = AttnImpl.AUTO
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
53
|
+
attn_params: Optional[SpargeAttentionParams | VideoSparseAttentionParams] = None
|
|
54
|
+
|
|
55
|
+
def get_attn_kwargs(self, latents: torch.Tensor, device: str) -> Dict:
|
|
56
|
+
attn_kwargs = {"attn_impl": self.dit_attn_impl.value}
|
|
57
|
+
if isinstance(self.attn_params, SpargeAttentionParams):
|
|
58
|
+
assert self.dit_attn_impl == AttnImpl.SPARGE
|
|
59
|
+
attn_kwargs.update(
|
|
60
|
+
{
|
|
61
|
+
"smooth_k": self.attn_params.smooth_k,
|
|
62
|
+
"simthreshd1": self.attn_params.simthreshd1,
|
|
63
|
+
"cdfthreshd": self.attn_params.cdfthreshd,
|
|
64
|
+
"pvthreshd": self.attn_params.pvthreshd,
|
|
65
|
+
}
|
|
66
|
+
)
|
|
67
|
+
elif isinstance(self.attn_params, VideoSparseAttentionParams):
|
|
68
|
+
assert self.dit_attn_impl == AttnImpl.VSA
|
|
69
|
+
attn_kwargs.update(get_vsa_kwargs(latents.shape[2:], (1, 2, 2), self.attn_params.sparsity, device=device))
|
|
70
|
+
return attn_kwargs
|
|
43
71
|
|
|
44
72
|
|
|
45
73
|
@dataclass
|
|
@@ -242,6 +270,8 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
|
|
|
242
270
|
vae_tile_size: Tuple[int, int] = (34, 34)
|
|
243
271
|
vae_tile_stride: Tuple[int, int] = (18, 16)
|
|
244
272
|
|
|
273
|
+
load_encoder: bool = True
|
|
274
|
+
|
|
245
275
|
@classmethod
|
|
246
276
|
def basic_config(
|
|
247
277
|
cls,
|
|
@@ -12,6 +12,7 @@ from diffsynth_engine.utils.flag import (
|
|
|
12
12
|
SDPA_AVAILABLE,
|
|
13
13
|
SAGE_ATTN_AVAILABLE,
|
|
14
14
|
SPARGE_ATTN_AVAILABLE,
|
|
15
|
+
VIDEO_SPARSE_ATTN_AVAILABLE,
|
|
15
16
|
)
|
|
16
17
|
from diffsynth_engine.utils.platform import DTYPE_FP8
|
|
17
18
|
|
|
@@ -20,12 +21,6 @@ FA3_MAX_HEADDIM = 256
|
|
|
20
21
|
logger = logging.get_logger(__name__)
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
|
|
24
|
-
padding_size = (alignment - x.shape[dim] % alignment) % alignment
|
|
25
|
-
padded_x = F.pad(x, (0, padding_size), "constant", 0)
|
|
26
|
-
return padded_x[..., : x.shape[dim]]
|
|
27
|
-
|
|
28
|
-
|
|
29
24
|
if FLASH_ATTN_3_AVAILABLE:
|
|
30
25
|
from flash_attn_interface import flash_attn_func as flash_attn3
|
|
31
26
|
if FLASH_ATTN_2_AVAILABLE:
|
|
@@ -33,6 +28,11 @@ if FLASH_ATTN_2_AVAILABLE:
|
|
|
33
28
|
if XFORMERS_AVAILABLE:
|
|
34
29
|
from xformers.ops import memory_efficient_attention
|
|
35
30
|
|
|
31
|
+
def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
|
|
32
|
+
padding_size = (alignment - x.shape[dim] % alignment) % alignment
|
|
33
|
+
padded_x = F.pad(x, (0, padding_size), "constant", 0)
|
|
34
|
+
return padded_x[..., : x.shape[dim]]
|
|
35
|
+
|
|
36
36
|
def xformers_attn(q, k, v, attn_mask=None, scale=None):
|
|
37
37
|
if attn_mask is not None:
|
|
38
38
|
if attn_mask.ndim == 2:
|
|
@@ -94,6 +94,13 @@ if SPARGE_ATTN_AVAILABLE:
|
|
|
94
94
|
return out.transpose(1, 2)
|
|
95
95
|
|
|
96
96
|
|
|
97
|
+
if VIDEO_SPARSE_ATTN_AVAILABLE:
|
|
98
|
+
from diffsynth_engine.models.basic.video_sparse_attention import (
|
|
99
|
+
video_sparse_attn,
|
|
100
|
+
distributed_video_sparse_attn,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
97
104
|
def eager_attn(q, k, v, attn_mask=None, scale=None):
|
|
98
105
|
q = q.transpose(1, 2)
|
|
99
106
|
k = k.transpose(1, 2)
|
|
@@ -109,9 +116,10 @@ def eager_attn(q, k, v, attn_mask=None, scale=None):
|
|
|
109
116
|
|
|
110
117
|
|
|
111
118
|
def attention(
|
|
112
|
-
q,
|
|
113
|
-
k,
|
|
114
|
-
v,
|
|
119
|
+
q: torch.Tensor,
|
|
120
|
+
k: torch.Tensor,
|
|
121
|
+
v: torch.Tensor,
|
|
122
|
+
g: Optional[torch.Tensor] = None,
|
|
115
123
|
attn_impl: Optional[str] = "auto",
|
|
116
124
|
attn_mask: Optional[torch.Tensor] = None,
|
|
117
125
|
scale: Optional[float] = None,
|
|
@@ -133,6 +141,7 @@ def attention(
|
|
|
133
141
|
"sdpa",
|
|
134
142
|
"sage",
|
|
135
143
|
"sparge",
|
|
144
|
+
"vsa",
|
|
136
145
|
]
|
|
137
146
|
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
|
|
138
147
|
if attn_impl is None or attn_impl == "auto":
|
|
@@ -189,10 +198,24 @@ def attention(
|
|
|
189
198
|
v,
|
|
190
199
|
attn_mask=attn_mask,
|
|
191
200
|
scale=scale,
|
|
192
|
-
smooth_k=kwargs.get("
|
|
193
|
-
simthreshd1=kwargs.get("
|
|
194
|
-
cdfthreshd=kwargs.get("
|
|
195
|
-
pvthreshd=kwargs.get("
|
|
201
|
+
smooth_k=kwargs.get("smooth_k", True),
|
|
202
|
+
simthreshd1=kwargs.get("simthreshd1", 0.6),
|
|
203
|
+
cdfthreshd=kwargs.get("cdfthreshd", 0.98),
|
|
204
|
+
pvthreshd=kwargs.get("pvthreshd", 50),
|
|
205
|
+
)
|
|
206
|
+
if attn_impl == "vsa":
|
|
207
|
+
return video_sparse_attn(
|
|
208
|
+
q,
|
|
209
|
+
k,
|
|
210
|
+
v,
|
|
211
|
+
g,
|
|
212
|
+
sparsity=kwargs.get("sparsity"),
|
|
213
|
+
num_tiles=kwargs.get("num_tiles"),
|
|
214
|
+
total_seq_length=kwargs.get("total_seq_length"),
|
|
215
|
+
tile_partition_indices=kwargs.get("tile_partition_indices"),
|
|
216
|
+
reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"),
|
|
217
|
+
variable_block_sizes=kwargs.get("variable_block_sizes"),
|
|
218
|
+
non_pad_index=kwargs.get("non_pad_index"),
|
|
196
219
|
)
|
|
197
220
|
raise ValueError(f"Invalid attention implementation: {attn_impl}")
|
|
198
221
|
|
|
@@ -242,9 +265,10 @@ class Attention(nn.Module):
|
|
|
242
265
|
|
|
243
266
|
|
|
244
267
|
def long_context_attention(
|
|
245
|
-
q,
|
|
246
|
-
k,
|
|
247
|
-
v,
|
|
268
|
+
q: torch.Tensor,
|
|
269
|
+
k: torch.Tensor,
|
|
270
|
+
v: torch.Tensor,
|
|
271
|
+
g: Optional[torch.Tensor] = None,
|
|
248
272
|
attn_impl: Optional[str] = None,
|
|
249
273
|
attn_mask: Optional[torch.Tensor] = None,
|
|
250
274
|
scale: Optional[float] = None,
|
|
@@ -267,6 +291,7 @@ def long_context_attention(
|
|
|
267
291
|
"sdpa",
|
|
268
292
|
"sage",
|
|
269
293
|
"sparge",
|
|
294
|
+
"vsa",
|
|
270
295
|
]
|
|
271
296
|
assert attn_mask is None, "long context attention does not support attention mask"
|
|
272
297
|
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
|
|
@@ -307,11 +332,25 @@ def long_context_attention(
|
|
|
307
332
|
if attn_impl == "sparge":
|
|
308
333
|
attn_processor = SparseAttentionMeansim()
|
|
309
334
|
# default args from spas_sage2_attn_meansim_cuda
|
|
310
|
-
attn_processor.smooth_k = torch.tensor(kwargs.get("
|
|
311
|
-
attn_processor.simthreshd1 = torch.tensor(kwargs.get("
|
|
312
|
-
attn_processor.cdfthreshd = torch.tensor(kwargs.get("
|
|
313
|
-
attn_processor.pvthreshd = torch.tensor(kwargs.get("
|
|
335
|
+
attn_processor.smooth_k = torch.tensor(kwargs.get("smooth_k", True))
|
|
336
|
+
attn_processor.simthreshd1 = torch.tensor(kwargs.get("simthreshd1", 0.6))
|
|
337
|
+
attn_processor.cdfthreshd = torch.tensor(kwargs.get("cdfthreshd", 0.98))
|
|
338
|
+
attn_processor.pvthreshd = torch.tensor(kwargs.get("pvthreshd", 50))
|
|
314
339
|
return LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)(
|
|
315
340
|
q, k, v, softmax_scale=scale
|
|
316
341
|
)
|
|
342
|
+
if attn_impl == "vsa":
|
|
343
|
+
return distributed_video_sparse_attn(
|
|
344
|
+
q,
|
|
345
|
+
k,
|
|
346
|
+
v,
|
|
347
|
+
g,
|
|
348
|
+
sparsity=kwargs.get("sparsity"),
|
|
349
|
+
num_tiles=kwargs.get("num_tiles"),
|
|
350
|
+
total_seq_length=kwargs.get("total_seq_length"),
|
|
351
|
+
tile_partition_indices=kwargs.get("tile_partition_indices"),
|
|
352
|
+
reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"),
|
|
353
|
+
variable_block_sizes=kwargs.get("variable_block_sizes"),
|
|
354
|
+
non_pad_index=kwargs.get("non_pad_index"),
|
|
355
|
+
)
|
|
317
356
|
raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import math
|
|
3
|
+
import functools
|
|
4
|
+
|
|
5
|
+
from vsa import video_sparse_attn as vsa_core
|
|
6
|
+
from diffsynth_engine.utils.parallel import get_sp_ulysses_group, get_sp_ring_world_size
|
|
7
|
+
|
|
8
|
+
VSA_TILE_SIZE = (4, 4, 4)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@functools.lru_cache(maxsize=10)
|
|
12
|
+
def get_tile_partition_indices(
|
|
13
|
+
dit_seq_shape: tuple[int, int, int],
|
|
14
|
+
tile_size: tuple[int, int, int],
|
|
15
|
+
device: torch.device,
|
|
16
|
+
) -> torch.LongTensor:
|
|
17
|
+
T, H, W = dit_seq_shape
|
|
18
|
+
ts, hs, ws = tile_size
|
|
19
|
+
indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W)
|
|
20
|
+
ls = []
|
|
21
|
+
for t in range(math.ceil(T / ts)):
|
|
22
|
+
for h in range(math.ceil(H / hs)):
|
|
23
|
+
for w in range(math.ceil(W / ws)):
|
|
24
|
+
ls.append(
|
|
25
|
+
indices[
|
|
26
|
+
t * ts : min(t * ts + ts, T), h * hs : min(h * hs + hs, H), w * ws : min(w * ws + ws, W)
|
|
27
|
+
].flatten()
|
|
28
|
+
)
|
|
29
|
+
index = torch.cat(ls, dim=0)
|
|
30
|
+
return index
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@functools.lru_cache(maxsize=10)
|
|
34
|
+
def get_reverse_tile_partition_indices(
|
|
35
|
+
dit_seq_shape: tuple[int, int, int],
|
|
36
|
+
tile_size: tuple[int, int, int],
|
|
37
|
+
device: torch.device,
|
|
38
|
+
) -> torch.LongTensor:
|
|
39
|
+
return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device))
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@functools.lru_cache(maxsize=10)
|
|
43
|
+
def construct_variable_block_sizes(
|
|
44
|
+
dit_seq_shape: tuple[int, int, int],
|
|
45
|
+
num_tiles: tuple[int, int, int],
|
|
46
|
+
device: torch.device,
|
|
47
|
+
) -> torch.LongTensor:
|
|
48
|
+
"""
|
|
49
|
+
Compute the number of valid (non-padded) tokens inside every
|
|
50
|
+
(ts_t x ts_h x ts_w) tile after padding -- flattened in the order
|
|
51
|
+
(t-tile, h-tile, w-tile) that `rearrange` uses.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
torch.LongTensor # shape: [∏ full_window_size]
|
|
56
|
+
"""
|
|
57
|
+
# unpack
|
|
58
|
+
t, h, w = dit_seq_shape
|
|
59
|
+
ts_t, ts_h, ts_w = VSA_TILE_SIZE
|
|
60
|
+
n_t, n_h, n_w = num_tiles
|
|
61
|
+
|
|
62
|
+
def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor:
|
|
63
|
+
"""Vector with the size of each tile along one dimension."""
|
|
64
|
+
sizes = torch.full((n_tiles,), tile, dtype=torch.int, device=device)
|
|
65
|
+
# size of last (possibly partial) tile
|
|
66
|
+
remainder = dim_len - (n_tiles - 1) * tile
|
|
67
|
+
sizes[-1] = remainder if remainder > 0 else tile
|
|
68
|
+
return sizes
|
|
69
|
+
|
|
70
|
+
t_sizes = _sizes(t, ts_t, n_t) # [n_t]
|
|
71
|
+
h_sizes = _sizes(h, ts_h, n_h) # [n_h]
|
|
72
|
+
w_sizes = _sizes(w, ts_w, n_w) # [n_w]
|
|
73
|
+
|
|
74
|
+
# broadcast‑multiply to get voxels per tile, then flatten
|
|
75
|
+
block_sizes = (
|
|
76
|
+
t_sizes[:, None, None] # [n_t, 1, 1]
|
|
77
|
+
* h_sizes[None, :, None] # [1, n_h, 1]
|
|
78
|
+
* w_sizes[None, None, :] # [1, 1, n_w]
|
|
79
|
+
).reshape(-1) # [n_t * n_h * n_w]
|
|
80
|
+
|
|
81
|
+
return block_sizes
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@functools.lru_cache(maxsize=10)
|
|
85
|
+
def get_non_pad_index(
|
|
86
|
+
variable_block_sizes: torch.LongTensor,
|
|
87
|
+
max_block_size: int,
|
|
88
|
+
):
|
|
89
|
+
n_win = variable_block_sizes.shape[0]
|
|
90
|
+
device = variable_block_sizes.device
|
|
91
|
+
starts_pad = torch.arange(n_win, device=device) * max_block_size
|
|
92
|
+
index_pad = starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :]
|
|
93
|
+
index_mask = torch.arange(max_block_size, device=device)[None, :] < variable_block_sizes[:, None]
|
|
94
|
+
return index_pad[index_mask]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_vsa_kwargs(
|
|
98
|
+
latent_shape: tuple[int, int, int],
|
|
99
|
+
patch_size: tuple[int, int, int],
|
|
100
|
+
sparsity: float,
|
|
101
|
+
device: torch.device,
|
|
102
|
+
):
|
|
103
|
+
dit_seq_shape = (
|
|
104
|
+
latent_shape[0] // patch_size[0],
|
|
105
|
+
latent_shape[1] // patch_size[1],
|
|
106
|
+
latent_shape[2] // patch_size[2],
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
num_tiles = (
|
|
110
|
+
math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]),
|
|
111
|
+
math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]),
|
|
112
|
+
math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]),
|
|
113
|
+
)
|
|
114
|
+
total_seq_length = math.prod(dit_seq_shape)
|
|
115
|
+
|
|
116
|
+
tile_partition_indices = get_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
|
|
117
|
+
reverse_tile_partition_indices = get_reverse_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
|
|
118
|
+
variable_block_sizes = construct_variable_block_sizes(dit_seq_shape, num_tiles, device)
|
|
119
|
+
non_pad_index = get_non_pad_index(variable_block_sizes, math.prod(VSA_TILE_SIZE))
|
|
120
|
+
|
|
121
|
+
return {
|
|
122
|
+
"sparsity": sparsity,
|
|
123
|
+
"num_tiles": num_tiles,
|
|
124
|
+
"total_seq_length": total_seq_length,
|
|
125
|
+
"tile_partition_indices": tile_partition_indices,
|
|
126
|
+
"reverse_tile_partition_indices": reverse_tile_partition_indices,
|
|
127
|
+
"variable_block_sizes": variable_block_sizes,
|
|
128
|
+
"non_pad_index": non_pad_index,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def tile(
|
|
133
|
+
x: torch.Tensor,
|
|
134
|
+
num_tiles: tuple[int, int, int],
|
|
135
|
+
tile_partition_indices: torch.LongTensor,
|
|
136
|
+
non_pad_index: torch.LongTensor,
|
|
137
|
+
) -> torch.Tensor:
|
|
138
|
+
t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]
|
|
139
|
+
h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]
|
|
140
|
+
w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]
|
|
141
|
+
|
|
142
|
+
x_padded = torch.zeros(
|
|
143
|
+
(x.shape[0], t_padded_size * h_padded_size * w_padded_size, x.shape[-2], x.shape[-1]),
|
|
144
|
+
device=x.device,
|
|
145
|
+
dtype=x.dtype,
|
|
146
|
+
)
|
|
147
|
+
x_padded[:, non_pad_index] = x[:, tile_partition_indices]
|
|
148
|
+
return x_padded
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def untile(
|
|
152
|
+
x: torch.Tensor, reverse_tile_partition_indices: torch.LongTensor, non_pad_index: torch.LongTensor
|
|
153
|
+
) -> torch.Tensor:
|
|
154
|
+
x = x[:, non_pad_index][:, reverse_tile_partition_indices]
|
|
155
|
+
return x
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def video_sparse_attn(
|
|
159
|
+
q: torch.Tensor,
|
|
160
|
+
k: torch.Tensor,
|
|
161
|
+
v: torch.Tensor,
|
|
162
|
+
g: torch.Tensor,
|
|
163
|
+
sparsity: float,
|
|
164
|
+
num_tiles: tuple[int, int, int],
|
|
165
|
+
total_seq_length: int,
|
|
166
|
+
tile_partition_indices: torch.LongTensor,
|
|
167
|
+
reverse_tile_partition_indices: torch.LongTensor,
|
|
168
|
+
variable_block_sizes: torch.LongTensor,
|
|
169
|
+
non_pad_index: torch.LongTensor,
|
|
170
|
+
):
|
|
171
|
+
q = tile(q, num_tiles, tile_partition_indices, non_pad_index)
|
|
172
|
+
k = tile(k, num_tiles, tile_partition_indices, non_pad_index)
|
|
173
|
+
v = tile(v, num_tiles, tile_partition_indices, non_pad_index)
|
|
174
|
+
g = tile(g, num_tiles, tile_partition_indices, non_pad_index)
|
|
175
|
+
|
|
176
|
+
q = q.transpose(1, 2).contiguous()
|
|
177
|
+
k = k.transpose(1, 2).contiguous()
|
|
178
|
+
v = v.transpose(1, 2).contiguous()
|
|
179
|
+
g = g.transpose(1, 2).contiguous()
|
|
180
|
+
|
|
181
|
+
topk = math.ceil((1 - sparsity) * (total_seq_length / math.prod(VSA_TILE_SIZE)))
|
|
182
|
+
out = vsa_core(
|
|
183
|
+
q,
|
|
184
|
+
k,
|
|
185
|
+
v,
|
|
186
|
+
variable_block_sizes=variable_block_sizes,
|
|
187
|
+
topk=topk,
|
|
188
|
+
block_size=VSA_TILE_SIZE,
|
|
189
|
+
compress_attn_weight=g,
|
|
190
|
+
).transpose(1, 2)
|
|
191
|
+
out = untile(out, reverse_tile_partition_indices, non_pad_index)
|
|
192
|
+
return out
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def distributed_video_sparse_attn(
|
|
196
|
+
q: torch.Tensor,
|
|
197
|
+
k: torch.Tensor,
|
|
198
|
+
v: torch.Tensor,
|
|
199
|
+
g: torch.Tensor,
|
|
200
|
+
sparsity: float,
|
|
201
|
+
num_tiles: tuple[int, int, int],
|
|
202
|
+
total_seq_length: int,
|
|
203
|
+
tile_partition_indices: torch.LongTensor,
|
|
204
|
+
reverse_tile_partition_indices: torch.LongTensor,
|
|
205
|
+
variable_block_sizes: torch.LongTensor,
|
|
206
|
+
non_pad_index: torch.LongTensor,
|
|
207
|
+
scatter_idx: int = 2,
|
|
208
|
+
gather_idx: int = 1,
|
|
209
|
+
):
|
|
210
|
+
from yunchang.comm.all_to_all import SeqAllToAll4D
|
|
211
|
+
|
|
212
|
+
assert get_sp_ring_world_size() == 1, "distributed video sparse attention requires ring degree to be 1"
|
|
213
|
+
sp_ulysses_group = get_sp_ulysses_group()
|
|
214
|
+
|
|
215
|
+
q = SeqAllToAll4D.apply(sp_ulysses_group, q, scatter_idx, gather_idx)
|
|
216
|
+
k = SeqAllToAll4D.apply(sp_ulysses_group, k, scatter_idx, gather_idx)
|
|
217
|
+
v = SeqAllToAll4D.apply(sp_ulysses_group, v, scatter_idx, gather_idx)
|
|
218
|
+
g = SeqAllToAll4D.apply(sp_ulysses_group, g, scatter_idx, gather_idx)
|
|
219
|
+
|
|
220
|
+
out = video_sparse_attn(
|
|
221
|
+
q,
|
|
222
|
+
k,
|
|
223
|
+
v,
|
|
224
|
+
g,
|
|
225
|
+
sparsity,
|
|
226
|
+
num_tiles,
|
|
227
|
+
total_seq_length,
|
|
228
|
+
tile_partition_indices,
|
|
229
|
+
reverse_tile_partition_indices,
|
|
230
|
+
variable_block_sizes,
|
|
231
|
+
non_pad_index,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
out = SeqAllToAll4D.apply(sp_ulysses_group, out, gather_idx, scatter_idx)
|
|
235
|
+
return out
|
|
@@ -86,7 +86,6 @@ class FluxControlNet(PreTrainedModel):
|
|
|
86
86
|
def __init__(
|
|
87
87
|
self,
|
|
88
88
|
condition_channels: int = 64,
|
|
89
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
90
89
|
device: str = "cuda:0",
|
|
91
90
|
dtype: torch.dtype = torch.bfloat16,
|
|
92
91
|
):
|
|
@@ -103,10 +102,7 @@ class FluxControlNet(PreTrainedModel):
|
|
|
103
102
|
self.x_embedder = nn.Linear(64, 3072, device=device, dtype=dtype)
|
|
104
103
|
self.controlnet_x_embedder = nn.Linear(condition_channels, 3072)
|
|
105
104
|
self.blocks = nn.ModuleList(
|
|
106
|
-
[
|
|
107
|
-
FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
|
|
108
|
-
for _ in range(6)
|
|
109
|
-
]
|
|
105
|
+
[FluxDoubleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(6)]
|
|
110
106
|
)
|
|
111
107
|
# controlnet projection
|
|
112
108
|
self.blocks_proj = nn.ModuleList(
|
|
@@ -128,6 +124,7 @@ class FluxControlNet(PreTrainedModel):
|
|
|
128
124
|
image_ids: torch.Tensor,
|
|
129
125
|
text_ids: torch.Tensor,
|
|
130
126
|
guidance: torch.Tensor,
|
|
127
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
131
128
|
):
|
|
132
129
|
hidden_states = self.x_embedder(hidden_states) + self.controlnet_x_embedder(control_condition)
|
|
133
130
|
condition = (
|
|
@@ -141,7 +138,9 @@ class FluxControlNet(PreTrainedModel):
|
|
|
141
138
|
# double block
|
|
142
139
|
double_block_outputs = []
|
|
143
140
|
for i, block in enumerate(self.blocks):
|
|
144
|
-
hidden_states, prompt_emb = block(
|
|
141
|
+
hidden_states, prompt_emb = block(
|
|
142
|
+
hidden_states, prompt_emb, condition, image_rotary_emb, attn_kwargs=attn_kwargs
|
|
143
|
+
)
|
|
145
144
|
double_block_outputs.append(self.blocks_proj[i](hidden_states))
|
|
146
145
|
|
|
147
146
|
# apply control scale
|
|
@@ -149,24 +148,13 @@ class FluxControlNet(PreTrainedModel):
|
|
|
149
148
|
return double_block_outputs, None
|
|
150
149
|
|
|
151
150
|
@classmethod
|
|
152
|
-
def from_state_dict(
|
|
153
|
-
cls,
|
|
154
|
-
state_dict: Dict[str, torch.Tensor],
|
|
155
|
-
device: str,
|
|
156
|
-
dtype: torch.dtype,
|
|
157
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
158
|
-
):
|
|
151
|
+
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
|
|
159
152
|
if "controlnet_x_embedder.weight" in state_dict:
|
|
160
153
|
condition_channels = state_dict["controlnet_x_embedder.weight"].shape[1]
|
|
161
154
|
else:
|
|
162
155
|
condition_channels = 64
|
|
163
156
|
|
|
164
|
-
model = cls(
|
|
165
|
-
condition_channels=condition_channels,
|
|
166
|
-
attn_kwargs=attn_kwargs,
|
|
167
|
-
device="meta",
|
|
168
|
-
dtype=dtype,
|
|
169
|
-
)
|
|
157
|
+
model = cls(condition_channels=condition_channels, device="meta", dtype=dtype)
|
|
170
158
|
model.requires_grad_(False)
|
|
171
159
|
model.load_state_dict(state_dict, assign=True)
|
|
172
160
|
model.to(device=device, dtype=dtype, non_blocking=True)
|