diffsynth-engine 0.3.6.dev8__tar.gz → 0.3.6.dev10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/PKG-INFO +1 -1
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/__init__.py +10 -8
- diffsynth_engine-0.3.6.dev10/diffsynth_engine/configs/__init__.py +23 -0
- diffsynth_engine-0.3.6.dev10/diffsynth_engine/configs/controlnet.py +17 -0
- diffsynth_engine-0.3.6.dev10/diffsynth_engine/configs/pipeline.py +206 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/basic/attention.py +43 -4
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/flux/flux_controlnet.py +8 -5
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/flux/flux_dit.py +22 -16
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/flux/flux_dit_fbcache.py +5 -5
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd/sd_controlnet.py +2 -4
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sdxl/sdxl_controlnet.py +1 -2
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/wan/wan_dit.py +15 -15
- diffsynth_engine-0.3.6.dev10/diffsynth_engine/pipelines/__init__.py +17 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/pipelines/base.py +14 -65
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/pipelines/flux_image.py +85 -158
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/pipelines/sd_image.py +30 -64
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/pipelines/sdxl_image.py +39 -71
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/pipelines/wan_video.py +66 -105
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tools/flux_inpainting_tool.py +7 -3
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tools/flux_outpainting_tool.py +7 -3
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tools/flux_reference_tool.py +21 -5
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tools/flux_replace_tool.py +15 -3
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/fp8_linear.py +14 -5
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/parallel.py +1 -1
- diffsynth_engine-0.3.6.dev10/diffsynth_engine/utils/platform.py +20 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine.egg-info/PKG-INFO +1 -1
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine.egg-info/SOURCES.txt +3 -0
- diffsynth_engine-0.3.6.dev8/diffsynth_engine/pipelines/__init__.py +0 -20
- diffsynth_engine-0.3.6.dev8/diffsynth_engine/utils/platform.py +0 -12
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/.gitignore +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/.pre-commit-config.yaml +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/LICENSE +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/MANIFEST.in +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/README.md +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/assets/dingtalk.png +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/assets/showcase.jpeg +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/components/vae.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/wan/dit/14b-flf2v.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/wan/dit/14b-i2v.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/models/wan/dit/14b-t2v.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/kernels/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/base.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/basic/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/basic/lora.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/basic/timestep.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/basic/transformer_helper.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/basic/unet_helper.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/flux/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/flux/flux_redux.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/flux/flux_vae.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd/sd_unet.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd/sd_vae.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd3/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd3/sd3_dit.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sdxl/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sdxl/sdxl_unet.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/text_encoder/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/text_encoder/clip.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/text_encoder/siglip.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/text_encoder/t5.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/utils.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/vae/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/vae/vae.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/wan/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/models/wan/wan_vae.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/pipelines/controlnet_helper.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/processor/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/processor/canny_processor.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/processor/depth_processor.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tokenizers/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tokenizers/base.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tokenizers/clip.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tokenizers/t5.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tokenizers/wan.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/tools/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/__init__.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/constants.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/download.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/env.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/flag.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/gguf.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/image.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/loader.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/lock.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/logging.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/offload.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/onnx.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/prompt.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine/utils/video.py +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine.egg-info/requires.txt +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/diffsynth_engine.egg-info/top_level.txt +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/docs/tutorial.md +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/docs/tutorial_zh.md +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/pyproject.toml +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/setup.cfg +0 -0
- {diffsynth_engine-0.3.6.dev8 → diffsynth_engine-0.3.6.dev10}/setup.py +0 -0
|
@@ -1,12 +1,14 @@
|
|
|
1
|
+
from .configs import (
|
|
2
|
+
SDPipelineConfig,
|
|
3
|
+
SDXLPipelineConfig,
|
|
4
|
+
FluxPipelineConfig,
|
|
5
|
+
WanPipelineConfig,
|
|
6
|
+
)
|
|
1
7
|
from .pipelines import (
|
|
2
8
|
FluxImagePipeline,
|
|
3
9
|
SDXLImagePipeline,
|
|
4
10
|
SDImagePipeline,
|
|
5
11
|
WanVideoPipeline,
|
|
6
|
-
FluxModelConfig,
|
|
7
|
-
SDXLModelConfig,
|
|
8
|
-
SDModelConfig,
|
|
9
|
-
WanModelConfig,
|
|
10
12
|
ControlNetParams,
|
|
11
13
|
)
|
|
12
14
|
from .models.flux import FluxControlNet, FluxIPAdapter, FluxRedux
|
|
@@ -23,6 +25,10 @@ from .tools import (
|
|
|
23
25
|
)
|
|
24
26
|
|
|
25
27
|
__all__ = [
|
|
28
|
+
"SDPipelineConfig",
|
|
29
|
+
"SDXLPipelineConfig",
|
|
30
|
+
"FluxPipelineConfig",
|
|
31
|
+
"WanPipelineConfig",
|
|
26
32
|
"FluxImagePipeline",
|
|
27
33
|
"FluxControlNet",
|
|
28
34
|
"FluxIPAdapter",
|
|
@@ -32,10 +38,6 @@ __all__ = [
|
|
|
32
38
|
"SDXLImagePipeline",
|
|
33
39
|
"SDImagePipeline",
|
|
34
40
|
"WanVideoPipeline",
|
|
35
|
-
"FluxModelConfig",
|
|
36
|
-
"SDXLModelConfig",
|
|
37
|
-
"SDModelConfig",
|
|
38
|
-
"WanModelConfig",
|
|
39
41
|
"FluxInpaintingTool",
|
|
40
42
|
"FluxOutpaintingTool",
|
|
41
43
|
"FluxIPAdapterRefTool",
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from .pipeline import (
|
|
2
|
+
BaseConfig,
|
|
3
|
+
AttentionConfig,
|
|
4
|
+
OptimizationConfig,
|
|
5
|
+
ParallelConfig,
|
|
6
|
+
SDPipelineConfig,
|
|
7
|
+
SDXLPipelineConfig,
|
|
8
|
+
FluxPipelineConfig,
|
|
9
|
+
WanPipelineConfig,
|
|
10
|
+
)
|
|
11
|
+
from .controlnet import ControlType
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"BaseConfig",
|
|
15
|
+
"AttentionConfig",
|
|
16
|
+
"OptimizationConfig",
|
|
17
|
+
"ParallelConfig",
|
|
18
|
+
"SDPipelineConfig",
|
|
19
|
+
"SDXLPipelineConfig",
|
|
20
|
+
"FluxPipelineConfig",
|
|
21
|
+
"WanPipelineConfig",
|
|
22
|
+
"ControlType",
|
|
23
|
+
]
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
# FLUX ControlType
|
|
5
|
+
class ControlType(Enum):
|
|
6
|
+
normal = "normal"
|
|
7
|
+
bfl_control = "bfl_control"
|
|
8
|
+
bfl_fill = "bfl_fill"
|
|
9
|
+
bfl_kontext = "bfl_kontext"
|
|
10
|
+
|
|
11
|
+
def get_in_channel(self):
|
|
12
|
+
if self in [ControlType.normal, ControlType.bfl_kontext]:
|
|
13
|
+
return 64
|
|
14
|
+
elif self == ControlType.bfl_control:
|
|
15
|
+
return 128
|
|
16
|
+
elif self == ControlType.bfl_fill:
|
|
17
|
+
return 384
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import List, Tuple, Optional
|
|
5
|
+
|
|
6
|
+
from diffsynth_engine.configs.controlnet import ControlType
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class BaseConfig:
|
|
11
|
+
model_path: str | os.PathLike | List[str | os.PathLike]
|
|
12
|
+
model_dtype: torch.dtype
|
|
13
|
+
batch_cfg: bool = False
|
|
14
|
+
vae_tiled: bool = False
|
|
15
|
+
vae_tile_size: int | Tuple[int, int] = 256
|
|
16
|
+
vae_tile_stride: int | Tuple[int, int] = 256
|
|
17
|
+
device: str = "cuda"
|
|
18
|
+
offload_mode: Optional[str] = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class AttentionConfig:
|
|
23
|
+
dit_attn_impl: str = "auto"
|
|
24
|
+
# Sparge Attention
|
|
25
|
+
sparge_smooth_k: bool = True
|
|
26
|
+
sparge_cdfthreshd: float = 0.6
|
|
27
|
+
sparge_simthreshd1: float = 0.98
|
|
28
|
+
sparge_pvthreshd: float = 50.0
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class OptimizationConfig:
|
|
33
|
+
use_fp8_linear: bool = False
|
|
34
|
+
use_fbcache: bool = False
|
|
35
|
+
fbcache_relative_l1_threshold: float = 0.05
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class ParallelConfig:
|
|
40
|
+
parallelism: int = 1
|
|
41
|
+
use_cfg_parallel: bool = False
|
|
42
|
+
cfg_degree: Optional[int] = None
|
|
43
|
+
sp_ulysses_degree: Optional[int] = None
|
|
44
|
+
sp_ring_degree: Optional[int] = None
|
|
45
|
+
tp_degree: Optional[int] = None
|
|
46
|
+
use_fsdp: bool = False
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class SDPipelineConfig(BaseConfig):
|
|
51
|
+
model_path: str | os.PathLike | List[str | os.PathLike]
|
|
52
|
+
clip_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
53
|
+
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
54
|
+
model_dtype: torch.dtype = torch.float16
|
|
55
|
+
clip_dtype: torch.dtype = torch.float16
|
|
56
|
+
vae_dtype: torch.dtype = torch.float32
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def basic_config(
|
|
60
|
+
cls,
|
|
61
|
+
model_path: str | os.PathLike | List[str | os.PathLike],
|
|
62
|
+
device: str = "cuda",
|
|
63
|
+
offload_mode: Optional[str] = None,
|
|
64
|
+
) -> "SDPipelineConfig":
|
|
65
|
+
return cls(
|
|
66
|
+
model_path=model_path,
|
|
67
|
+
device=device,
|
|
68
|
+
offload_mode=offload_mode,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class SDXLPipelineConfig(BaseConfig):
|
|
74
|
+
model_path: str | os.PathLike | List[str | os.PathLike]
|
|
75
|
+
clip_l_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
76
|
+
clip_g_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
77
|
+
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
78
|
+
model_dtype: torch.dtype = torch.float16
|
|
79
|
+
clip_l_dtype: torch.dtype = torch.float16
|
|
80
|
+
clip_g_dtype: torch.dtype = torch.float16
|
|
81
|
+
vae_dtype: torch.dtype = torch.float32
|
|
82
|
+
|
|
83
|
+
@classmethod
|
|
84
|
+
def basic_config(
|
|
85
|
+
cls,
|
|
86
|
+
model_path: str | os.PathLike | List[str | os.PathLike],
|
|
87
|
+
device: str = "cuda",
|
|
88
|
+
offload_mode: Optional[str] = None,
|
|
89
|
+
) -> "SDXLPipelineConfig":
|
|
90
|
+
return cls(
|
|
91
|
+
model_path=model_path,
|
|
92
|
+
device=device,
|
|
93
|
+
offload_mode=offload_mode,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@dataclass
|
|
98
|
+
class FluxPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig):
|
|
99
|
+
model_path: str | os.PathLike | List[str | os.PathLike]
|
|
100
|
+
clip_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
101
|
+
t5_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
102
|
+
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
103
|
+
model_dtype: torch.dtype = torch.bfloat16
|
|
104
|
+
clip_dtype: torch.dtype = torch.bfloat16
|
|
105
|
+
t5_dtype: torch.dtype = torch.bfloat16
|
|
106
|
+
vae_dtype: torch.dtype = torch.bfloat16
|
|
107
|
+
|
|
108
|
+
load_text_encoder: bool = True
|
|
109
|
+
control_type: ControlType = ControlType.normal
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
def basic_config(
|
|
113
|
+
cls,
|
|
114
|
+
model_path: str | os.PathLike | List[str | os.PathLike],
|
|
115
|
+
device: str = "cuda",
|
|
116
|
+
parallelism: int = 1,
|
|
117
|
+
offload_mode: Optional[str] = None,
|
|
118
|
+
) -> "FluxPipelineConfig":
|
|
119
|
+
return cls(
|
|
120
|
+
model_path=model_path,
|
|
121
|
+
device=device,
|
|
122
|
+
parallelism=parallelism,
|
|
123
|
+
use_fsdp=True,
|
|
124
|
+
offload_mode=offload_mode,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def __post_init__(self):
|
|
128
|
+
init_parallel_config(self)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@dataclass
|
|
132
|
+
class WanPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig):
|
|
133
|
+
model_path: str | os.PathLike | List[str | os.PathLike]
|
|
134
|
+
t5_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
135
|
+
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
136
|
+
image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
137
|
+
model_dtype: torch.dtype = torch.bfloat16
|
|
138
|
+
t5_dtype: torch.dtype = torch.bfloat16
|
|
139
|
+
vae_dtype: torch.dtype = torch.bfloat16
|
|
140
|
+
image_encoder_dtype: torch.dtype = torch.bfloat16
|
|
141
|
+
|
|
142
|
+
shift: Optional[float] = field(default=None, init=False) # RecifitedFlowScheduler shift factor, set by model type
|
|
143
|
+
|
|
144
|
+
# override BaseConfig
|
|
145
|
+
vae_tiled: bool = True
|
|
146
|
+
vae_tile_size: Tuple[int, int] = (34, 34)
|
|
147
|
+
vae_tile_stride: Tuple[int, int] = (18, 16)
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
def basic_config(
|
|
151
|
+
cls,
|
|
152
|
+
model_path: str | os.PathLike | List[str | os.PathLike],
|
|
153
|
+
image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
|
|
154
|
+
device: str = "cuda",
|
|
155
|
+
parallelism: int = 1,
|
|
156
|
+
offload_mode: Optional[str] = None,
|
|
157
|
+
) -> "WanPipelineConfig":
|
|
158
|
+
return cls(
|
|
159
|
+
model_path=model_path,
|
|
160
|
+
image_encoder_path=image_encoder_path,
|
|
161
|
+
device=device,
|
|
162
|
+
parallelism=parallelism,
|
|
163
|
+
use_cfg_parallel=True,
|
|
164
|
+
use_fsdp=True,
|
|
165
|
+
offload_mode=offload_mode,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def __post_init__(self):
|
|
169
|
+
init_parallel_config(self)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def init_parallel_config(config: FluxPipelineConfig | WanPipelineConfig):
|
|
173
|
+
assert config.parallelism in (1, 2, 4, 8), "parallelism must be 1, 2, 4 or 8"
|
|
174
|
+
config.batch_cfg = True if config.parallelism > 1 and config.use_cfg_parallel else config.batch_cfg
|
|
175
|
+
|
|
176
|
+
if config.use_cfg_parallel is True and config.cfg_degree is not None:
|
|
177
|
+
raise ValueError("use_cfg_parallel and cfg_degree should not be specified together")
|
|
178
|
+
config.cfg_degree = (2 if config.use_cfg_parallel else 1) if config.cfg_degree is None else config.cfg_degree
|
|
179
|
+
|
|
180
|
+
if config.tp_degree is not None:
|
|
181
|
+
assert config.sp_ulysses_degree is None and config.sp_ring_degree is None, (
|
|
182
|
+
"not allowed to enable sequence parallel and tensor parallel together; "
|
|
183
|
+
"either set sp_ulysses_degree=None, sp_ring_degree=None or set tp_degree=None during pipeline initialization"
|
|
184
|
+
)
|
|
185
|
+
assert config.use_fsdp is False, (
|
|
186
|
+
"not allowed to enable fully sharded data parallel and tensor parallel together; "
|
|
187
|
+
"either set use_fsdp=False or set tp_degree=None during pipeline initialization"
|
|
188
|
+
)
|
|
189
|
+
assert config.parallelism == config.cfg_degree * config.tp_degree, (
|
|
190
|
+
f"parallelism ({config.parallelism}) must be equal to cfg_degree ({config.cfg_degree}) * tp_degree ({config.tp_degree})"
|
|
191
|
+
)
|
|
192
|
+
config.sp_ulysses_degree = 1
|
|
193
|
+
config.sp_ring_degree = 1
|
|
194
|
+
elif config.sp_ulysses_degree is None and config.sp_ring_degree is None:
|
|
195
|
+
# use ulysses if not specified
|
|
196
|
+
config.sp_ulysses_degree = config.parallelism // config.cfg_degree
|
|
197
|
+
config.sp_ring_degree = 1
|
|
198
|
+
config.tp_degree = 1
|
|
199
|
+
elif config.sp_ulysses_degree is not None and config.sp_ring_degree is not None:
|
|
200
|
+
assert config.parallelism == config.cfg_degree * config.sp_ulysses_degree * config.sp_ring_degree, (
|
|
201
|
+
f"parallelism ({config.parallelism}) must be equal to cfg_degree ({config.cfg_degree}) * "
|
|
202
|
+
f"sp_ulysses_degree ({config.sp_ulysses_degree}) * sp_ring_degree ({config.sp_ring_degree})"
|
|
203
|
+
)
|
|
204
|
+
config.tp_degree = 1
|
|
205
|
+
else:
|
|
206
|
+
raise ValueError("sp_ulysses_degree and sp_ring_degree must be specified together")
|
|
@@ -61,12 +61,33 @@ if SAGE_ATTN_AVAILABLE:
|
|
|
61
61
|
|
|
62
62
|
if SPARGE_ATTN_AVAILABLE:
|
|
63
63
|
from spas_sage_attn import spas_sage2_attn_meansim_cuda
|
|
64
|
+
from spas_sage_attn.autotune import SparseAttentionMeansim
|
|
64
65
|
|
|
65
|
-
def sparge_attn(
|
|
66
|
+
def sparge_attn(
|
|
67
|
+
q,
|
|
68
|
+
k,
|
|
69
|
+
v,
|
|
70
|
+
attn_mask=None,
|
|
71
|
+
scale=None,
|
|
72
|
+
smooth_k=True,
|
|
73
|
+
simthreshd1=0.6,
|
|
74
|
+
cdfthreshd=0.98,
|
|
75
|
+
pvthreshd=50,
|
|
76
|
+
):
|
|
66
77
|
q = q.transpose(1, 2)
|
|
67
78
|
k = k.transpose(1, 2)
|
|
68
79
|
v = v.transpose(1, 2)
|
|
69
|
-
out = spas_sage2_attn_meansim_cuda(
|
|
80
|
+
out = spas_sage2_attn_meansim_cuda(
|
|
81
|
+
q,
|
|
82
|
+
k,
|
|
83
|
+
v,
|
|
84
|
+
attn_mask=attn_mask,
|
|
85
|
+
scale=scale,
|
|
86
|
+
smooth_k=smooth_k,
|
|
87
|
+
simthreshd1=simthreshd1,
|
|
88
|
+
cdfthreshd=cdfthreshd,
|
|
89
|
+
pvthreshd=pvthreshd,
|
|
90
|
+
)
|
|
70
91
|
return out.transpose(1, 2)
|
|
71
92
|
|
|
72
93
|
|
|
@@ -91,6 +112,7 @@ def attention(
|
|
|
91
112
|
attn_impl: Optional[str] = None,
|
|
92
113
|
attn_mask: Optional[torch.Tensor] = None,
|
|
93
114
|
scale: Optional[float] = None,
|
|
115
|
+
**kwargs,
|
|
94
116
|
):
|
|
95
117
|
"""
|
|
96
118
|
q: [B, Lq, Nq, C1]
|
|
@@ -133,7 +155,17 @@ def attention(
|
|
|
133
155
|
elif attn_impl == "sage_attn":
|
|
134
156
|
return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
135
157
|
elif attn_impl == "sparge_attn":
|
|
136
|
-
return sparge_attn(
|
|
158
|
+
return sparge_attn(
|
|
159
|
+
q,
|
|
160
|
+
k,
|
|
161
|
+
v,
|
|
162
|
+
attn_mask=attn_mask,
|
|
163
|
+
scale=scale,
|
|
164
|
+
smooth_k=kwargs.get("sparge_smooth_k", True),
|
|
165
|
+
simthreshd1=kwargs.get("sparge_simthreshd1", 0.6),
|
|
166
|
+
cdfthreshd=kwargs.get("sparge_cdfthreshd", 0.98),
|
|
167
|
+
pvthreshd=kwargs.get("sparge_pvthreshd", 50),
|
|
168
|
+
)
|
|
137
169
|
else:
|
|
138
170
|
raise ValueError(f"Invalid attention implementation: {attn_impl}")
|
|
139
171
|
|
|
@@ -189,6 +221,7 @@ def long_context_attention(
|
|
|
189
221
|
attn_impl: Optional[str] = None,
|
|
190
222
|
attn_mask: Optional[torch.Tensor] = None,
|
|
191
223
|
scale: Optional[float] = None,
|
|
224
|
+
**kwargs,
|
|
192
225
|
):
|
|
193
226
|
"""
|
|
194
227
|
q: [B, Lq, Nq, C1]
|
|
@@ -226,7 +259,13 @@ def long_context_attention(
|
|
|
226
259
|
elif attn_impl == "sage_attn":
|
|
227
260
|
attn_func = LongContextAttention(attn_type=AttnType.SAGE_FP8)
|
|
228
261
|
elif attn_impl == "sparge_attn":
|
|
229
|
-
|
|
262
|
+
attn_processor = SparseAttentionMeansim()
|
|
263
|
+
# default args from spas_sage2_attn_meansim_cuda
|
|
264
|
+
attn_processor.smooth_k = torch.tensor(kwargs.get("sparge_smooth_k", True))
|
|
265
|
+
attn_processor.simthreshd1 = torch.tensor(kwargs.get("sparge_simthreshd1", 0.6))
|
|
266
|
+
attn_processor.cdfthreshd = torch.tensor(kwargs.get("sparge_cdfthreshd", 0.98))
|
|
267
|
+
attn_processor.pvthreshd = torch.tensor(kwargs.get("sparge_pvthreshd", 50))
|
|
268
|
+
attn_func = LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)
|
|
230
269
|
else:
|
|
231
270
|
raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
|
|
232
271
|
return attn_func(q, k, v, softmax_scale=scale)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Any, Dict, Optional
|
|
4
4
|
from einops import rearrange
|
|
5
5
|
from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
|
|
6
6
|
from diffsynth_engine.models.flux.flux_dit import (
|
|
@@ -87,7 +87,7 @@ class FluxControlNet(PreTrainedModel):
|
|
|
87
87
|
def __init__(
|
|
88
88
|
self,
|
|
89
89
|
condition_channels: int = 64,
|
|
90
|
-
|
|
90
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
91
91
|
device: str = "cuda:0",
|
|
92
92
|
dtype: torch.dtype = torch.bfloat16,
|
|
93
93
|
):
|
|
@@ -104,7 +104,10 @@ class FluxControlNet(PreTrainedModel):
|
|
|
104
104
|
self.x_embedder = nn.Linear(64, 3072, device=device, dtype=dtype)
|
|
105
105
|
self.controlnet_x_embedder = nn.Linear(condition_channels, 3072)
|
|
106
106
|
self.blocks = nn.ModuleList(
|
|
107
|
-
[
|
|
107
|
+
[
|
|
108
|
+
FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
|
|
109
|
+
for _ in range(6)
|
|
110
|
+
]
|
|
108
111
|
)
|
|
109
112
|
# controlnet projection
|
|
110
113
|
self.blocks_proj = nn.ModuleList(
|
|
@@ -154,7 +157,7 @@ class FluxControlNet(PreTrainedModel):
|
|
|
154
157
|
state_dict: Dict[str, torch.Tensor],
|
|
155
158
|
device: str,
|
|
156
159
|
dtype: torch.dtype,
|
|
157
|
-
|
|
160
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
158
161
|
):
|
|
159
162
|
if "controlnet_x_embedder.weight" in state_dict:
|
|
160
163
|
condition_channels = state_dict["controlnet_x_embedder.weight"].shape[1]
|
|
@@ -163,7 +166,7 @@ class FluxControlNet(PreTrainedModel):
|
|
|
163
166
|
|
|
164
167
|
with no_init_weights():
|
|
165
168
|
model = torch.nn.utils.skip_init(
|
|
166
|
-
cls, condition_channels=condition_channels,
|
|
169
|
+
cls, condition_channels=condition_channels, attn_kwargs=attn_kwargs, device=device, dtype=dtype
|
|
167
170
|
)
|
|
168
171
|
model.load_state_dict(state_dict)
|
|
169
172
|
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
@@ -2,7 +2,7 @@ import json
|
|
|
2
2
|
import torch
|
|
3
3
|
import torch.nn as nn
|
|
4
4
|
import numpy as np
|
|
5
|
-
from typing import Dict, Optional
|
|
5
|
+
from typing import Any, Dict, Optional
|
|
6
6
|
from einops import rearrange
|
|
7
7
|
|
|
8
8
|
from diffsynth_engine.models.basic.transformer_helper import (
|
|
@@ -177,7 +177,7 @@ class FluxDoubleAttention(nn.Module):
|
|
|
177
177
|
dim_b,
|
|
178
178
|
num_heads,
|
|
179
179
|
head_dim,
|
|
180
|
-
|
|
180
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
181
181
|
device: str = "cuda:0",
|
|
182
182
|
dtype: torch.dtype = torch.bfloat16,
|
|
183
183
|
):
|
|
@@ -195,7 +195,7 @@ class FluxDoubleAttention(nn.Module):
|
|
|
195
195
|
|
|
196
196
|
self.a_to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype)
|
|
197
197
|
self.b_to_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype)
|
|
198
|
-
self.
|
|
198
|
+
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
199
199
|
|
|
200
200
|
def attention_callback(self, attn_out_a, attn_out_b, x_a, x_b, q_a, q_b, k_a, k_b, v_a, v_b, rope_emb, image_emb):
|
|
201
201
|
return attn_out_a, attn_out_b
|
|
@@ -207,7 +207,7 @@ class FluxDoubleAttention(nn.Module):
|
|
|
207
207
|
k = torch.cat([self.norm_k_b(k_b), self.norm_k_a(k_a)], dim=1)
|
|
208
208
|
v = torch.cat([v_b, v_a], dim=1)
|
|
209
209
|
q, k = apply_rope(q, k, rope_emb)
|
|
210
|
-
attn_out = attention_ops.attention(q, k, v,
|
|
210
|
+
attn_out = attention_ops.attention(q, k, v, **self.attn_kwargs)
|
|
211
211
|
attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
|
|
212
212
|
text_out, image_out = attn_out[:, : text.shape[1]], attn_out[:, text.shape[1] :]
|
|
213
213
|
image_out, text_out = self.attention_callback(
|
|
@@ -232,13 +232,13 @@ class FluxDoubleTransformerBlock(nn.Module):
|
|
|
232
232
|
self,
|
|
233
233
|
dim,
|
|
234
234
|
num_heads,
|
|
235
|
-
|
|
235
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
236
236
|
device: str = "cuda:0",
|
|
237
237
|
dtype: torch.dtype = torch.bfloat16,
|
|
238
238
|
):
|
|
239
239
|
super().__init__()
|
|
240
240
|
self.attn = FluxDoubleAttention(
|
|
241
|
-
dim, dim, num_heads, dim // num_heads,
|
|
241
|
+
dim, dim, num_heads, dim // num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype
|
|
242
242
|
)
|
|
243
243
|
# Image
|
|
244
244
|
self.norm_msa_a = AdaLayerNormZero(dim, device=device, dtype=dtype)
|
|
@@ -278,7 +278,7 @@ class FluxSingleAttention(nn.Module):
|
|
|
278
278
|
self,
|
|
279
279
|
dim,
|
|
280
280
|
num_heads,
|
|
281
|
-
|
|
281
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
282
282
|
device: str = "cuda:0",
|
|
283
283
|
dtype: torch.dtype = torch.bfloat16,
|
|
284
284
|
):
|
|
@@ -287,7 +287,7 @@ class FluxSingleAttention(nn.Module):
|
|
|
287
287
|
self.to_qkv = nn.Linear(dim, dim * 3, device=device, dtype=dtype)
|
|
288
288
|
self.norm_q_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype)
|
|
289
289
|
self.norm_k_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype)
|
|
290
|
-
self.
|
|
290
|
+
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
291
291
|
|
|
292
292
|
def attention_callback(self, attn_out, x, q, k, v, rope_emb, image_emb):
|
|
293
293
|
return attn_out
|
|
@@ -295,7 +295,7 @@ class FluxSingleAttention(nn.Module):
|
|
|
295
295
|
def forward(self, x, rope_emb, image_emb):
|
|
296
296
|
q, k, v = rearrange(self.to_qkv(x), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
|
|
297
297
|
q, k = apply_rope(self.norm_q_a(q), self.norm_k_a(k), rope_emb)
|
|
298
|
-
attn_out = attention_ops.attention(q, k, v,
|
|
298
|
+
attn_out = attention_ops.attention(q, k, v, **self.attn_kwargs)
|
|
299
299
|
attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
|
|
300
300
|
return self.attention_callback(attn_out=attn_out, x=x, q=q, k=k, v=v, rope_emb=rope_emb, image_emb=image_emb)
|
|
301
301
|
|
|
@@ -305,14 +305,14 @@ class FluxSingleTransformerBlock(nn.Module):
|
|
|
305
305
|
self,
|
|
306
306
|
dim,
|
|
307
307
|
num_heads,
|
|
308
|
-
|
|
308
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
309
309
|
device: str = "cuda:0",
|
|
310
310
|
dtype: torch.dtype = torch.bfloat16,
|
|
311
311
|
):
|
|
312
312
|
super().__init__()
|
|
313
313
|
self.dim = dim
|
|
314
314
|
self.norm = AdaLayerNormZero(dim, device=device, dtype=dtype)
|
|
315
|
-
self.attn = FluxSingleAttention(dim, num_heads,
|
|
315
|
+
self.attn = FluxSingleAttention(dim, num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
|
|
316
316
|
self.mlp = nn.Sequential(
|
|
317
317
|
nn.Linear(dim, dim * 4),
|
|
318
318
|
nn.GELU(approximate="tanh"),
|
|
@@ -333,7 +333,7 @@ class FluxDiT(PreTrainedModel):
|
|
|
333
333
|
def __init__(
|
|
334
334
|
self,
|
|
335
335
|
in_channel: int = 64,
|
|
336
|
-
|
|
336
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
337
337
|
device: str = "cuda:0",
|
|
338
338
|
dtype: torch.dtype = torch.bfloat16,
|
|
339
339
|
):
|
|
@@ -351,10 +351,16 @@ class FluxDiT(PreTrainedModel):
|
|
|
351
351
|
self.x_embedder = nn.Linear(in_channel, 3072, device=device, dtype=dtype)
|
|
352
352
|
|
|
353
353
|
self.blocks = nn.ModuleList(
|
|
354
|
-
[
|
|
354
|
+
[
|
|
355
|
+
FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
|
|
356
|
+
for _ in range(19)
|
|
357
|
+
]
|
|
355
358
|
)
|
|
356
359
|
self.single_blocks = nn.ModuleList(
|
|
357
|
-
[
|
|
360
|
+
[
|
|
361
|
+
FluxSingleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
|
|
362
|
+
for _ in range(38)
|
|
363
|
+
]
|
|
358
364
|
)
|
|
359
365
|
self.final_norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
|
|
360
366
|
self.final_proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
|
|
@@ -495,7 +501,7 @@ class FluxDiT(PreTrainedModel):
|
|
|
495
501
|
device: str,
|
|
496
502
|
dtype: torch.dtype,
|
|
497
503
|
in_channel: int = 64,
|
|
498
|
-
|
|
504
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
499
505
|
):
|
|
500
506
|
with no_init_weights():
|
|
501
507
|
model = torch.nn.utils.skip_init(
|
|
@@ -503,7 +509,7 @@ class FluxDiT(PreTrainedModel):
|
|
|
503
509
|
device=device,
|
|
504
510
|
dtype=dtype,
|
|
505
511
|
in_channel=in_channel,
|
|
506
|
-
|
|
512
|
+
attn_kwargs=attn_kwargs,
|
|
507
513
|
)
|
|
508
514
|
model = model.requires_grad_(False) # for loading gguf
|
|
509
515
|
model.load_state_dict(state_dict, assign=True)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import numpy as np
|
|
3
|
-
from typing import Dict, Optional
|
|
3
|
+
from typing import Any, Dict, Optional
|
|
4
4
|
|
|
5
5
|
from diffsynth_engine.models.utils import no_init_weights
|
|
6
6
|
from diffsynth_engine.utils.gguf import gguf_inference
|
|
@@ -21,12 +21,12 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
21
21
|
def __init__(
|
|
22
22
|
self,
|
|
23
23
|
in_channel: int = 64,
|
|
24
|
-
|
|
24
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
25
25
|
device: str = "cuda:0",
|
|
26
26
|
dtype: torch.dtype = torch.bfloat16,
|
|
27
27
|
relative_l1_threshold: float = 0.05,
|
|
28
28
|
):
|
|
29
|
-
super().__init__(in_channel=in_channel,
|
|
29
|
+
super().__init__(in_channel=in_channel, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
|
|
30
30
|
self.relative_l1_threshold = relative_l1_threshold
|
|
31
31
|
self.step_count = 0
|
|
32
32
|
self.num_inference_steps = 0
|
|
@@ -187,7 +187,7 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
187
187
|
device: str,
|
|
188
188
|
dtype: torch.dtype,
|
|
189
189
|
in_channel: int = 64,
|
|
190
|
-
|
|
190
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
191
191
|
fb_cache_relative_l1_threshold: float = 0.05,
|
|
192
192
|
):
|
|
193
193
|
with no_init_weights():
|
|
@@ -196,7 +196,7 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
196
196
|
device=device,
|
|
197
197
|
dtype=dtype,
|
|
198
198
|
in_channel=in_channel,
|
|
199
|
-
|
|
199
|
+
attn_kwargs=attn_kwargs,
|
|
200
200
|
fb_cache_relative_l1_threshold=fb_cache_relative_l1_threshold,
|
|
201
201
|
)
|
|
202
202
|
model = model.requires_grad_(False) # for loading gguf
|
|
@@ -2,7 +2,7 @@ import torch
|
|
|
2
2
|
from einops import rearrange
|
|
3
3
|
from torch import nn
|
|
4
4
|
from PIL import Image
|
|
5
|
-
from typing import Dict, List
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
6
|
from functools import partial
|
|
7
7
|
from diffsynth_engine.models.utils import no_init_weights
|
|
8
8
|
from diffsynth_engine.models.text_encoder.siglip import SiglipImageEncoder
|
|
@@ -19,7 +19,7 @@ class FluxIPAdapterAttention(nn.Module):
|
|
|
19
19
|
dim: int = 3072,
|
|
20
20
|
head_num: int = 24,
|
|
21
21
|
scale: float = 1.0,
|
|
22
|
-
|
|
22
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
23
23
|
device: str = "cuda:0",
|
|
24
24
|
dtype: torch.dtype = torch.bfloat16,
|
|
25
25
|
):
|
|
@@ -29,12 +29,12 @@ class FluxIPAdapterAttention(nn.Module):
|
|
|
29
29
|
self.to_v_ip = nn.Linear(image_emb_dim, dim, device=device, dtype=dtype, bias=False)
|
|
30
30
|
self.head_num = head_num
|
|
31
31
|
self.scale = scale
|
|
32
|
-
self.
|
|
32
|
+
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
33
33
|
|
|
34
34
|
def forward(self, query: torch.Tensor, image_emb: torch.Tensor):
|
|
35
35
|
key = rearrange(self.norm_k(self.to_k_ip(image_emb)), "b s (h d) -> b s h d", h=self.head_num)
|
|
36
36
|
value = rearrange(self.to_v_ip(image_emb), "b s (h d) -> b s h d", h=self.head_num)
|
|
37
|
-
attn_out = attention(query, key, value)
|
|
37
|
+
attn_out = attention(query, key, value, **self.attn_kwargs)
|
|
38
38
|
return self.scale * rearrange(attn_out, "b s h d -> b s (h d)")
|
|
39
39
|
|
|
40
40
|
@classmethod
|
|
@@ -142,7 +142,7 @@ class FluxIPAdapter(PreTrainedModel):
|
|
|
142
142
|
single_attention_callback, self=dit.single_blocks[i].attn
|
|
143
143
|
)
|
|
144
144
|
|
|
145
|
-
def
|
|
145
|
+
def encode_image(self, image: Image.Image) -> torch.Tensor:
|
|
146
146
|
image_emb = self.image_encoder(image)
|
|
147
147
|
return self.image_proj(image_emb)
|
|
148
148
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
|
-
from typing import Dict
|
|
3
|
+
from typing import Dict
|
|
4
4
|
|
|
5
5
|
from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
|
|
6
6
|
from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
|
|
@@ -570,7 +570,6 @@ class SDControlNet(PreTrainedModel):
|
|
|
570
570
|
|
|
571
571
|
def __init__(
|
|
572
572
|
self,
|
|
573
|
-
attn_impl: Optional[str] = None,
|
|
574
573
|
device: str = "cuda:0",
|
|
575
574
|
dtype: torch.dtype = torch.bfloat16,
|
|
576
575
|
):
|
|
@@ -666,10 +665,9 @@ class SDControlNet(PreTrainedModel):
|
|
|
666
665
|
state_dict: Dict[str, torch.Tensor],
|
|
667
666
|
device: str,
|
|
668
667
|
dtype: torch.dtype,
|
|
669
|
-
attn_impl: Optional[str] = None,
|
|
670
668
|
):
|
|
671
669
|
with no_init_weights():
|
|
672
|
-
model = torch.nn.utils.skip_init(cls,
|
|
670
|
+
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
|
|
673
671
|
model.load_state_dict(state_dict)
|
|
674
672
|
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
675
673
|
return model
|