diffsynth-engine 0.2.4__tar.gz → 0.2.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/PKG-INFO +2 -1
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/README.md +9 -2
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/basic/attention.py +0 -2
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/flux/flux_dit.py +70 -26
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/utils.py +0 -1
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/vae/vae.py +4 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/wan/wan_dit.py +17 -42
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/pipelines/base.py +65 -10
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/pipelines/flux_image.py +137 -25
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/pipelines/sd_image.py +3 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/pipelines/sdxl_image.py +3 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/pipelines/wan_video.py +11 -39
- diffsynth_engine-0.2.6/diffsynth_engine/processor/canny_processor.py +21 -0
- diffsynth_engine-0.2.6/diffsynth_engine/processor/depth_processor.py +42 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tools/flux_inpainting_tool.py +3 -1
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tools/flux_outpainting_tool.py +3 -1
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tools/flux_reference_tool.py +1 -1
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tools/flux_replace_tool.py +1 -1
- diffsynth_engine-0.2.6/diffsynth_engine/utils/__init__.py +0 -0
- diffsynth_engine-0.2.6/diffsynth_engine/utils/onnx.py +33 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/parallel.py +64 -2
- diffsynth_engine-0.2.6/diffsynth_engine/utils/platform.py +12 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine.egg-info/PKG-INFO +2 -1
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine.egg-info/SOURCES.txt +5 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine.egg-info/requires.txt +1 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/pyproject.toml +2 -1
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/.gitignore +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/.pre-commit-config.yaml +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/LICENSE +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/MANIFEST.in +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/assets/dingtalk.png +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/assets/showcase.jpeg +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/components/vae.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/wan/dit/14b-flf2v.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/wan/dit/14b-i2v.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/models/wan/dit/14b-t2v.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/kernels/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/base.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/basic/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/basic/lora.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/basic/timestep.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/basic/transformer_helper.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/basic/unet_helper.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/flux/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/flux/flux_controlnet.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/flux/flux_ipadapter.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/flux/flux_redux.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/flux/flux_vae.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sd/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sd/sd_unet.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sd/sd_vae.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sd3/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sd3/sd3_dit.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sdxl/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sdxl/sdxl_unet.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/text_encoder/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/text_encoder/clip.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/text_encoder/siglip.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/text_encoder/t5.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/vae/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/wan/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/models/wan/wan_vae.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/pipelines/__init__.py +0 -0
- {diffsynth_engine-0.2.4/diffsynth_engine/utils → diffsynth_engine-0.2.6/diffsynth_engine/processor}/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tokenizers/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tokenizers/base.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tokenizers/clip.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tokenizers/t5.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tokenizers/wan.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/tools/__init__.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/constants.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/download.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/env.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/flag.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/fp8_linear.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/gguf.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/image.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/loader.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/lock.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/logging.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/offload.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/prompt.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine/utils/video.py +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/diffsynth_engine.egg-info/top_level.txt +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/docs/tutorial.md +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/docs/tutorial_zh.md +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/setup.cfg +0 -0
- {diffsynth_engine-0.2.4 → diffsynth_engine-0.2.6}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: diffsynth_engine
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.6
|
|
4
4
|
Author: MuseAI x ModelScope
|
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
|
6
6
|
Classifier: Operating System :: OS Independent
|
|
@@ -23,6 +23,7 @@ Requires-Dist: torchsde
|
|
|
23
23
|
Requires-Dist: pillow
|
|
24
24
|
Requires-Dist: imageio[ffmpeg]
|
|
25
25
|
Requires-Dist: yunchang; sys_platform == "linux"
|
|
26
|
+
Requires-Dist: onnxruntime
|
|
26
27
|
Provides-Extra: dev
|
|
27
28
|
Requires-Dist: diffusers==0.31.0; extra == "dev"
|
|
28
29
|
Requires-Dist: transformers==4.45.2; extra == "dev"
|
|
@@ -45,7 +45,7 @@ Text to image
|
|
|
45
45
|
```python
|
|
46
46
|
from diffsynth_engine import fetch_model, FluxImagePipeline
|
|
47
47
|
|
|
48
|
-
model_path = fetch_model("muse/flux-with-vae", path="
|
|
48
|
+
model_path = fetch_model("muse/flux-with-vae", path="flux1-dev-with-vae.safetensors")
|
|
49
49
|
pipe = FluxImagePipeline.from_pretrained(model_path, device='cuda:0')
|
|
50
50
|
image = pipe(prompt="a cat")
|
|
51
51
|
image.save("image.png")
|
|
@@ -54,7 +54,7 @@ Text to image with LoRA
|
|
|
54
54
|
```python
|
|
55
55
|
from diffsynth_engine import fetch_model, FluxImagePipeline
|
|
56
56
|
|
|
57
|
-
model_path = fetch_model("muse/flux-with-vae", path="
|
|
57
|
+
model_path = fetch_model("muse/flux-with-vae", path="flux1-dev-with-vae.safetensors")
|
|
58
58
|
lora_path = fetch_model("DonRat/MAJICFLUS_SuperChinesestyleheongsam", path="麦橘超国风旗袍.safetensors")
|
|
59
59
|
|
|
60
60
|
pipe = FluxImagePipeline.from_pretrained(model_path, device='cuda:0')
|
|
@@ -77,6 +77,13 @@ If you have any questions or feedback, please scan the QR code below, or send em
|
|
|
77
77
|
<img src="assets/dingtalk.png" alt="dingtalk" width="400" />
|
|
78
78
|
</div>
|
|
79
79
|
|
|
80
|
+
## Contributing
|
|
81
|
+
We welcome contributions to DiffSynth-Engine. After Install from source, we recommand developers install this project using following command to setup the development environment.
|
|
82
|
+
```bash
|
|
83
|
+
pip install -e '.[dev]'
|
|
84
|
+
```
|
|
85
|
+
TODO: Please refer to [CONTRIBUTING.md](./CONTRIBUTING.md) for more details.
|
|
86
|
+
|
|
80
87
|
## License
|
|
81
88
|
This project is licensed under the Apache License 2.0. See the LICENSE file for details.
|
|
82
89
|
|
|
@@ -13,11 +13,12 @@ from diffsynth_engine.models.basic.transformer_helper import (
|
|
|
13
13
|
)
|
|
14
14
|
from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
|
|
15
15
|
from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
|
|
16
|
+
from diffsynth_engine.models.basic import attention as attention_ops
|
|
16
17
|
from diffsynth_engine.models.utils import no_init_weights
|
|
17
18
|
from diffsynth_engine.utils.gguf import gguf_inference
|
|
18
19
|
from diffsynth_engine.utils.fp8_linear import fp8_inference
|
|
19
20
|
from diffsynth_engine.utils.constants import FLUX_DIT_CONFIG_FILE
|
|
20
|
-
from diffsynth_engine.
|
|
21
|
+
from diffsynth_engine.utils.parallel import sequence_parallel, sequence_parallel_unshard
|
|
21
22
|
from diffsynth_engine.utils import logging
|
|
22
23
|
|
|
23
24
|
|
|
@@ -198,7 +199,7 @@ class FluxDoubleAttention(nn.Module):
|
|
|
198
199
|
k = torch.cat([self.norm_k_b(k_b), self.norm_k_a(k_a)], dim=1)
|
|
199
200
|
v = torch.cat([v_b, v_a], dim=1)
|
|
200
201
|
q, k = apply_rope(q, k, rope_emb)
|
|
201
|
-
attn_out = attention(q, k, v, attn_impl=self.attn_impl)
|
|
202
|
+
attn_out = attention_ops.attention(q, k, v, attn_impl=self.attn_impl)
|
|
202
203
|
attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
|
|
203
204
|
text_out, image_out = attn_out[:, : text.shape[1]], attn_out[:, text.shape[1] :]
|
|
204
205
|
image_out, text_out = self.attention_callback(
|
|
@@ -286,7 +287,7 @@ class FluxSingleAttention(nn.Module):
|
|
|
286
287
|
def forward(self, x, rope_emb, image_emb):
|
|
287
288
|
q, k, v = rearrange(self.to_qkv(x), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
|
|
288
289
|
q, k = apply_rope(self.norm_q_a(q), self.norm_k_a(k), rope_emb)
|
|
289
|
-
attn_out = attention(q, k, v, attn_impl=self.attn_impl)
|
|
290
|
+
attn_out = attention_ops.attention(q, k, v, attn_impl=self.attn_impl)
|
|
290
291
|
attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
|
|
291
292
|
return self.attention_callback(attn_out=attn_out, x=x, q=q, k=k, v=v, rope_emb=rope_emb, image_emb=image_emb)
|
|
292
293
|
|
|
@@ -322,7 +323,9 @@ class FluxDiT(PreTrainedModel):
|
|
|
322
323
|
|
|
323
324
|
def __init__(
|
|
324
325
|
self,
|
|
326
|
+
in_channel: int = 64,
|
|
325
327
|
attn_impl: Optional[str] = None,
|
|
328
|
+
use_usp: bool = False,
|
|
326
329
|
device: str = "cuda:0",
|
|
327
330
|
dtype: torch.dtype = torch.bfloat16,
|
|
328
331
|
):
|
|
@@ -336,7 +339,8 @@ class FluxDiT(PreTrainedModel):
|
|
|
336
339
|
nn.Linear(3072, 3072, device=device, dtype=dtype),
|
|
337
340
|
)
|
|
338
341
|
self.context_embedder = nn.Linear(4096, 3072, device=device, dtype=dtype)
|
|
339
|
-
|
|
342
|
+
# normal flux has 64 channels, bfl canny and depth has 128 channels, bfl fill has 384 channels, bfl redux has 64 channels
|
|
343
|
+
self.x_embedder = nn.Linear(in_channel, 3072, device=device, dtype=dtype)
|
|
340
344
|
|
|
341
345
|
self.blocks = nn.ModuleList(
|
|
342
346
|
[FluxDoubleTransformerBlock(3072, 24, attn_impl=attn_impl, device=device, dtype=dtype) for _ in range(19)]
|
|
@@ -347,6 +351,8 @@ class FluxDiT(PreTrainedModel):
|
|
|
347
351
|
self.final_norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
|
|
348
352
|
self.final_proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
|
|
349
353
|
|
|
354
|
+
self.use_usp = use_usp
|
|
355
|
+
|
|
350
356
|
def patchify(self, hidden_states):
|
|
351
357
|
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
|
352
358
|
return hidden_states
|
|
@@ -357,7 +363,8 @@ class FluxDiT(PreTrainedModel):
|
|
|
357
363
|
)
|
|
358
364
|
return hidden_states
|
|
359
365
|
|
|
360
|
-
|
|
366
|
+
@staticmethod
|
|
367
|
+
def prepare_image_ids(latents: torch.Tensor):
|
|
361
368
|
batch_size, _, height, width = latents.shape
|
|
362
369
|
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
|
363
370
|
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
|
@@ -387,7 +394,14 @@ class FluxDiT(PreTrainedModel):
|
|
|
387
394
|
controlnet_single_block_output=None,
|
|
388
395
|
**kwargs,
|
|
389
396
|
):
|
|
390
|
-
|
|
397
|
+
h, w = hidden_states.shape[-2:]
|
|
398
|
+
controlnet_double_block_output = (
|
|
399
|
+
controlnet_double_block_output if controlnet_double_block_output is not None else ()
|
|
400
|
+
)
|
|
401
|
+
controlnet_single_block_output = (
|
|
402
|
+
controlnet_single_block_output if controlnet_single_block_output is not None else ()
|
|
403
|
+
)
|
|
404
|
+
|
|
391
405
|
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
|
|
392
406
|
with fp8_inference(fp8_linear_enabled), gguf_inference():
|
|
393
407
|
if image_ids is None:
|
|
@@ -400,28 +414,54 @@ class FluxDiT(PreTrainedModel):
|
|
|
400
414
|
guidance = guidance * 1000
|
|
401
415
|
conditioning += self.guidance_embedder(guidance, hidden_states.dtype)
|
|
402
416
|
conditioning += self.pooled_text_embedder(pooled_prompt_emb)
|
|
403
|
-
prompt_emb = self.context_embedder(prompt_emb)
|
|
404
417
|
rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
|
418
|
+
text_rope_emb = rope_emb[:, :, : text_ids.size(1)]
|
|
419
|
+
image_rope_emb = rope_emb[:, :, text_ids.size(1) :]
|
|
405
420
|
hidden_states = self.patchify(hidden_states)
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
421
|
+
|
|
422
|
+
with sequence_parallel(
|
|
423
|
+
(
|
|
424
|
+
hidden_states,
|
|
425
|
+
prompt_emb,
|
|
426
|
+
text_rope_emb,
|
|
427
|
+
image_rope_emb,
|
|
428
|
+
*controlnet_double_block_output,
|
|
429
|
+
*controlnet_single_block_output,
|
|
430
|
+
),
|
|
431
|
+
seq_dims=(
|
|
432
|
+
1,
|
|
433
|
+
1,
|
|
434
|
+
2,
|
|
435
|
+
2,
|
|
436
|
+
*(1 for _ in controlnet_double_block_output),
|
|
437
|
+
*(1 for _ in controlnet_single_block_output),
|
|
438
|
+
),
|
|
439
|
+
enabled=self.use_usp,
|
|
440
|
+
):
|
|
441
|
+
hidden_states = self.x_embedder(hidden_states)
|
|
442
|
+
prompt_emb = self.context_embedder(prompt_emb)
|
|
443
|
+
rope_emb = torch.cat((text_rope_emb, image_rope_emb), dim=2)
|
|
444
|
+
|
|
445
|
+
for i, block in enumerate(self.blocks):
|
|
446
|
+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
|
|
447
|
+
if len(controlnet_double_block_output) > 0:
|
|
448
|
+
interval_control = len(self.blocks) / len(controlnet_double_block_output)
|
|
449
|
+
interval_control = int(np.ceil(interval_control))
|
|
450
|
+
hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
|
|
451
|
+
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
|
452
|
+
for i, block in enumerate(self.single_blocks):
|
|
453
|
+
hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
|
|
454
|
+
if len(controlnet_single_block_output) > 0:
|
|
455
|
+
interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
|
|
456
|
+
interval_control = int(np.ceil(interval_control))
|
|
457
|
+
hidden_states = hidden_states + controlnet_single_block_output[i // interval_control]
|
|
458
|
+
|
|
459
|
+
hidden_states = hidden_states[:, prompt_emb.shape[1] :]
|
|
460
|
+
hidden_states = self.final_norm_out(hidden_states, conditioning)
|
|
461
|
+
hidden_states = self.final_proj_out(hidden_states)
|
|
462
|
+
(hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(h * w // 4,))
|
|
463
|
+
|
|
464
|
+
hidden_states = self.unpatchify(hidden_states, h, w)
|
|
425
465
|
return hidden_states
|
|
426
466
|
|
|
427
467
|
@classmethod
|
|
@@ -430,14 +470,18 @@ class FluxDiT(PreTrainedModel):
|
|
|
430
470
|
state_dict: Dict[str, torch.Tensor],
|
|
431
471
|
device: str,
|
|
432
472
|
dtype: torch.dtype,
|
|
473
|
+
in_channel: int = 64,
|
|
433
474
|
attn_impl: Optional[str] = None,
|
|
475
|
+
use_usp: bool = False,
|
|
434
476
|
):
|
|
435
477
|
with no_init_weights():
|
|
436
478
|
model = torch.nn.utils.skip_init(
|
|
437
479
|
cls,
|
|
438
480
|
device=device,
|
|
439
481
|
dtype=dtype,
|
|
482
|
+
in_channel=in_channel,
|
|
440
483
|
attn_impl=attn_impl,
|
|
484
|
+
use_usp=use_usp,
|
|
441
485
|
)
|
|
442
486
|
model = model.requires_grad_(False) # for loading gguf
|
|
443
487
|
model.load_state_dict(state_dict, assign=True)
|
|
@@ -167,6 +167,8 @@ class VAEDecoder(PreTrainedModel):
|
|
|
167
167
|
self.conv_norm_out = nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6, device=device, dtype=dtype)
|
|
168
168
|
self.conv_act = nn.SiLU()
|
|
169
169
|
self.conv_out = nn.Conv2d(128, 3, kernel_size=3, padding=1, device=device, dtype=dtype)
|
|
170
|
+
self.device = device
|
|
171
|
+
self.dtype = dtype
|
|
170
172
|
|
|
171
173
|
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
|
172
174
|
original_dtype = sample.dtype
|
|
@@ -277,6 +279,8 @@ class VAEEncoder(PreTrainedModel):
|
|
|
277
279
|
self.conv_norm_out = nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6, device=device, dtype=dtype)
|
|
278
280
|
self.conv_act = nn.SiLU()
|
|
279
281
|
self.conv_out = nn.Conv2d(512, 2 * latent_channels, kernel_size=3, padding=1, device=device, dtype=dtype)
|
|
282
|
+
self.device = device
|
|
283
|
+
self.dtype = dtype
|
|
280
284
|
|
|
281
285
|
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
|
282
286
|
original_dtype = sample.dtype
|
|
@@ -2,12 +2,11 @@ import math
|
|
|
2
2
|
import json
|
|
3
3
|
import torch
|
|
4
4
|
import torch.nn as nn
|
|
5
|
-
import torch.distributed as dist
|
|
6
5
|
from typing import Tuple, Optional
|
|
7
6
|
from einops import rearrange
|
|
8
7
|
|
|
9
8
|
from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
|
|
10
|
-
from diffsynth_engine.models.basic
|
|
9
|
+
from diffsynth_engine.models.basic import attention as attention_ops
|
|
11
10
|
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
|
|
12
11
|
from diffsynth_engine.models.utils import no_init_weights
|
|
13
12
|
from diffsynth_engine.utils.constants import (
|
|
@@ -17,11 +16,7 @@ from diffsynth_engine.utils.constants import (
|
|
|
17
16
|
WAN_DIT_14B_FLF2V_CONFIG_FILE,
|
|
18
17
|
)
|
|
19
18
|
from diffsynth_engine.utils.gguf import gguf_inference
|
|
20
|
-
from diffsynth_engine.utils.parallel import
|
|
21
|
-
get_sp_group,
|
|
22
|
-
get_sp_world_size,
|
|
23
|
-
get_sp_rank,
|
|
24
|
-
)
|
|
19
|
+
from diffsynth_engine.utils.parallel import sequence_parallel, sequence_parallel_unshard
|
|
25
20
|
|
|
26
21
|
T5_TOKEN_NUM = 512
|
|
27
22
|
FLF_TOKEN_NUM = 257 * 2
|
|
@@ -90,20 +85,12 @@ class SelfAttention(nn.Module):
|
|
|
90
85
|
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
|
91
86
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
|
92
87
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
)
|
|
100
|
-
else:
|
|
101
|
-
x = attention(
|
|
102
|
-
q=rope_apply(q, freqs),
|
|
103
|
-
k=rope_apply(k, freqs),
|
|
104
|
-
v=v,
|
|
105
|
-
attn_impl=self.attn_impl,
|
|
106
|
-
)
|
|
88
|
+
x = attention_ops.attention(
|
|
89
|
+
q=rope_apply(q, freqs),
|
|
90
|
+
k=rope_apply(k, freqs),
|
|
91
|
+
v=v,
|
|
92
|
+
attn_impl=self.attn_impl,
|
|
93
|
+
)
|
|
107
94
|
x = x.flatten(2)
|
|
108
95
|
return self.o(x)
|
|
109
96
|
|
|
@@ -148,12 +135,12 @@ class CrossAttention(nn.Module):
|
|
|
148
135
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
|
149
136
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
|
150
137
|
|
|
151
|
-
x = attention(q, k, v, attn_impl=self.attn_impl).flatten(2)
|
|
138
|
+
x = attention_ops.attention(q, k, v, attn_impl=self.attn_impl).flatten(2)
|
|
152
139
|
if self.has_image_input:
|
|
153
140
|
k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
|
|
154
141
|
k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
|
|
155
142
|
v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
|
|
156
|
-
y = attention(q, k_img, v_img, attn_impl=self.attn_impl).flatten(2)
|
|
143
|
+
y = attention_ops.attention(q, k_img, v_img, attn_impl=self.attn_impl).flatten(2)
|
|
157
144
|
x = x + y
|
|
158
145
|
return self.o(x)
|
|
159
146
|
|
|
@@ -316,10 +303,7 @@ class WanDiT(PreTrainedModel):
|
|
|
316
303
|
if has_image_input:
|
|
317
304
|
self.img_emb = MLP(1280, dim, flf_pos_emb, device=device, dtype=dtype) # clip_feature_dim = 1280
|
|
318
305
|
|
|
319
|
-
|
|
320
|
-
setattr(self, "use_usp", True)
|
|
321
|
-
for block in self.blocks:
|
|
322
|
-
setattr(block.self_attn, "use_usp", True)
|
|
306
|
+
self.use_usp = use_usp
|
|
323
307
|
|
|
324
308
|
def patchify(self, x: torch.Tensor):
|
|
325
309
|
x = self.patch_embedding(x) # b c f h w -> b 4c f h/2 w/2
|
|
@@ -368,21 +352,12 @@ class WanDiT(PreTrainedModel):
|
|
|
368
352
|
.reshape(f * h * w, 1, -1)
|
|
369
353
|
.to(x.device)
|
|
370
354
|
)
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
for block in self.blocks:
|
|
378
|
-
x = block(x, context, t_mod, freqs)
|
|
379
|
-
x = self.head(x, t)
|
|
380
|
-
|
|
381
|
-
if getattr(self, "use_usp", False):
|
|
382
|
-
b, d = x.size(0), x.size(2) # (batch_size, out_dim)
|
|
383
|
-
xs = [torch.zeros((b, s, d), dtype=x.dtype, device=x.device) for s in split_size]
|
|
384
|
-
dist.all_gather(xs, x, group=get_sp_group())
|
|
385
|
-
x = torch.concat(xs, dim=1)
|
|
355
|
+
|
|
356
|
+
with sequence_parallel([x, freqs], seq_dims=(1, 0), enabled=self.use_usp):
|
|
357
|
+
for block in self.blocks:
|
|
358
|
+
x = block(x, context, t_mod, freqs)
|
|
359
|
+
x = self.head(x, t)
|
|
360
|
+
(x,) = sequence_parallel_unshard((x,), seq_dims=(1,), seq_lens=(f * h * w,))
|
|
386
361
|
x = self.unpatchify(x, (f, h, w))
|
|
387
362
|
return x
|
|
388
363
|
|
|
@@ -4,10 +4,11 @@ import numpy as np
|
|
|
4
4
|
from typing import Dict, List, Tuple
|
|
5
5
|
from PIL import Image
|
|
6
6
|
from dataclasses import dataclass
|
|
7
|
-
from diffsynth_engine.utils.loader import load_file
|
|
8
7
|
from diffsynth_engine.utils.offload import enable_sequential_cpu_offload
|
|
9
8
|
from diffsynth_engine.utils.gguf import load_gguf_checkpoint
|
|
10
9
|
from diffsynth_engine.utils import logging
|
|
10
|
+
from diffsynth_engine.utils.loader import load_file
|
|
11
|
+
from diffsynth_engine.utils.platform import empty_cache
|
|
11
12
|
|
|
12
13
|
logger = logging.get_logger(__name__)
|
|
13
14
|
|
|
@@ -25,14 +26,21 @@ class LoRAStateDictConverter:
|
|
|
25
26
|
class BasePipeline:
|
|
26
27
|
lora_converter = LoRAStateDictConverter()
|
|
27
28
|
|
|
28
|
-
def __init__(
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
vae_tiled: bool = False,
|
|
32
|
+
vae_tile_size: int = -1,
|
|
33
|
+
vae_tile_stride: int = -1,
|
|
34
|
+
device="cuda:0",
|
|
35
|
+
dtype=torch.float16,
|
|
36
|
+
):
|
|
29
37
|
super().__init__()
|
|
30
|
-
self.device = device
|
|
31
|
-
self.dtype = dtype
|
|
32
|
-
self.offload_mode = None
|
|
33
38
|
self.vae_tiled = vae_tiled
|
|
34
39
|
self.vae_tile_size = vae_tile_size
|
|
35
40
|
self.vae_tile_stride = vae_tile_stride
|
|
41
|
+
self.device = device
|
|
42
|
+
self.dtype = dtype
|
|
43
|
+
self.offload_mode = None
|
|
36
44
|
self.model_names = []
|
|
37
45
|
|
|
38
46
|
@classmethod
|
|
@@ -144,6 +152,7 @@ class BasePipeline:
|
|
|
144
152
|
return noise
|
|
145
153
|
|
|
146
154
|
def encode_image(self, image: torch.Tensor) -> torch.Tensor:
|
|
155
|
+
image = image.to(self.device, self.vae_encoder.dtype)
|
|
147
156
|
latents = self.vae_encoder(
|
|
148
157
|
image, tiled=self.vae_tiled, tile_size=self.vae_tile_size, tile_stride=self.vae_tile_stride
|
|
149
158
|
)
|
|
@@ -151,8 +160,9 @@ class BasePipeline:
|
|
|
151
160
|
|
|
152
161
|
def decode_image(self, latent: torch.Tensor) -> torch.Tensor:
|
|
153
162
|
vae_dtype = self.vae_decoder.conv_in.weight.dtype
|
|
163
|
+
latent = latent.to(self.device, vae_dtype)
|
|
154
164
|
image = self.vae_decoder(
|
|
155
|
-
latent
|
|
165
|
+
latent, tiled=self.vae_tiled, tile_size=self.vae_tile_size, tile_stride=self.vae_tile_stride
|
|
156
166
|
)
|
|
157
167
|
return image
|
|
158
168
|
|
|
@@ -196,8 +206,53 @@ class BasePipeline:
|
|
|
196
206
|
model.eval()
|
|
197
207
|
return self
|
|
198
208
|
|
|
199
|
-
|
|
200
|
-
|
|
209
|
+
@staticmethod
|
|
210
|
+
def init_parallel_config(
|
|
211
|
+
parallelism: int,
|
|
212
|
+
use_cfg_parallel: bool,
|
|
213
|
+
model_config: ModelConfig,
|
|
214
|
+
):
|
|
215
|
+
assert parallelism in (2, 4, 8), "parallelism must be 2, 4 or 8"
|
|
216
|
+
cfg_degree = 2 if use_cfg_parallel else 1
|
|
217
|
+
sp_ulysses_degree = getattr(model_config, "sp_ulysses_degree", None)
|
|
218
|
+
sp_ring_degree = getattr(model_config, "sp_ring_degree", None)
|
|
219
|
+
tp_degree = getattr(model_config, "tp_degree", None)
|
|
220
|
+
use_fsdp = getattr(model_config, "use_fsdp", False)
|
|
221
|
+
|
|
222
|
+
if tp_degree is not None:
|
|
223
|
+
assert sp_ulysses_degree is None and sp_ring_degree is None, (
|
|
224
|
+
"not allowed to enable sequence parallel and tensor parallel together; "
|
|
225
|
+
"either set sp_ulysses_degree=None, sp_ring_degree=None or set tp_degree=None during pipeline initialization"
|
|
226
|
+
)
|
|
227
|
+
assert use_fsdp is False, (
|
|
228
|
+
"not allowed to enable fully sharded data parallel and tensor parallel together; "
|
|
229
|
+
"either set use_fsdp=False or set tp_degree=None during pipeline initialization"
|
|
230
|
+
)
|
|
231
|
+
assert parallelism == cfg_degree * tp_degree, (
|
|
232
|
+
f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * tp_degree ({tp_degree})"
|
|
233
|
+
)
|
|
234
|
+
sp_ulysses_degree = 1
|
|
235
|
+
sp_ring_degree = 1
|
|
236
|
+
elif sp_ulysses_degree is None and sp_ring_degree is None:
|
|
237
|
+
# use ulysses if not specified
|
|
238
|
+
sp_ulysses_degree = parallelism // cfg_degree
|
|
239
|
+
sp_ring_degree = 1
|
|
240
|
+
tp_degree = 1
|
|
241
|
+
elif sp_ulysses_degree is not None and sp_ring_degree is not None:
|
|
242
|
+
assert parallelism == cfg_degree * sp_ulysses_degree * sp_ring_degree, (
|
|
243
|
+
f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * "
|
|
244
|
+
f"sp_ulysses_degree ({sp_ulysses_degree}) * sp_ring_degree ({sp_ring_degree})"
|
|
245
|
+
)
|
|
246
|
+
tp_degree = 1
|
|
247
|
+
else:
|
|
248
|
+
raise ValueError("sp_ulysses_degree and sp_ring_degree must be specified together")
|
|
249
|
+
return {
|
|
250
|
+
"cfg_degree": cfg_degree,
|
|
251
|
+
"sp_ulysses_degree": sp_ulysses_degree,
|
|
252
|
+
"sp_ring_degree": sp_ring_degree,
|
|
253
|
+
"tp_degree": tp_degree,
|
|
254
|
+
"use_fsdp": use_fsdp,
|
|
255
|
+
}
|
|
201
256
|
|
|
202
257
|
@staticmethod
|
|
203
258
|
def validate_offload_mode(offload_mode: str | None):
|
|
@@ -233,7 +288,7 @@ class BasePipeline:
|
|
|
233
288
|
return
|
|
234
289
|
if self.offload_mode == "sequential_cpu_offload":
|
|
235
290
|
# fresh the cuda cache
|
|
236
|
-
|
|
291
|
+
empty_cache()
|
|
237
292
|
return
|
|
238
293
|
|
|
239
294
|
# offload unnecessary models to cpu
|
|
@@ -248,4 +303,4 @@ class BasePipeline:
|
|
|
248
303
|
if model is not None and (p := next(model.parameters(), None)) is not None and p.device != self.device:
|
|
249
304
|
model.to(self.device)
|
|
250
305
|
# fresh the cuda cache
|
|
251
|
-
|
|
306
|
+
empty_cache()
|