diffsynth-engine 0.2.0__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/.gitignore +3 -1
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/PKG-INFO +3 -3
- diffsynth_engine-0.2.1/assets/dingtalk.png +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/__init__.py +7 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/base.py +5 -5
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/attention.py +22 -6
- diffsynth_engine-0.2.1/diffsynth_engine/models/components/siglip.py +169 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/flux/__init__.py +2 -0
- diffsynth_engine-0.2.1/diffsynth_engine/models/flux/flux_controlnet.py +160 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/flux/flux_dit.py +16 -17
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd3/sd3_dit.py +1 -7
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sdxl/sdxl_unet.py +1 -7
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/wan_dit.py +1 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/__init__.py +2 -1
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/base.py +26 -28
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/flux_image.py +179 -32
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/sd_image.py +32 -7
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/sdxl_image.py +32 -7
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/pipelines/wan_video.py +51 -27
- diffsynth_engine-0.2.1/diffsynth_engine/tools/__init__.py +4 -0
- diffsynth_engine-0.2.1/diffsynth_engine/tools/flux_inpainting.py +50 -0
- diffsynth_engine-0.2.1/diffsynth_engine/tools/flux_outpainting.py +58 -0
- diffsynth_engine-0.2.1/diffsynth_engine/utils/env.py +10 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/flag.py +2 -2
- diffsynth_engine-0.2.1/diffsynth_engine/utils/image.py +25 -0
- diffsynth_engine-0.2.1/diffsynth_engine/utils/loader.py +32 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/parallel.py +15 -4
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/PKG-INFO +3 -3
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/SOURCES.txt +17 -1
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/requires.txt +2 -2
- diffsynth_engine-0.2.1/examples/i2v_input.jpg +0 -0
- diffsynth_engine-0.2.1/examples/wan_image_to_video.py +35 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/examples/wan_text_to_video.py +1 -1
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/pyproject.toml +2 -2
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/common/test_case.py +1 -1
- diffsynth_engine-0.2.1/tests/data/expect/flux/flux_inpainting.png +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/flux/flux_outpainting.png +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/flux/flux_union_pro_canny.png +0 -0
- diffsynth_engine-0.2.1/tests/data/expect/test_siglip_image_encoder.safetensors +0 -0
- diffsynth_engine-0.2.1/tests/data/input/canny.png +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/flux/test_flux_dit.py +1 -1
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/flux/test_flux_text_encoder.py +1 -2
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/flux/test_flux_vae.py +1 -2
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/sd/test_sd_text_encoder.py +1 -2
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/sd/test_sd_vae.py +1 -2
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/sdxl/test_sdxl_text_encoder.py +1 -1
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/sdxl/test_sdxl_vae.py +1 -2
- diffsynth_engine-0.2.1/tests/test_models/test_siglip.py +17 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/wan/test_wan_vae.py +1 -2
- diffsynth_engine-0.2.1/tests/test_pipelines/test_flux_controlnet.py +32 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_pipelines/test_flux_image.py +0 -13
- diffsynth_engine-0.2.1/tests/test_tools/__init__.py +0 -0
- diffsynth_engine-0.2.1/tests/test_tools/test_flux_tools.py +31 -0
- diffsynth_engine-0.2.0/assets/dingtalk.png +0 -0
- diffsynth_engine-0.2.0/diffsynth_engine/utils/env.py +0 -7
- diffsynth_engine-0.2.0/diffsynth_engine/utils/loader.py +0 -17
- diffsynth_engine-0.2.0/tests/data/expect/flux/flux_inpainting.png +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/.github/workflows/python-publish.yml +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/.pre-commit-config.yaml +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/LICENSE +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/README.md +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/assets/showcase.jpeg +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/components/vae.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/wan/dit/14b-i2v.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/models/wan/dit/14b-t2v.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/kernels/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/lora.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/timestep.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/transformer_helper.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/unet_helper.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/components/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/components/clip.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/components/t5.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/components/vae.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/flux/flux_vae.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd/sd_unet.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd/sd_vae.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd3/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sdxl/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/utils.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/wan/wan_vae.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/base.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/clip.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/t5.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/tokenizers/wan.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/constants.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/download.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/fp8_linear.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/gguf.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/lock.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/logging.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/offload.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/prompt.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/utils/video.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine.egg-info/top_level.txt +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/docs/tutorial.md +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/docs/tutorial_zh.md +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/examples/flux_lora.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/examples/flux_text_to_image.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/examples/sdxl_text_to_image.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/examples/wan_lora.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/setup.cfg +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/setup.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/common/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/common/utils.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/beta_20steps.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/ddim_20steps.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/euler_i10.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/exponential_20steps.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/flow_match_euler_i10.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/karras_20steps.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/output.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/recifited_flow_20steps_flux.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/scaled_linear_20steps.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/algorithm/sgm_uniform_20steps.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/flux/flux_dit.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/flux/flux_lora.png +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/flux/flux_text_encoder_1.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/flux/flux_text_encoder_2.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/flux/flux_txt2img.png +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/flux/flux_vae.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sd/sd_inpainting.png +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sd/sd_lora.png +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sd/sd_text_encoder.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sd/sd_txt2img.png +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sd/sd_unet.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sd/sd_vae.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sdxl/sdxl_inpainting.png +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sdxl/sdxl_lora.png +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sdxl/sdxl_text_encoder_1.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sdxl/sdxl_text_encoder_2.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sdxl/sdxl_txt2img.png +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sdxl/sdxl_unet.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/sdxl/sdxl_vae.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/expect/wan/wan_vae.safetensors +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/input/astronaut_320_320.mp4 +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/input/mask_image.png +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/input/test_image.png +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/input/wukong_1024_1024.png +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/data/input/wukong_480_480.png +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_algorithm/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_algorithm/test_sampler.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_algorithm/test_scheduler.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/flux/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/sd/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/sd/test_sd_unet.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/sdxl/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_models/sdxl/test_sdxl_unet.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_pipelines/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_pipelines/test_sd_image.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_pipelines/test_sdxl_image.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_pipelines/test_wan_video.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_pipelines/test_wan_video_gguf.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_pipelines/test_wan_video_tp.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_tokenizers/__init__.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_tokenizers/test_clip.py +0 -0
- {diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/tests/test_tokenizers/test_t5.py +0 -0
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: diffsynth_engine
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Author: MuseAI x ModelScope
|
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
|
6
6
|
Classifier: Operating System :: OS Independent
|
|
7
7
|
Requires-Python: >=3.10
|
|
8
8
|
License-File: LICENSE
|
|
9
|
-
Requires-Dist: torch>=2.
|
|
9
|
+
Requires-Dist: torch>=2.6
|
|
10
10
|
Requires-Dist: torchvision
|
|
11
11
|
Requires-Dist: xformers; sys_platform == "linux"
|
|
12
12
|
Requires-Dist: safetensors
|
|
@@ -22,7 +22,7 @@ Requires-Dist: scipy
|
|
|
22
22
|
Requires-Dist: torchsde
|
|
23
23
|
Requires-Dist: pillow
|
|
24
24
|
Requires-Dist: imageio[ffmpeg]
|
|
25
|
-
Requires-Dist: yunchang
|
|
25
|
+
Requires-Dist: yunchang; sys_platform == "linux"
|
|
26
26
|
Provides-Extra: dev
|
|
27
27
|
Requires-Dist: diffusers==0.31.0; extra == "dev"
|
|
28
28
|
Requires-Dist: transformers==4.45.2; extra == "dev"
|
|
Binary file
|
|
@@ -7,12 +7,16 @@ from .pipelines import (
|
|
|
7
7
|
SDXLModelConfig,
|
|
8
8
|
SDModelConfig,
|
|
9
9
|
WanModelConfig,
|
|
10
|
+
ControlNetParams,
|
|
10
11
|
)
|
|
12
|
+
from .models.flux import FluxControlNet
|
|
11
13
|
from .utils.download import fetch_model, fetch_modelscope_model, fetch_civitai_model
|
|
12
14
|
from .utils.video import load_video, save_video
|
|
15
|
+
from .tools import FluxInpaintingTool, FluxOutpaintingTool
|
|
13
16
|
|
|
14
17
|
__all__ = [
|
|
15
18
|
"FluxImagePipeline",
|
|
19
|
+
"FluxControlNet",
|
|
16
20
|
"SDXLImagePipeline",
|
|
17
21
|
"SDImagePipeline",
|
|
18
22
|
"WanVideoPipeline",
|
|
@@ -20,6 +24,9 @@ __all__ = [
|
|
|
20
24
|
"SDXLModelConfig",
|
|
21
25
|
"SDModelConfig",
|
|
22
26
|
"WanModelConfig",
|
|
27
|
+
"FluxInpaintingTool",
|
|
28
|
+
"FluxOutpaintingTool",
|
|
29
|
+
"ControlNetParams",
|
|
23
30
|
"fetch_model",
|
|
24
31
|
"fetch_modelscope_model",
|
|
25
32
|
"fetch_civitai_model",
|
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import torch
|
|
3
3
|
import torch.nn as nn
|
|
4
|
-
from typing import Dict, List,
|
|
5
|
-
from
|
|
6
|
-
|
|
4
|
+
from typing import Dict, Union, List, Any
|
|
5
|
+
from diffsynth_engine.utils.loader import load_file
|
|
7
6
|
from diffsynth_engine.models.basic.lora import LoRALinear, LoRAConv2d
|
|
8
7
|
from diffsynth_engine.models.utils import no_init_weights
|
|
9
8
|
|
|
@@ -22,18 +21,19 @@ class PreTrainedModel(nn.Module):
|
|
|
22
21
|
|
|
23
22
|
@classmethod
|
|
24
23
|
def from_pretrained(cls, pretrained_model_path: Union[str, os.PathLike], device: str, dtype: torch.dtype, **kwargs):
|
|
25
|
-
state_dict = load_file(pretrained_model_path
|
|
24
|
+
state_dict = load_file(pretrained_model_path)
|
|
26
25
|
return cls.from_state_dict(state_dict, device=device, dtype=dtype, **kwargs)
|
|
27
26
|
|
|
28
27
|
@classmethod
|
|
29
28
|
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype, **kwargs):
|
|
30
29
|
with no_init_weights():
|
|
31
30
|
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, **kwargs)
|
|
31
|
+
model.to_empty(device=device)
|
|
32
32
|
model.load_state_dict(state_dict)
|
|
33
33
|
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
34
34
|
return model
|
|
35
35
|
|
|
36
|
-
def load_loras(self, lora_args: List[Dict[str,
|
|
36
|
+
def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = True):
|
|
37
37
|
for args in lora_args:
|
|
38
38
|
key = args["name"]
|
|
39
39
|
module = self.get_submodule(key)
|
{diffsynth_engine-0.2.0 → diffsynth_engine-0.2.1}/diffsynth_engine/models/basic/attention.py
RENAMED
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
|
-
from einops import rearrange
|
|
3
|
+
from einops import rearrange, repeat
|
|
4
4
|
from typing import Optional
|
|
5
|
-
from yunchang import LongContextAttention
|
|
6
|
-
from yunchang.kernels import AttnType
|
|
7
5
|
|
|
6
|
+
import torch.nn.functional as F
|
|
8
7
|
from diffsynth_engine.utils import logging
|
|
9
8
|
from diffsynth_engine.utils.flag import (
|
|
10
9
|
FLASH_ATTN_3_AVAILABLE,
|
|
@@ -18,12 +17,26 @@ from diffsynth_engine.utils.flag import (
|
|
|
18
17
|
logger = logging.get_logger(__name__)
|
|
19
18
|
|
|
20
19
|
|
|
20
|
+
def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
|
|
21
|
+
padding_size = (alignment - x.shape[dim] % alignment) % alignment
|
|
22
|
+
padded_x = F.pad(x, (0, padding_size), "constant", 0)
|
|
23
|
+
return padded_x[..., : x.shape[dim]]
|
|
24
|
+
|
|
25
|
+
|
|
21
26
|
if FLASH_ATTN_3_AVAILABLE:
|
|
22
27
|
from flash_attn_interface import flash_attn_func as flash_attn3
|
|
23
28
|
if FLASH_ATTN_2_AVAILABLE:
|
|
24
29
|
from flash_attn import flash_attn_func as flash_attn2
|
|
25
30
|
if XFORMERS_AVAILABLE:
|
|
26
|
-
from xformers.ops import memory_efficient_attention
|
|
31
|
+
from xformers.ops import memory_efficient_attention
|
|
32
|
+
|
|
33
|
+
def xformers_attn(q, k, v, attn_mask=None, scale=None):
|
|
34
|
+
if attn_mask is not None:
|
|
35
|
+
attn_mask = repeat(attn_mask, "S L -> B H S L", B=q.shape[0], H=q.shape[2])
|
|
36
|
+
attn_mask = memory_align(attn_mask)
|
|
37
|
+
return memory_efficient_attention(q, k, v, attn_bias=attn_mask, scale=scale)
|
|
38
|
+
|
|
39
|
+
|
|
27
40
|
if SDPA_AVAILABLE:
|
|
28
41
|
|
|
29
42
|
def sdpa_attn(q, k, v, attn_mask=None, scale=None):
|
|
@@ -100,7 +113,7 @@ def attention(
|
|
|
100
113
|
elif FLASH_ATTN_2_AVAILABLE:
|
|
101
114
|
return flash_attn2(q, k, v, softmax_scale=scale)
|
|
102
115
|
elif XFORMERS_AVAILABLE:
|
|
103
|
-
return xformers_attn(q, k, v,
|
|
116
|
+
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
104
117
|
elif SDPA_AVAILABLE:
|
|
105
118
|
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
106
119
|
else:
|
|
@@ -113,7 +126,7 @@ def attention(
|
|
|
113
126
|
elif attn_impl == "flash_attn_2":
|
|
114
127
|
return flash_attn2(q, k, v, softmax_scale=scale)
|
|
115
128
|
elif attn_impl == "xformers":
|
|
116
|
-
return xformers_attn(q, k, v,
|
|
129
|
+
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
117
130
|
elif attn_impl == "sdpa":
|
|
118
131
|
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
119
132
|
elif attn_impl == "sage_attn":
|
|
@@ -181,6 +194,9 @@ def long_context_attention(
|
|
|
181
194
|
k: [B, Lk, Nk, C1]
|
|
182
195
|
v: [B, Lk, Nk, C2]
|
|
183
196
|
"""
|
|
197
|
+
from yunchang import LongContextAttention
|
|
198
|
+
from yunchang.kernels import AttnType
|
|
199
|
+
|
|
184
200
|
assert attn_impl in [
|
|
185
201
|
None,
|
|
186
202
|
"auto",
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
import numpy as np
|
|
6
|
+
from einops import rearrange
|
|
7
|
+
from typing import Union, List
|
|
8
|
+
from PIL import Image
|
|
9
|
+
from diffsynth_engine.models.basic.attention import Attention
|
|
10
|
+
from diffsynth_engine.utils.loader import load_file
|
|
11
|
+
from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SiglipVisionEmbeddings(nn.Module):
|
|
15
|
+
def __init__(
|
|
16
|
+
self, num_channels: int, num_positions: int, hidden_size: int, patch_size: int, device: str, dtype: torch.dtype
|
|
17
|
+
):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.patch_embedding = nn.Conv2d(
|
|
20
|
+
in_channels=num_channels,
|
|
21
|
+
out_channels=hidden_size,
|
|
22
|
+
kernel_size=patch_size,
|
|
23
|
+
stride=patch_size,
|
|
24
|
+
padding="valid",
|
|
25
|
+
device=device,
|
|
26
|
+
dtype=dtype,
|
|
27
|
+
)
|
|
28
|
+
self.position_embedding = nn.Embedding(num_positions, hidden_size, device=device, dtype=dtype)
|
|
29
|
+
self.position_ids = torch.arange(num_positions).expand((1, -1))
|
|
30
|
+
|
|
31
|
+
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
|
|
32
|
+
target_dtype = self.patch_embedding.weight.dtype
|
|
33
|
+
target_device = self.patch_embedding.weight.device
|
|
34
|
+
self.position_ids = self.position_ids.to(target_device)
|
|
35
|
+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
|
36
|
+
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
|
37
|
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
|
38
|
+
return embeddings
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class SiglipMLP(nn.Module):
|
|
42
|
+
def __init__(self, hidden_size, inner_dim, device, dtype):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.fc1 = nn.Linear(hidden_size, inner_dim, device=device, dtype=dtype)
|
|
45
|
+
self.fc2 = nn.Linear(inner_dim, hidden_size, device=device, dtype=dtype)
|
|
46
|
+
|
|
47
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
48
|
+
hidden_states = self.fc1(hidden_states)
|
|
49
|
+
hidden_states = F.gelu(hidden_states, approximate="tanh")
|
|
50
|
+
hidden_states = self.fc2(hidden_states)
|
|
51
|
+
return hidden_states
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class SiglipEncoderLayer(nn.Module):
|
|
55
|
+
def __init__(self, hidden_size: int, inner_dim: int, num_heads: int, eps: float, device: str, dtype: torch.dtype):
|
|
56
|
+
super().__init__()
|
|
57
|
+
self.layer_norm1 = nn.LayerNorm(hidden_size, eps=eps)
|
|
58
|
+
self.self_attn = Attention(
|
|
59
|
+
q_dim=hidden_size,
|
|
60
|
+
num_heads=num_heads,
|
|
61
|
+
head_dim=hidden_size // num_heads,
|
|
62
|
+
bias_q=True,
|
|
63
|
+
bias_kv=True,
|
|
64
|
+
bias_out=True,
|
|
65
|
+
)
|
|
66
|
+
self.layer_norm2 = nn.LayerNorm(hidden_size, eps=eps)
|
|
67
|
+
self.mlp = SiglipMLP(hidden_size=hidden_size, inner_dim=inner_dim, device=device, dtype=dtype)
|
|
68
|
+
|
|
69
|
+
def forward(self, x):
|
|
70
|
+
x = self.self_attn(self.layer_norm1(x)) + x
|
|
71
|
+
x = self.mlp(self.layer_norm2(x)) + x
|
|
72
|
+
return x
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
|
76
|
+
"""Multihead Attention Pooling."""
|
|
77
|
+
|
|
78
|
+
def __init__(self, hidden_size, inner_dim, num_heads, eps, device, dtype) -> None:
|
|
79
|
+
super().__init__()
|
|
80
|
+
|
|
81
|
+
self.probe = nn.Parameter(data=torch.randn(1, 1, hidden_size))
|
|
82
|
+
self.attention = nn.MultiheadAttention(
|
|
83
|
+
embed_dim=hidden_size, num_heads=num_heads, batch_first=True, device=device, dtype=dtype
|
|
84
|
+
)
|
|
85
|
+
self.layernorm = nn.LayerNorm(normalized_shape=hidden_size, eps=eps, device=device, dtype=dtype)
|
|
86
|
+
self.mlp = SiglipMLP(hidden_size=hidden_size, inner_dim=inner_dim, device=device, dtype=dtype)
|
|
87
|
+
|
|
88
|
+
def forward(self, hidden_state) -> torch.Tensor:
|
|
89
|
+
batch_size = hidden_state.shape[0]
|
|
90
|
+
probe = self.probe.repeat(batch_size, 1, 1)
|
|
91
|
+
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
|
|
92
|
+
residual = hidden_state
|
|
93
|
+
hidden_state = self.layernorm(hidden_state)
|
|
94
|
+
hidden_state = residual + self.mlp(hidden_state)
|
|
95
|
+
return hidden_state[:, 0]
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class SiglipVisionTransformer(nn.Module):
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
hidden_size: int = 1152,
|
|
102
|
+
num_channels: int = 3,
|
|
103
|
+
image_size: int = 384,
|
|
104
|
+
patch_size: int = 14,
|
|
105
|
+
layer_num: int = 27,
|
|
106
|
+
inner_dim: int = 4304,
|
|
107
|
+
num_heads: int = 16,
|
|
108
|
+
eps: float = 1e-06,
|
|
109
|
+
device: str = "cpu",
|
|
110
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
111
|
+
):
|
|
112
|
+
super().__init__()
|
|
113
|
+
self.embeddings = SiglipVisionEmbeddings(
|
|
114
|
+
num_channels=num_channels,
|
|
115
|
+
num_positions=(image_size // patch_size) ** 2,
|
|
116
|
+
hidden_size=hidden_size,
|
|
117
|
+
patch_size=patch_size,
|
|
118
|
+
device=device,
|
|
119
|
+
dtype=dtype,
|
|
120
|
+
)
|
|
121
|
+
self.layers = nn.ModuleList(
|
|
122
|
+
[SiglipEncoderLayer(hidden_size, inner_dim, num_heads, eps, device, dtype) for _ in range(layer_num)]
|
|
123
|
+
)
|
|
124
|
+
self.post_layernorm = nn.LayerNorm(hidden_size, eps=eps, device=device, dtype=dtype)
|
|
125
|
+
self.head = SiglipMultiheadAttentionPoolingHead(
|
|
126
|
+
hidden_size, inner_dim=inner_dim, num_heads=num_heads, eps=eps, device=device, dtype=dtype
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def forward(self, x):
|
|
130
|
+
x = self.embeddings(x)
|
|
131
|
+
for layer in self.layers:
|
|
132
|
+
x = layer(x)
|
|
133
|
+
x = self.post_layernorm(x)
|
|
134
|
+
x = self.head(x)
|
|
135
|
+
return x
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class SiglipImageEncoderConverter(StateDictConverter):
|
|
139
|
+
def convert(self, state_dict: dict) -> dict:
|
|
140
|
+
return state_dict
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class SiglipImageEncoder(PreTrainedModel):
|
|
144
|
+
converter = SiglipImageEncoderConverter()
|
|
145
|
+
|
|
146
|
+
def __init__(self, device: str, dtype: torch.dtype) -> None:
|
|
147
|
+
super().__init__()
|
|
148
|
+
self.image_encoder = SiglipVisionTransformer(device=device, dtype=dtype)
|
|
149
|
+
|
|
150
|
+
def image_preprocess(self, images: List[Image.Image]):
|
|
151
|
+
images = [image.resize(size=(384, 384), resample=3) for image in images]
|
|
152
|
+
rescaled_images = [np.array(image) / 255 for image in images]
|
|
153
|
+
normalized_images = [(image - 0.5) / 0.5 for image in rescaled_images]
|
|
154
|
+
image_tensor = torch.stack([torch.tensor(image) for image in normalized_images])
|
|
155
|
+
param = next(self.parameters())
|
|
156
|
+
image_tensor = image_tensor.to(param.device, param.dtype)
|
|
157
|
+
return rearrange(image_tensor, "b h w c -> b c h w")
|
|
158
|
+
|
|
159
|
+
@torch.no_grad()
|
|
160
|
+
def forward(self, images: List[Image.Image] | Image.Image):
|
|
161
|
+
if isinstance(images, Image.Image):
|
|
162
|
+
images = [images]
|
|
163
|
+
image_input = self.image_preprocess(images)
|
|
164
|
+
return self.image_encoder(image_input)
|
|
165
|
+
|
|
166
|
+
@classmethod
|
|
167
|
+
def from_pretrained(cls, pretrained_model_path: Union[str, os.PathLike], device: str, dtype: torch.dtype, **kwargs):
|
|
168
|
+
state_dict = load_file(str(pretrained_model_path))
|
|
169
|
+
return cls.from_state_dict(state_dict, device=device, dtype=dtype, **kwargs)
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from .flux_dit import FluxDiT, config as flux_dit_config
|
|
2
2
|
from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2, config as flux_text_encoder_config
|
|
3
3
|
from .flux_vae import FluxVAEDecoder, FluxVAEEncoder, config as flux_vae_config
|
|
4
|
+
from .flux_controlnet import FluxControlNet
|
|
4
5
|
|
|
5
6
|
__all__ = [
|
|
6
7
|
"FluxDiT",
|
|
8
|
+
"FluxControlNet",
|
|
7
9
|
"FluxTextEncoder1",
|
|
8
10
|
"FluxTextEncoder2",
|
|
9
11
|
"FluxVAEDecoder",
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from typing import Optional, Dict
|
|
4
|
+
from einops import rearrange
|
|
5
|
+
from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
|
|
6
|
+
from diffsynth_engine.models.flux.flux_dit import (
|
|
7
|
+
FluxJointTransformerBlock,
|
|
8
|
+
RoPEEmbedding,
|
|
9
|
+
TimestepEmbeddings,
|
|
10
|
+
)
|
|
11
|
+
from diffsynth_engine.models.utils import no_init_weights
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class FluxControlNetStateDictConverter(StateDictConverter):
|
|
15
|
+
def __init__(self):
|
|
16
|
+
super().__init__()
|
|
17
|
+
|
|
18
|
+
def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
19
|
+
new_state_dict = {}
|
|
20
|
+
for key, value in state_dict.items():
|
|
21
|
+
new_key = key
|
|
22
|
+
if "attn.to_q" in new_key:
|
|
23
|
+
q = state_dict[new_key]
|
|
24
|
+
k = state_dict[new_key.replace("attn.to_q", "attn.to_k")]
|
|
25
|
+
v = state_dict[new_key.replace("attn.to_q", "attn.to_v")]
|
|
26
|
+
new_key = new_key.replace("transformer_blocks", "blocks")
|
|
27
|
+
new_key = new_key.replace("attn.to_q", "attn.a_to_qkv")
|
|
28
|
+
new_state_dict[new_key] = torch.cat((q, k, v), dim=0)
|
|
29
|
+
elif "attn.add_q_proj" in new_key:
|
|
30
|
+
q = state_dict[new_key]
|
|
31
|
+
k = state_dict[new_key.replace("attn.add_q_proj", "attn.add_k_proj")]
|
|
32
|
+
v = state_dict[new_key.replace("attn.add_q_proj", "attn.add_v_proj")]
|
|
33
|
+
new_key = new_key.replace("transformer_blocks", "blocks")
|
|
34
|
+
new_key = new_key.replace("attn.add_q_proj", "attn.b_to_qkv")
|
|
35
|
+
new_state_dict[new_key.replace("attn.add_q_proj", "attn.b_to_qkv")] = torch.cat((q, k, v), dim=0)
|
|
36
|
+
elif (
|
|
37
|
+
"attn.to_k" in new_key
|
|
38
|
+
or "attn.to_v" in new_key
|
|
39
|
+
or "attn.add_k_proj" in new_key
|
|
40
|
+
or "attn.add_v_proj" in new_key
|
|
41
|
+
):
|
|
42
|
+
continue
|
|
43
|
+
else:
|
|
44
|
+
new_key = new_key.replace("transformer_blocks", "blocks")
|
|
45
|
+
new_key = new_key.replace("controlnet_blocks", "blocks_proj")
|
|
46
|
+
new_key = new_key.replace("time_text_embed.guidance_embedder", "guidance_embedder")
|
|
47
|
+
new_key = new_key.replace("time_text_embed.timestep_embedder", "time_embedder")
|
|
48
|
+
new_key = new_key.replace("time_text_embed.text_embedder.linear_1", "pooled_text_embedder.0")
|
|
49
|
+
new_key = new_key.replace("time_text_embed.text_embedder.linear_2", "pooled_text_embedder.2")
|
|
50
|
+
new_key = new_key.replace("transformer_blocks", "blocks")
|
|
51
|
+
new_key = new_key.replace("time_embedder.linear_1", "time_embedder.timestep_embedder.0")
|
|
52
|
+
new_key = new_key.replace("time_embedder.linear_2", "time_embedder.timestep_embedder.2")
|
|
53
|
+
new_key = new_key.replace("guidance_embedder.linear_1", "guidance_embedder.timestep_embedder.0")
|
|
54
|
+
new_key = new_key.replace("guidance_embedder.linear_2", "guidance_embedder.timestep_embedder.2")
|
|
55
|
+
# joint block
|
|
56
|
+
new_key = new_key.replace("norm1.linear", "norm1_a.linear")
|
|
57
|
+
new_key = new_key.replace("norm1_context.linear", "norm1_b.linear")
|
|
58
|
+
new_key = new_key.replace("attn.to_out.0", "attn.a_to_out")
|
|
59
|
+
new_key = new_key.replace("attn.to_add_out", "attn.b_to_out")
|
|
60
|
+
new_key = new_key.replace("attn.norm_q", "attn.norm_q_a")
|
|
61
|
+
new_key = new_key.replace("attn.norm_k", "attn.norm_k_a")
|
|
62
|
+
new_key = new_key.replace("attn.norm_added_q", "attn.norm_q_b")
|
|
63
|
+
new_key = new_key.replace("attn.norm_added_k", "attn.norm_k_b")
|
|
64
|
+
new_key = new_key.replace("ff.net", "ff_a")
|
|
65
|
+
new_key = new_key.replace("ff_context.net", "ff_b")
|
|
66
|
+
new_key = new_key.replace("0.proj", "0")
|
|
67
|
+
new_state_dict[new_key] = value
|
|
68
|
+
return new_state_dict
|
|
69
|
+
|
|
70
|
+
def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
71
|
+
return self._from_diffusers(state_dict)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class FluxControlNet(PreTrainedModel):
|
|
75
|
+
converter = FluxControlNetStateDictConverter()
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
condition_channels: int = 64,
|
|
80
|
+
attn_impl: Optional[str] = None,
|
|
81
|
+
device: str = "cuda:0",
|
|
82
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
83
|
+
):
|
|
84
|
+
super().__init__()
|
|
85
|
+
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
|
86
|
+
self.time_embedder = TimestepEmbeddings(256, 3072, device=device, dtype=dtype)
|
|
87
|
+
self.guidance_embedder = TimestepEmbeddings(256, 3072, device=device, dtype=dtype)
|
|
88
|
+
self.pooled_text_embedder = nn.Sequential(
|
|
89
|
+
nn.Linear(768, 3072, device=device, dtype=dtype),
|
|
90
|
+
nn.SiLU(),
|
|
91
|
+
nn.Linear(3072, 3072, device=device, dtype=dtype),
|
|
92
|
+
)
|
|
93
|
+
self.context_embedder = nn.Linear(4096, 3072, device=device, dtype=dtype)
|
|
94
|
+
self.x_embedder = nn.Linear(64, 3072, device=device, dtype=dtype)
|
|
95
|
+
self.controlnet_x_embedder = nn.Linear(condition_channels, 3072)
|
|
96
|
+
self.blocks = nn.ModuleList(
|
|
97
|
+
[FluxJointTransformerBlock(3072, 24, attn_impl=attn_impl, device=device, dtype=dtype) for _ in range(6)]
|
|
98
|
+
)
|
|
99
|
+
# controlnet projection
|
|
100
|
+
self.blocks_proj = nn.ModuleList(
|
|
101
|
+
[nn.Linear(3072, 3072, device=device, dtype=dtype) for _ in range(len(self.blocks))]
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def patchify(self, hidden_states):
|
|
105
|
+
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
|
106
|
+
return hidden_states
|
|
107
|
+
|
|
108
|
+
def forward(
|
|
109
|
+
self,
|
|
110
|
+
hidden_states,
|
|
111
|
+
control_condition,
|
|
112
|
+
control_scale,
|
|
113
|
+
timestep,
|
|
114
|
+
prompt_emb,
|
|
115
|
+
pooled_prompt_emb,
|
|
116
|
+
guidance,
|
|
117
|
+
image_ids,
|
|
118
|
+
text_ids,
|
|
119
|
+
):
|
|
120
|
+
hidden_states = self.patchify(hidden_states)
|
|
121
|
+
control_condition = self.patchify(control_condition)
|
|
122
|
+
hidden_states = self.x_embedder(hidden_states) + self.controlnet_x_embedder(control_condition)
|
|
123
|
+
condition = (
|
|
124
|
+
self.time_embedder(timestep, hidden_states.dtype)
|
|
125
|
+
+ self.guidance_embedder(guidance * 1000, hidden_states.dtype)
|
|
126
|
+
+ self.pooled_text_embedder(pooled_prompt_emb)
|
|
127
|
+
)
|
|
128
|
+
prompt_emb = self.context_embedder(prompt_emb)
|
|
129
|
+
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
|
130
|
+
|
|
131
|
+
# double block
|
|
132
|
+
double_block_outputs = []
|
|
133
|
+
for i, block in enumerate(self.blocks):
|
|
134
|
+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, condition, image_rotary_emb)
|
|
135
|
+
double_block_outputs.append(self.blocks_proj[i](hidden_states))
|
|
136
|
+
|
|
137
|
+
# apply control scale
|
|
138
|
+
double_block_outputs = [control_scale * output for output in double_block_outputs]
|
|
139
|
+
return double_block_outputs, None
|
|
140
|
+
|
|
141
|
+
@classmethod
|
|
142
|
+
def from_state_dict(
|
|
143
|
+
cls,
|
|
144
|
+
state_dict: Dict[str, torch.Tensor],
|
|
145
|
+
device: str,
|
|
146
|
+
dtype: torch.dtype,
|
|
147
|
+
attn_impl: Optional[str] = None,
|
|
148
|
+
):
|
|
149
|
+
if "controlnet_x_embedder.weight" in state_dict:
|
|
150
|
+
condition_channels = state_dict["controlnet_x_embedder.weight"].shape[1]
|
|
151
|
+
else:
|
|
152
|
+
condition_channels = 64
|
|
153
|
+
|
|
154
|
+
with no_init_weights():
|
|
155
|
+
model = torch.nn.utils.skip_init(
|
|
156
|
+
cls, condition_channels=condition_channels, attn_impl=attn_impl, device=device, dtype=dtype
|
|
157
|
+
)
|
|
158
|
+
model.load_state_dict(state_dict)
|
|
159
|
+
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
160
|
+
return model
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import torch
|
|
3
3
|
import torch.nn as nn
|
|
4
|
+
import numpy as np
|
|
4
5
|
from typing import Dict, Optional
|
|
5
6
|
from einops import rearrange
|
|
6
7
|
|
|
@@ -327,7 +328,6 @@ class FluxDiT(PreTrainedModel):
|
|
|
327
328
|
|
|
328
329
|
def __init__(
|
|
329
330
|
self,
|
|
330
|
-
disable_guidance_embedder=False,
|
|
331
331
|
attn_impl: Optional[str] = None,
|
|
332
332
|
device: str = "cuda:0",
|
|
333
333
|
dtype: torch.dtype = torch.bfloat16,
|
|
@@ -335,9 +335,7 @@ class FluxDiT(PreTrainedModel):
|
|
|
335
335
|
super().__init__()
|
|
336
336
|
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
|
337
337
|
self.time_embedder = TimestepEmbeddings(256, 3072, device=device, dtype=dtype)
|
|
338
|
-
self.guidance_embedder = (
|
|
339
|
-
None if disable_guidance_embedder else TimestepEmbeddings(256, 3072, device=device, dtype=dtype)
|
|
340
|
-
)
|
|
338
|
+
self.guidance_embedder = TimestepEmbeddings(256, 3072, device=device, dtype=dtype)
|
|
341
339
|
self.pooled_text_embedder = nn.Sequential(
|
|
342
340
|
nn.Linear(768, 3072, device=device, dtype=dtype),
|
|
343
341
|
nn.SiLU(),
|
|
@@ -392,6 +390,8 @@ class FluxDiT(PreTrainedModel):
|
|
|
392
390
|
text_ids,
|
|
393
391
|
image_ids=None,
|
|
394
392
|
use_gradient_checkpointing=False,
|
|
393
|
+
controlnet_double_block_output=None,
|
|
394
|
+
controlnet_single_block_output=None,
|
|
395
395
|
**kwargs,
|
|
396
396
|
):
|
|
397
397
|
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
|
|
@@ -413,16 +413,10 @@ class FluxDiT(PreTrainedModel):
|
|
|
413
413
|
hidden_states = self.patchify(hidden_states)
|
|
414
414
|
hidden_states = self.x_embedder(hidden_states)
|
|
415
415
|
|
|
416
|
-
|
|
417
|
-
def custom_forward(*inputs):
|
|
418
|
-
return module(*inputs)
|
|
419
|
-
|
|
420
|
-
return custom_forward
|
|
421
|
-
|
|
422
|
-
for block in self.blocks:
|
|
416
|
+
for i, block in enumerate(self.blocks):
|
|
423
417
|
if self.training and use_gradient_checkpointing:
|
|
424
418
|
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
|
425
|
-
|
|
419
|
+
block,
|
|
426
420
|
hidden_states,
|
|
427
421
|
prompt_emb,
|
|
428
422
|
conditioning,
|
|
@@ -431,12 +425,16 @@ class FluxDiT(PreTrainedModel):
|
|
|
431
425
|
)
|
|
432
426
|
else:
|
|
433
427
|
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
|
428
|
+
if controlnet_double_block_output is not None:
|
|
429
|
+
interval_control = len(self.blocks) / len(controlnet_double_block_output)
|
|
430
|
+
interval_control = int(np.ceil(interval_control))
|
|
431
|
+
hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
|
|
434
432
|
|
|
435
433
|
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
|
436
434
|
for block in self.single_blocks:
|
|
437
435
|
if self.training and use_gradient_checkpointing:
|
|
438
436
|
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
|
439
|
-
|
|
437
|
+
block,
|
|
440
438
|
hidden_states,
|
|
441
439
|
prompt_emb,
|
|
442
440
|
conditioning,
|
|
@@ -445,12 +443,15 @@ class FluxDiT(PreTrainedModel):
|
|
|
445
443
|
)
|
|
446
444
|
else:
|
|
447
445
|
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
|
448
|
-
|
|
446
|
+
if controlnet_single_block_output is not None:
|
|
447
|
+
interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
|
|
448
|
+
interval_control = int(np.ceil(interval_control))
|
|
449
|
+
hidden_states = hidden_states + controlnet_single_block_output[i // interval_control]
|
|
449
450
|
|
|
451
|
+
hidden_states = hidden_states[:, prompt_emb.shape[1] :]
|
|
450
452
|
hidden_states = self.final_norm_out(hidden_states, conditioning)
|
|
451
453
|
hidden_states = self.final_proj_out(hidden_states)
|
|
452
454
|
hidden_states = self.unpatchify(hidden_states, height, width)
|
|
453
|
-
|
|
454
455
|
return hidden_states
|
|
455
456
|
|
|
456
457
|
@classmethod
|
|
@@ -459,7 +460,6 @@ class FluxDiT(PreTrainedModel):
|
|
|
459
460
|
state_dict: Dict[str, torch.Tensor],
|
|
460
461
|
device: str,
|
|
461
462
|
dtype: torch.dtype,
|
|
462
|
-
disable_guidance_embedder: bool = False,
|
|
463
463
|
attn_impl: Optional[str] = None,
|
|
464
464
|
):
|
|
465
465
|
with no_init_weights():
|
|
@@ -467,7 +467,6 @@ class FluxDiT(PreTrainedModel):
|
|
|
467
467
|
cls,
|
|
468
468
|
device=device,
|
|
469
469
|
dtype=dtype,
|
|
470
|
-
disable_guidance_embedder=disable_guidance_embedder,
|
|
471
470
|
attn_impl=attn_impl,
|
|
472
471
|
)
|
|
473
472
|
model = model.requires_grad_(False) # for loading gguf
|
|
@@ -268,16 +268,10 @@ class SD3DiT(PreTrainedModel):
|
|
|
268
268
|
height, width = hidden_states.shape[-2:]
|
|
269
269
|
hidden_states = self.pos_embedder(hidden_states)
|
|
270
270
|
|
|
271
|
-
def create_custom_forward(module):
|
|
272
|
-
def custom_forward(*inputs):
|
|
273
|
-
return module(*inputs)
|
|
274
|
-
|
|
275
|
-
return custom_forward
|
|
276
|
-
|
|
277
271
|
for block in self.blocks:
|
|
278
272
|
if self.training and use_gradient_checkpointing:
|
|
279
273
|
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
|
280
|
-
|
|
274
|
+
block,
|
|
281
275
|
hidden_states,
|
|
282
276
|
prompt_emb,
|
|
283
277
|
conditioning,
|
|
@@ -260,12 +260,6 @@ class SDXLUNet(PreTrainedModel):
|
|
|
260
260
|
res_stack = [hidden_states]
|
|
261
261
|
|
|
262
262
|
# 3. blocks
|
|
263
|
-
def create_custom_forward(module):
|
|
264
|
-
def custom_forward(*inputs):
|
|
265
|
-
return module(*inputs)
|
|
266
|
-
|
|
267
|
-
return custom_forward
|
|
268
|
-
|
|
269
263
|
for i, block in enumerate(self.blocks):
|
|
270
264
|
if (
|
|
271
265
|
self.training
|
|
@@ -273,7 +267,7 @@ class SDXLUNet(PreTrainedModel):
|
|
|
273
267
|
and not (isinstance(block, PushBlock) or isinstance(block, PopBlock))
|
|
274
268
|
):
|
|
275
269
|
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
|
276
|
-
|
|
270
|
+
block,
|
|
277
271
|
hidden_states,
|
|
278
272
|
time_emb,
|
|
279
273
|
text_emb,
|
|
@@ -166,6 +166,7 @@ class CrossAttention(nn.Module):
|
|
|
166
166
|
if self.has_image_input:
|
|
167
167
|
k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
|
|
168
168
|
k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
|
|
169
|
+
v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
|
|
169
170
|
y = attention(q, k_img, v_img, attn_impl=self.attn_impl).flatten(2)
|
|
170
171
|
x = x + y
|
|
171
172
|
return self.o(x)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from .base import BasePipeline, LoRAStateDictConverter
|
|
2
|
-
from .flux_image import FluxImagePipeline, FluxModelConfig
|
|
2
|
+
from .flux_image import FluxImagePipeline, FluxModelConfig, ControlNetParams
|
|
3
3
|
from .sdxl_image import SDXLImagePipeline, SDXLModelConfig
|
|
4
4
|
from .sd_image import SDImagePipeline, SDModelConfig
|
|
5
5
|
from .wan_video import WanVideoPipeline, WanModelConfig
|
|
@@ -15,4 +15,5 @@ __all__ = [
|
|
|
15
15
|
"SDModelConfig",
|
|
16
16
|
"WanVideoPipeline",
|
|
17
17
|
"WanModelConfig",
|
|
18
|
+
"ControlNetParams",
|
|
18
19
|
]
|