diffsynth-engine 0.1.1__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine-0.2.1/.github/workflows/python-publish.yml +41 -0
- diffsynth_engine-0.2.1/.gitignore +11 -0
- diffsynth_engine-0.2.1/.pre-commit-config.yaml +11 -0
- diffsynth_engine-0.2.1/PKG-INFO +34 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/README.md +14 -14
- diffsynth_engine-0.2.1/assets/dingtalk.png +0 -0
- diffsynth_engine-0.2.1/assets/showcase.jpeg +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/__init__.py +10 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +16 -14
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -3
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -3
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +1 -1
- diffsynth_engine-0.2.1/diffsynth_engine/models/__init__.py +7 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/base.py +22 -13
- diffsynth_engine-0.2.1/diffsynth_engine/models/basic/attention.py +233 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/unet_helper.py +2 -2
- diffsynth_engine-0.2.1/diffsynth_engine/models/components/siglip.py +169 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/components/vae.py +0 -1
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/flux/__init__.py +2 -0
- diffsynth_engine-0.2.1/diffsynth_engine/models/flux/flux_controlnet.py +160 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/flux/flux_dit.py +67 -96
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/flux/flux_text_encoder.py +1 -3
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/flux/flux_vae.py +1 -1
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd3/sd3_dit.py +1 -7
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sdxl/sdxl_unet.py +1 -7
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/wan_dit.py +146 -79
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/wan_image_encoder.py +2 -3
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/wan_text_encoder.py +46 -13
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/__init__.py +4 -2
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/base.py +66 -31
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/flux_image.py +190 -79
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/sd_image.py +38 -47
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/sdxl_image.py +40 -50
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/wan_video.py +156 -89
- diffsynth_engine-0.2.1/diffsynth_engine/tokenizers/__init__.py +6 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/wan.py +17 -22
- diffsynth_engine-0.2.1/diffsynth_engine/tools/__init__.py +4 -0
- diffsynth_engine-0.2.1/diffsynth_engine/tools/flux_inpainting.py +50 -0
- diffsynth_engine-0.2.1/diffsynth_engine/tools/flux_outpainting.py +58 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/download.py +1 -5
- diffsynth_engine-0.2.1/diffsynth_engine/utils/env.py +10 -0
- diffsynth_engine-0.2.1/diffsynth_engine/utils/flag.py +46 -0
- diffsynth_engine-0.2.1/diffsynth_engine/utils/image.py +25 -0
- diffsynth_engine-0.2.1/diffsynth_engine/utils/loader.py +32 -0
- diffsynth_engine-0.2.1/diffsynth_engine/utils/parallel.py +401 -0
- diffsynth_engine-0.2.1/diffsynth_engine.egg-info/PKG-INFO +34 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/SOURCES.txt +98 -2
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/requires.txt +3 -4
- diffsynth_engine-0.2.1/docs/tutorial.md +1 -0
- diffsynth_engine-0.2.1/docs/tutorial_zh.md +207 -0
- diffsynth_engine-0.2.1/examples/flux_lora.py +11 -0
- diffsynth_engine-0.2.1/examples/flux_text_to_image.py +8 -0
- diffsynth_engine-0.2.1/examples/i2v_input.jpg +0 -0
- diffsynth_engine-0.2.1/examples/sdxl_text_to_image.py +14 -0
- diffsynth_engine-0.2.1/examples/wan_image_to_video.py +35 -0
- diffsynth_engine-0.2.1/examples/wan_lora.py +33 -0
- diffsynth_engine-0.2.1/examples/wan_text_to_video.py +28 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/pyproject.toml +8 -8
- diffsynth_engine-0.2.1/tests/__init__.py +0 -0
- diffsynth_engine-0.2.1/tests/common/__init__.py +0 -0
- diffsynth_engine-0.2.1/tests/common/test_case.py +123 -0
- diffsynth_engine-0.2.1/tests/common/utils.py +29 -0
- diffsynth_engine-0.2.1/tests/data/expect/algorithm/beta_20steps.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/algorithm/ddim_20steps.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/algorithm/euler_i10.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/algorithm/exponential_20steps.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/algorithm/flow_match_euler_i10.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/algorithm/karras_20steps.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/algorithm/output.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/algorithm/recifited_flow_20steps_flux.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/algorithm/scaled_linear_20steps.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/algorithm/sgm_uniform_20steps.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/flux/flux_dit.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/flux/flux_inpainting.png +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/flux/flux_lora.png +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/flux/flux_outpainting.png +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/flux/flux_text_encoder_1.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/flux/flux_text_encoder_2.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/flux/flux_txt2img.png +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/flux/flux_union_pro_canny.png +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/flux/flux_vae.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/sd/sd_inpainting.png +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/sd/sd_lora.png +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/sd/sd_text_encoder.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/sd/sd_txt2img.png +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/sd/sd_unet.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/sd/sd_vae.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/sdxl/sdxl_inpainting.png +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/sdxl/sdxl_lora.png +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/sdxl/sdxl_text_encoder_1.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/sdxl/sdxl_text_encoder_2.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/sdxl/sdxl_txt2img.png +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/sdxl/sdxl_unet.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/sdxl/sdxl_vae.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/test_siglip_image_encoder.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/wan/wan_vae.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/input/astronaut_320_320.mp4 +0 -0
- diffsynth_engine-0.2.1/tests/data/input/canny.png +0 -0
- diffsynth_engine-0.2.1/tests/data/input/mask_image.png +0 -0
- diffsynth_engine-0.2.1/tests/data/input/test_image.png +0 -0
- diffsynth_engine-0.2.1/tests/data/input/wukong_1024_1024.png +0 -0
- diffsynth_engine-0.2.1/tests/data/input/wukong_480_480.png +0 -0
- diffsynth_engine-0.2.1/tests/test_algorithm/__init__.py +0 -0
- diffsynth_engine-0.2.1/tests/test_algorithm/test_sampler.py +42 -0
- diffsynth_engine-0.2.1/tests/test_algorithm/test_scheduler.py +77 -0
- diffsynth_engine-0.2.1/tests/test_models/__init__.py +0 -0
- diffsynth_engine-0.2.1/tests/test_models/flux/__init__.py +0 -0
- diffsynth_engine-0.2.1/tests/test_models/flux/test_flux_dit.py +208 -0
- diffsynth_engine-0.2.1/tests/test_models/flux/test_flux_text_encoder.py +114 -0
- diffsynth_engine-0.2.1/tests/test_models/flux/test_flux_vae.py +344 -0
- diffsynth_engine-0.2.1/tests/test_models/sd/__init__.py +0 -0
- diffsynth_engine-0.2.1/tests/test_models/sd/test_sd_text_encoder.py +72 -0
- diffsynth_engine-0.2.1/tests/test_models/sd/test_sd_unet.py +22 -0
- diffsynth_engine-0.2.1/tests/test_models/sd/test_sd_vae.py +353 -0
- diffsynth_engine-0.2.1/tests/test_models/sdxl/__init__.py +0 -0
- diffsynth_engine-0.2.1/tests/test_models/sdxl/test_sdxl_text_encoder.py +163 -0
- diffsynth_engine-0.2.1/tests/test_models/sdxl/test_sdxl_unet.py +21 -0
- diffsynth_engine-0.2.1/tests/test_models/sdxl/test_sdxl_vae.py +351 -0
- diffsynth_engine-0.2.1/tests/test_models/test_siglip.py +17 -0
- diffsynth_engine-0.2.1/tests/test_models/wan/test_wan_vae.py +34 -0
- diffsynth_engine-0.2.1/tests/test_pipelines/__init__.py +0 -0
- diffsynth_engine-0.2.1/tests/test_pipelines/test_flux_controlnet.py +32 -0
- diffsynth_engine-0.2.1/tests/test_pipelines/test_flux_image.py +68 -0
- diffsynth_engine-0.2.1/tests/test_pipelines/test_sd_image.py +55 -0
- diffsynth_engine-0.2.1/tests/test_pipelines/test_sdxl_image.py +59 -0
- diffsynth_engine-0.2.1/tests/test_pipelines/test_wan_video.py +24 -0
- diffsynth_engine-0.2.1/tests/test_pipelines/test_wan_video_gguf.py +24 -0
- diffsynth_engine-0.2.1/tests/test_pipelines/test_wan_video_tp.py +25 -0
- diffsynth_engine-0.2.1/tests/test_tokenizers/__init__.py +0 -0
- diffsynth_engine-0.2.1/tests/test_tokenizers/test_clip.py +135 -0
- diffsynth_engine-0.2.1/tests/test_tokenizers/test_t5.py +138 -0
- diffsynth_engine-0.2.1/tests/test_tools/__init__.py +0 -0
- diffsynth_engine-0.2.1/tests/test_tools/test_flux_tools.py +31 -0
- diffsynth_engine-0.1.1/PKG-INFO +0 -213
- diffsynth_engine-0.1.1/diffsynth_engine/models/basic/attention.py +0 -137
- diffsynth_engine-0.1.1/diffsynth_engine/models/wan/attention.py +0 -200
- diffsynth_engine-0.1.1/diffsynth_engine/tokenizers/__init__.py +0 -4
- diffsynth_engine-0.1.1/diffsynth_engine/utils/env.py +0 -7
- diffsynth_engine-0.1.1/diffsynth_engine/utils/loader.py +0 -14
- diffsynth_engine-0.1.1/diffsynth_engine/utils/parallel.py +0 -191
- diffsynth_engine-0.1.1/diffsynth_engine.egg-info/PKG-INFO +0 -213
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/LICENSE +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/__init__.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/components/vae.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/wan/dit/14b-i2v.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/wan/dit/14b-t2v.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
- {diffsynth_engine-0.1.1/diffsynth_engine/models → diffsynth_engine-0.2.1/diffsynth_engine/kernels}/__init__.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/__init__.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/lora.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/timestep.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/transformer_helper.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/components/__init__.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/components/clip.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/components/t5.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd/__init__.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd/sd_unet.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd/sd_vae.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd3/__init__.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sdxl/__init__.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/utils.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/__init__.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/wan_vae.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/base.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/clip.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/t5.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/__init__.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/constants.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/fp8_linear.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/gguf.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/lock.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/logging.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/offload.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/prompt.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/video.py +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/top_level.txt +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/setup.cfg +0 -0
- {diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/setup.py +0 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
name: release
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
tags:
|
|
6
|
+
- 'v**'
|
|
7
|
+
|
|
8
|
+
workflow_dispatch:
|
|
9
|
+
inputs:
|
|
10
|
+
branch:
|
|
11
|
+
required: true
|
|
12
|
+
default: 'main'
|
|
13
|
+
|
|
14
|
+
permissions:
|
|
15
|
+
contents: read
|
|
16
|
+
|
|
17
|
+
concurrency:
|
|
18
|
+
group: ${{ github.workflow }}-${{ github.ref }}
|
|
19
|
+
cancel-in-progress: true
|
|
20
|
+
|
|
21
|
+
jobs:
|
|
22
|
+
build-and-publish:
|
|
23
|
+
runs-on: ubuntu-latest
|
|
24
|
+
|
|
25
|
+
steps:
|
|
26
|
+
- uses: actions/checkout@v4
|
|
27
|
+
|
|
28
|
+
- uses: actions/setup-python@v5
|
|
29
|
+
with:
|
|
30
|
+
python-version: "3.10"
|
|
31
|
+
|
|
32
|
+
- name: Install build
|
|
33
|
+
run: pip install build
|
|
34
|
+
|
|
35
|
+
- name: Build dist
|
|
36
|
+
run: python -m build
|
|
37
|
+
|
|
38
|
+
- name: Publish to PyPI
|
|
39
|
+
run: |
|
|
40
|
+
pip install twine
|
|
41
|
+
twine upload dist/* --skip-existing -p ${{ secrets.PYPI_API_TOKEN }}
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: diffsynth_engine
|
|
3
|
+
Version: 0.2.1
|
|
4
|
+
Author: MuseAI x ModelScope
|
|
5
|
+
Classifier: Programming Language :: Python :: 3
|
|
6
|
+
Classifier: Operating System :: OS Independent
|
|
7
|
+
Requires-Python: >=3.10
|
|
8
|
+
License-File: LICENSE
|
|
9
|
+
Requires-Dist: torch>=2.6
|
|
10
|
+
Requires-Dist: torchvision
|
|
11
|
+
Requires-Dist: xformers; sys_platform == "linux"
|
|
12
|
+
Requires-Dist: safetensors
|
|
13
|
+
Requires-Dist: gguf
|
|
14
|
+
Requires-Dist: einops
|
|
15
|
+
Requires-Dist: ftfy
|
|
16
|
+
Requires-Dist: regex
|
|
17
|
+
Requires-Dist: sentencepiece
|
|
18
|
+
Requires-Dist: tokenizers
|
|
19
|
+
Requires-Dist: modelscope
|
|
20
|
+
Requires-Dist: flufl.lock
|
|
21
|
+
Requires-Dist: scipy
|
|
22
|
+
Requires-Dist: torchsde
|
|
23
|
+
Requires-Dist: pillow
|
|
24
|
+
Requires-Dist: imageio[ffmpeg]
|
|
25
|
+
Requires-Dist: yunchang; sys_platform == "linux"
|
|
26
|
+
Provides-Extra: dev
|
|
27
|
+
Requires-Dist: diffusers==0.31.0; extra == "dev"
|
|
28
|
+
Requires-Dist: transformers==4.45.2; extra == "dev"
|
|
29
|
+
Requires-Dist: build; extra == "dev"
|
|
30
|
+
Requires-Dist: ruff; extra == "dev"
|
|
31
|
+
Requires-Dist: scikit-image; extra == "dev"
|
|
32
|
+
Requires-Dist: pytest; extra == "dev"
|
|
33
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
34
|
+
Dynamic: license-file
|
|
@@ -6,20 +6,20 @@
|
|
|
6
6
|
[](https://GitHub.com/modelscope/DiffSynth-Engine/pull/)
|
|
7
7
|
[](https://GitHub.com/modelscope/DiffSynth-Engine/commit/)
|
|
8
8
|
|
|
9
|
-
|
|
9
|
+
DiffSynth-Engine is a high-performance engine geared towards buidling efficient inference pipelines for diffusion models.
|
|
10
10
|
|
|
11
11
|
**Key Features:**
|
|
12
12
|
|
|
13
|
-
- **
|
|
13
|
+
- **Thoughtfully-Designed Implementation:** We carefully re-implemented key components in Diffusion pipelines, such as sampler and scheduler, without introducing external dependencies on libraries like k-diffusion, ldm, or sgm.
|
|
14
14
|
|
|
15
|
-
- **Extensive Model Support:** Compatible with
|
|
15
|
+
- **Extensive Model Support:** Compatible with popular formats (e.g., CivitAI) of base models and LoRA models , catering to diverse use cases.
|
|
16
16
|
|
|
17
|
-
- **
|
|
18
|
-
and
|
|
17
|
+
- **Versatile Resource Management:** Comprehensive support for varous model quantization (e.g., FP8, INT8)
|
|
18
|
+
and offloading strategies, enabling loading of larger diffusion models (e.g., Flux.1 Dev) on limited hardware budget of GPU memory.
|
|
19
19
|
|
|
20
|
-
- **
|
|
20
|
+
- **Optimized Performance:** Carefully-crafted inference pipeline to achieve fast generation across various hardware environments.
|
|
21
21
|
|
|
22
|
-
- **Platform
|
|
22
|
+
- **Cross-Platform Support:** Runnable on Windows, macOS (Apple Silicon), and Linux, ensuring a smooth experience across different operating systems.
|
|
23
23
|
|
|
24
24
|
## Quick Start
|
|
25
25
|
### Requirements
|
|
@@ -29,13 +29,13 @@ and offload strategies, enabling users to run large models (e.g., Flux.1 Dev) on
|
|
|
29
29
|
|
|
30
30
|
### Installation
|
|
31
31
|
|
|
32
|
-
Install
|
|
33
|
-
```
|
|
32
|
+
Install released version (from PyPI):
|
|
33
|
+
```shell
|
|
34
34
|
pip3 install diffsynth-engine
|
|
35
35
|
```
|
|
36
36
|
|
|
37
|
-
Install
|
|
38
|
-
```
|
|
37
|
+
Install from source:
|
|
38
|
+
```shell
|
|
39
39
|
git clone https://github.com/modelscope/diffsynth-engine.git && cd diffsynth-engine
|
|
40
40
|
pip3 install -e .
|
|
41
41
|
```
|
|
@@ -71,10 +71,10 @@ For more details, please refer to our tutorials ([English](./docs/tutorial.md),
|
|
|
71
71
|
|
|
72
72
|
## Contact
|
|
73
73
|
|
|
74
|
-
If you have any questions or feedback, please scan the QR code or send email to muse@alibaba-inc.com.
|
|
74
|
+
If you have any questions or feedback, please scan the QR code below, or send email to muse@alibaba-inc.com.
|
|
75
75
|
|
|
76
76
|
<div style="display: flex; justify-content: space-between;">
|
|
77
|
-
<img src="assets/dingtalk.png" alt="dingtalk"
|
|
77
|
+
<img src="assets/dingtalk.png" alt="dingtalk" width="400" />
|
|
78
78
|
</div>
|
|
79
79
|
|
|
80
80
|
## License
|
|
@@ -82,7 +82,7 @@ This project is licensed under the Apache License 2.0. See the LICENSE file for
|
|
|
82
82
|
|
|
83
83
|
## Citation
|
|
84
84
|
|
|
85
|
-
If you use this codebase, or otherwise found our work
|
|
85
|
+
If you use this codebase, or otherwise found our work helpful, please cite:
|
|
86
86
|
|
|
87
87
|
```bibtex
|
|
88
88
|
@misc{diffsynth-engine2025,
|
|
Binary file
|
|
Binary file
|
|
@@ -7,11 +7,16 @@ from .pipelines import (
|
|
|
7
7
|
SDXLModelConfig,
|
|
8
8
|
SDModelConfig,
|
|
9
9
|
WanModelConfig,
|
|
10
|
+
ControlNetParams,
|
|
10
11
|
)
|
|
12
|
+
from .models.flux import FluxControlNet
|
|
11
13
|
from .utils.download import fetch_model, fetch_modelscope_model, fetch_civitai_model
|
|
12
14
|
from .utils.video import load_video, save_video
|
|
15
|
+
from .tools import FluxInpaintingTool, FluxOutpaintingTool
|
|
16
|
+
|
|
13
17
|
__all__ = [
|
|
14
18
|
"FluxImagePipeline",
|
|
19
|
+
"FluxControlNet",
|
|
15
20
|
"SDXLImagePipeline",
|
|
16
21
|
"SDImagePipeline",
|
|
17
22
|
"WanVideoPipeline",
|
|
@@ -19,7 +24,12 @@ __all__ = [
|
|
|
19
24
|
"SDXLModelConfig",
|
|
20
25
|
"SDModelConfig",
|
|
21
26
|
"WanModelConfig",
|
|
27
|
+
"FluxInpaintingTool",
|
|
28
|
+
"FluxOutpaintingTool",
|
|
29
|
+
"ControlNetParams",
|
|
22
30
|
"fetch_model",
|
|
23
31
|
"fetch_modelscope_model",
|
|
24
32
|
"fetch_civitai_model",
|
|
33
|
+
"load_video",
|
|
34
|
+
"save_video",
|
|
25
35
|
]
|
|
@@ -5,18 +5,19 @@ from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zer
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class RecifitedFlowScheduler(BaseScheduler):
|
|
8
|
-
def __init__(
|
|
9
|
-
|
|
10
|
-
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
shift=1.0,
|
|
11
|
+
sigma_min=0.001,
|
|
11
12
|
sigma_max=1.0,
|
|
12
|
-
num_train_timesteps=1000,
|
|
13
|
+
num_train_timesteps=1000,
|
|
13
14
|
use_dynamic_shifting=False,
|
|
14
15
|
):
|
|
15
16
|
self.shift = shift
|
|
16
17
|
self.sigma_min = sigma_min
|
|
17
18
|
self.sigma_max = sigma_max
|
|
18
|
-
self.num_train_timesteps = num_train_timesteps
|
|
19
|
-
self.use_dynamic_shifting = use_dynamic_shifting
|
|
19
|
+
self.num_train_timesteps = num_train_timesteps
|
|
20
|
+
self.use_dynamic_shifting = use_dynamic_shifting
|
|
20
21
|
|
|
21
22
|
def _sigma_to_t(self, sigma):
|
|
22
23
|
return sigma * self.num_train_timesteps
|
|
@@ -30,19 +31,20 @@ class RecifitedFlowScheduler(BaseScheduler):
|
|
|
30
31
|
def _shift_sigma(self, sigma: torch.Tensor, shift: float):
|
|
31
32
|
return shift * sigma / (1 + (shift - 1) * sigma)
|
|
32
33
|
|
|
33
|
-
def schedule(
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
34
|
+
def schedule(
|
|
35
|
+
self,
|
|
36
|
+
num_inference_steps: int,
|
|
37
|
+
mu: float | None = None,
|
|
38
|
+
sigma_min: float | None = None,
|
|
39
|
+
sigma_max: float | None = None,
|
|
38
40
|
):
|
|
39
41
|
sigma_min = self.sigma_min if sigma_min is None else sigma_min
|
|
40
|
-
sigma_max = self.sigma_max if sigma_max is None else sigma_max
|
|
42
|
+
sigma_max = self.sigma_max if sigma_max is None else sigma_max
|
|
41
43
|
sigmas = torch.linspace(sigma_max, sigma_min, num_inference_steps)
|
|
42
44
|
if self.use_dynamic_shifting:
|
|
43
|
-
sigmas = self._time_shift(mu, 1.0, sigmas)
|
|
45
|
+
sigmas = self._time_shift(mu, 1.0, sigmas) # FLUX
|
|
44
46
|
else:
|
|
45
47
|
sigmas = self._shift_sigma(sigmas, self.shift)
|
|
46
48
|
timesteps = sigmas * self.num_train_timesteps
|
|
47
49
|
sigmas = append_zero(sigmas)
|
|
48
|
-
return sigmas, timesteps
|
|
50
|
+
return sigmas, timesteps
|
|
@@ -1,7 +1,4 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from .linear import ScaledLinearScheduler
|
|
3
|
-
from ..base_scheduler import append_zero
|
|
4
|
-
import numpy as np
|
|
5
2
|
|
|
6
3
|
from diffsynth_engine.algorithm.noise_scheduler.stable_diffusion.linear import ScaledLinearScheduler
|
|
7
4
|
from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
|
|
@@ -1,7 +1,4 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from .linear import ScaledLinearScheduler
|
|
3
|
-
from ..base_scheduler import append_zero
|
|
4
|
-
import numpy as np
|
|
5
2
|
|
|
6
3
|
from diffsynth_engine.algorithm.noise_scheduler.stable_diffusion.linear import ScaledLinearScheduler
|
|
7
4
|
from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
|
|
@@ -2,7 +2,7 @@ import torch
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
class FlowMatchEulerSampler:
|
|
5
|
-
def initialize(self, init_latents, timesteps, sigmas, mask=None):
|
|
5
|
+
def initialize(self, init_latents, timesteps, sigmas, mask=None):
|
|
6
6
|
self.init_latents = init_latents
|
|
7
7
|
self.timesteps = timesteps
|
|
8
8
|
self.sigmas = sigmas
|
|
@@ -1,22 +1,14 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import torch
|
|
3
3
|
import torch.nn as nn
|
|
4
|
-
from typing import Dict, Union
|
|
5
|
-
from
|
|
6
|
-
|
|
4
|
+
from typing import Dict, Union, List, Any
|
|
5
|
+
from diffsynth_engine.utils.loader import load_file
|
|
6
|
+
from diffsynth_engine.models.basic.lora import LoRALinear, LoRAConv2d
|
|
7
7
|
from diffsynth_engine.models.utils import no_init_weights
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
class LoRAStateDictConverter:
|
|
11
|
-
def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
12
|
-
return {"lora": lora_state_dict}
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
StateDictType = Dict[str, torch.Tensor]
|
|
16
|
-
|
|
17
|
-
|
|
18
10
|
class StateDictConverter:
|
|
19
|
-
def convert(self, state_dict:
|
|
11
|
+
def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
20
12
|
return state_dict
|
|
21
13
|
|
|
22
14
|
|
|
@@ -29,17 +21,34 @@ class PreTrainedModel(nn.Module):
|
|
|
29
21
|
|
|
30
22
|
@classmethod
|
|
31
23
|
def from_pretrained(cls, pretrained_model_path: Union[str, os.PathLike], device: str, dtype: torch.dtype, **kwargs):
|
|
32
|
-
state_dict = load_file(pretrained_model_path
|
|
24
|
+
state_dict = load_file(pretrained_model_path)
|
|
33
25
|
return cls.from_state_dict(state_dict, device=device, dtype=dtype, **kwargs)
|
|
34
26
|
|
|
35
27
|
@classmethod
|
|
36
28
|
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype, **kwargs):
|
|
37
29
|
with no_init_weights():
|
|
38
30
|
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, **kwargs)
|
|
31
|
+
model.to_empty(device=device)
|
|
39
32
|
model.load_state_dict(state_dict)
|
|
40
33
|
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
41
34
|
return model
|
|
42
35
|
|
|
36
|
+
def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = True):
|
|
37
|
+
for args in lora_args:
|
|
38
|
+
key = args["name"]
|
|
39
|
+
module = self.get_submodule(key)
|
|
40
|
+
if not isinstance(module, (LoRALinear, LoRAConv2d)):
|
|
41
|
+
raise ValueError(f"Unsupported lora key: {key}")
|
|
42
|
+
if fused:
|
|
43
|
+
module.add_frozen_lora(**args)
|
|
44
|
+
else:
|
|
45
|
+
module.add_lora(**args)
|
|
46
|
+
|
|
47
|
+
def unload_loras(self):
|
|
48
|
+
for module in self.modules():
|
|
49
|
+
if isinstance(module, (LoRALinear, LoRAConv2d)):
|
|
50
|
+
module.clear()
|
|
51
|
+
|
|
43
52
|
|
|
44
53
|
def split_suffix(name: str):
|
|
45
54
|
suffix_list = [
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from einops import rearrange, repeat
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from diffsynth_engine.utils import logging
|
|
8
|
+
from diffsynth_engine.utils.flag import (
|
|
9
|
+
FLASH_ATTN_3_AVAILABLE,
|
|
10
|
+
FLASH_ATTN_2_AVAILABLE,
|
|
11
|
+
XFORMERS_AVAILABLE,
|
|
12
|
+
SDPA_AVAILABLE,
|
|
13
|
+
SAGE_ATTN_AVAILABLE,
|
|
14
|
+
SPARGE_ATTN_AVAILABLE,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
logger = logging.get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
|
|
21
|
+
padding_size = (alignment - x.shape[dim] % alignment) % alignment
|
|
22
|
+
padded_x = F.pad(x, (0, padding_size), "constant", 0)
|
|
23
|
+
return padded_x[..., : x.shape[dim]]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
if FLASH_ATTN_3_AVAILABLE:
|
|
27
|
+
from flash_attn_interface import flash_attn_func as flash_attn3
|
|
28
|
+
if FLASH_ATTN_2_AVAILABLE:
|
|
29
|
+
from flash_attn import flash_attn_func as flash_attn2
|
|
30
|
+
if XFORMERS_AVAILABLE:
|
|
31
|
+
from xformers.ops import memory_efficient_attention
|
|
32
|
+
|
|
33
|
+
def xformers_attn(q, k, v, attn_mask=None, scale=None):
|
|
34
|
+
if attn_mask is not None:
|
|
35
|
+
attn_mask = repeat(attn_mask, "S L -> B H S L", B=q.shape[0], H=q.shape[2])
|
|
36
|
+
attn_mask = memory_align(attn_mask)
|
|
37
|
+
return memory_efficient_attention(q, k, v, attn_bias=attn_mask, scale=scale)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
if SDPA_AVAILABLE:
|
|
41
|
+
|
|
42
|
+
def sdpa_attn(q, k, v, attn_mask=None, scale=None):
|
|
43
|
+
q = q.transpose(1, 2)
|
|
44
|
+
k = k.transpose(1, 2)
|
|
45
|
+
v = v.transpose(1, 2)
|
|
46
|
+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
47
|
+
return out.transpose(1, 2)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
if SAGE_ATTN_AVAILABLE:
|
|
51
|
+
from sageattention import sageattn
|
|
52
|
+
|
|
53
|
+
def sage_attn(q, k, v, attn_mask=None, scale=None):
|
|
54
|
+
q = q.transpose(1, 2)
|
|
55
|
+
k = k.transpose(1, 2)
|
|
56
|
+
v = v.transpose(1, 2)
|
|
57
|
+
out = sageattn(q, k, v, attn_mask=attn_mask, sm_scale=scale)
|
|
58
|
+
return out.transpose(1, 2)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
if SPARGE_ATTN_AVAILABLE:
|
|
62
|
+
from spas_sage_attn import spas_sage2_attn_meansim_cuda
|
|
63
|
+
|
|
64
|
+
def sparge_attn(self, q, k, v, attn_mask=None, scale=None):
|
|
65
|
+
q = q.transpose(1, 2)
|
|
66
|
+
k = k.transpose(1, 2)
|
|
67
|
+
v = v.transpose(1, 2)
|
|
68
|
+
out = spas_sage2_attn_meansim_cuda(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
69
|
+
return out.transpose(1, 2)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def eager_attn(q, k, v, attn_mask=None, scale=None):
|
|
73
|
+
q = q.transpose(1, 2)
|
|
74
|
+
k = k.transpose(1, 2)
|
|
75
|
+
v = v.transpose(1, 2)
|
|
76
|
+
scale = 1 / q.shape[-1] ** 0.5 if scale is None else scale
|
|
77
|
+
q = q * scale
|
|
78
|
+
attn = torch.matmul(q, k.transpose(-2, -1))
|
|
79
|
+
if attn_mask is not None:
|
|
80
|
+
attn = attn + attn_mask
|
|
81
|
+
attn = attn.softmax(-1)
|
|
82
|
+
out = attn @ v
|
|
83
|
+
return out.transpose(1, 2)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def attention(
|
|
87
|
+
q,
|
|
88
|
+
k,
|
|
89
|
+
v,
|
|
90
|
+
attn_impl: Optional[str] = None,
|
|
91
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
92
|
+
scale: Optional[float] = None,
|
|
93
|
+
):
|
|
94
|
+
"""
|
|
95
|
+
q: [B, Lq, Nq, C1]
|
|
96
|
+
k: [B, Lk, Nk, C1]
|
|
97
|
+
v: [B, Lk, Nk, C2]
|
|
98
|
+
"""
|
|
99
|
+
assert attn_impl in [
|
|
100
|
+
None,
|
|
101
|
+
"auto",
|
|
102
|
+
"eager",
|
|
103
|
+
"flash_attn_2",
|
|
104
|
+
"flash_attn_3",
|
|
105
|
+
"xformers",
|
|
106
|
+
"sdpa",
|
|
107
|
+
"sage_attn",
|
|
108
|
+
"sparge_attn",
|
|
109
|
+
]
|
|
110
|
+
if attn_impl is None or attn_impl == "auto":
|
|
111
|
+
if FLASH_ATTN_3_AVAILABLE:
|
|
112
|
+
return flash_attn3(q, k, v, softmax_scale=scale)
|
|
113
|
+
elif FLASH_ATTN_2_AVAILABLE:
|
|
114
|
+
return flash_attn2(q, k, v, softmax_scale=scale)
|
|
115
|
+
elif XFORMERS_AVAILABLE:
|
|
116
|
+
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
117
|
+
elif SDPA_AVAILABLE:
|
|
118
|
+
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
119
|
+
else:
|
|
120
|
+
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
121
|
+
else:
|
|
122
|
+
if attn_impl == "eager":
|
|
123
|
+
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
124
|
+
elif attn_impl == "flash_attn_3":
|
|
125
|
+
return flash_attn3(q, k, v, softmax_scale=scale)
|
|
126
|
+
elif attn_impl == "flash_attn_2":
|
|
127
|
+
return flash_attn2(q, k, v, softmax_scale=scale)
|
|
128
|
+
elif attn_impl == "xformers":
|
|
129
|
+
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
130
|
+
elif attn_impl == "sdpa":
|
|
131
|
+
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
132
|
+
elif attn_impl == "sage_attn":
|
|
133
|
+
return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
134
|
+
elif attn_impl == "sparge_attn":
|
|
135
|
+
return sparge_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
136
|
+
else:
|
|
137
|
+
raise ValueError(f"Invalid attention implementation: {attn_impl}")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class Attention(nn.Module):
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
q_dim,
|
|
144
|
+
num_heads,
|
|
145
|
+
head_dim,
|
|
146
|
+
kv_dim=None,
|
|
147
|
+
bias_q=False,
|
|
148
|
+
bias_kv=False,
|
|
149
|
+
bias_out=False,
|
|
150
|
+
scale=None,
|
|
151
|
+
attn_impl: Optional[str] = None,
|
|
152
|
+
device: str = "cuda:0",
|
|
153
|
+
dtype: torch.dtype = torch.float16,
|
|
154
|
+
):
|
|
155
|
+
super().__init__()
|
|
156
|
+
dim_inner = head_dim * num_heads
|
|
157
|
+
kv_dim = kv_dim if kv_dim is not None else q_dim
|
|
158
|
+
self.num_heads = num_heads
|
|
159
|
+
self.head_dim = head_dim
|
|
160
|
+
|
|
161
|
+
self.to_q = nn.Linear(q_dim, dim_inner, bias=bias_q, device=device, dtype=dtype)
|
|
162
|
+
self.to_k = nn.Linear(kv_dim, dim_inner, bias=bias_kv, device=device, dtype=dtype)
|
|
163
|
+
self.to_v = nn.Linear(kv_dim, dim_inner, bias=bias_kv, device=device, dtype=dtype)
|
|
164
|
+
self.to_out = nn.Linear(dim_inner, q_dim, bias=bias_out, device=device, dtype=dtype)
|
|
165
|
+
self.attn_impl = attn_impl
|
|
166
|
+
self.scale = scale
|
|
167
|
+
|
|
168
|
+
def forward(
|
|
169
|
+
self,
|
|
170
|
+
x: torch.Tensor,
|
|
171
|
+
y: Optional[torch.Tensor] = None,
|
|
172
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
173
|
+
):
|
|
174
|
+
if y is None:
|
|
175
|
+
y = x
|
|
176
|
+
q = rearrange(self.to_q(x), "b s (n d) -> b s n d", n=self.num_heads)
|
|
177
|
+
k = rearrange(self.to_k(y), "b s (n d) -> b s n d", n=self.num_heads)
|
|
178
|
+
v = rearrange(self.to_v(y), "b s (n d) -> b s n d", n=self.num_heads)
|
|
179
|
+
out = attention(q, k, v, attn_mask=attn_mask, attn_impl=self.attn_impl, scale=self.scale)
|
|
180
|
+
out = rearrange(out, "b s n d -> b s (n d)", n=self.num_heads)
|
|
181
|
+
return self.to_out(out)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def long_context_attention(
|
|
185
|
+
q,
|
|
186
|
+
k,
|
|
187
|
+
v,
|
|
188
|
+
attn_impl: Optional[str] = None,
|
|
189
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
190
|
+
scale: Optional[float] = None,
|
|
191
|
+
):
|
|
192
|
+
"""
|
|
193
|
+
q: [B, Lq, Nq, C1]
|
|
194
|
+
k: [B, Lk, Nk, C1]
|
|
195
|
+
v: [B, Lk, Nk, C2]
|
|
196
|
+
"""
|
|
197
|
+
from yunchang import LongContextAttention
|
|
198
|
+
from yunchang.kernels import AttnType
|
|
199
|
+
|
|
200
|
+
assert attn_impl in [
|
|
201
|
+
None,
|
|
202
|
+
"auto",
|
|
203
|
+
"eager",
|
|
204
|
+
"flash_attn_2",
|
|
205
|
+
"flash_attn_3",
|
|
206
|
+
"xformers",
|
|
207
|
+
"sdpa",
|
|
208
|
+
"sage_attn",
|
|
209
|
+
"sparge_attn",
|
|
210
|
+
]
|
|
211
|
+
if attn_impl is None or attn_impl == "auto":
|
|
212
|
+
if FLASH_ATTN_3_AVAILABLE:
|
|
213
|
+
attn_func = LongContextAttention(attn_type=AttnType.FA3)
|
|
214
|
+
elif FLASH_ATTN_2_AVAILABLE:
|
|
215
|
+
attn_func = LongContextAttention(attn_type=AttnType.FA)
|
|
216
|
+
elif SDPA_AVAILABLE:
|
|
217
|
+
attn_func = LongContextAttention(attn_type=AttnType.TORCH)
|
|
218
|
+
else:
|
|
219
|
+
raise ValueError("No available long context attention implementation")
|
|
220
|
+
else:
|
|
221
|
+
if attn_impl == "flash_attn_3":
|
|
222
|
+
attn_func = LongContextAttention(attn_type=AttnType.FA3)
|
|
223
|
+
elif attn_impl == "flash_attn_2":
|
|
224
|
+
attn_func = LongContextAttention(attn_type=AttnType.FA)
|
|
225
|
+
elif attn_impl == "sdpa":
|
|
226
|
+
attn_func = LongContextAttention(attn_type=AttnType.TORCH)
|
|
227
|
+
elif attn_impl == "sage_attn":
|
|
228
|
+
attn_func = LongContextAttention(attn_type=AttnType.SAGE_FP8)
|
|
229
|
+
elif attn_impl == "sparge_attn":
|
|
230
|
+
attn_func = LongContextAttention(attn_type=AttnType.SPARSE_SAGE)
|
|
231
|
+
else:
|
|
232
|
+
raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
|
|
233
|
+
return attn_func(q, k, v, softmax_scale=scale)
|
{diffsynth_engine-0.1.1 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/unet_helper.py
RENAMED
|
@@ -51,12 +51,12 @@ class BasicTransformerBlock(nn.Module):
|
|
|
51
51
|
def forward(self, hidden_states, encoder_hidden_states):
|
|
52
52
|
# 1. Self-Attention
|
|
53
53
|
norm_hidden_states = self.norm1(hidden_states)
|
|
54
|
-
attn_output = self.attn1(norm_hidden_states
|
|
54
|
+
attn_output = self.attn1(norm_hidden_states)
|
|
55
55
|
hidden_states = attn_output + hidden_states
|
|
56
56
|
|
|
57
57
|
# 2. Cross-Attention
|
|
58
58
|
norm_hidden_states = self.norm2(hidden_states)
|
|
59
|
-
attn_output = self.attn2(norm_hidden_states,
|
|
59
|
+
attn_output = self.attn2(norm_hidden_states, y=encoder_hidden_states)
|
|
60
60
|
hidden_states = attn_output + hidden_states
|
|
61
61
|
|
|
62
62
|
# 3. Feed-forward
|