diffsynth-engine 0.6.1.dev27__tar.gz → 0.6.1.dev29__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/PKG-INFO +1 -1
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/configs/pipeline.py +5 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/base.py +1 -1
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/basic/lora.py +1 -0
- diffsynth_engine-0.6.1.dev29/diffsynth_engine/models/basic/lora_nunchaku.py +221 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/basic/video_sparse_attention.py +15 -3
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/qwen_image/__init__.py +8 -0
- diffsynth_engine-0.6.1.dev29/diffsynth_engine/models/qwen_image/qwen_image_dit_nunchaku.py +341 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/base.py +11 -4
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/qwen_image.py +64 -2
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/wan_video.py +25 -1
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/flag.py +24 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/parallel.py +23 -107
- diffsynth_engine-0.6.1.dev29/diffsynth_engine/utils/process_group.py +149 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine.egg-info/PKG-INFO +1 -1
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine.egg-info/SOURCES.txt +3 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/.gitattributes +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/.gitignore +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/.pre-commit-config.yaml +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/LICENSE +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/MANIFEST.in +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/README.md +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/assets/dingtalk.png +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/assets/showcase.jpeg +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/assets/tongyi.svg +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/components/vae.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_config.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_vision_config.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/qwen_image/qwen_image_vae_keymap.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan2.1_flf2v_14b.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan2.1_i2v_14b.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_1.3b.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_14b.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan2.2_i2v_a14b.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan2.2_s2v_14b.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan2.2_t2v_a14b.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan2.2_ti2v_5b.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/vae/wan2.1_vae.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/vae/wan2.2_vae.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/models/wan/vae/wan_vae_keymap.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/qwen_image/qwen2_vl_image_processor.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/added_tokens.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/merges.txt +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/special_tokens_map.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/tokenizer_config.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/qwen_image/tokenizer/vocab.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/configs/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/configs/controlnet.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/kernels/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/basic/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/basic/attention.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/basic/timestep.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/basic/transformer_helper.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/basic/unet_helper.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/flux/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/flux/flux_controlnet.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/flux/flux_dit.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/flux/flux_dit_fbcache.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/flux/flux_ipadapter.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/flux/flux_redux.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/flux/flux_vae.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/hunyuan3d/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/hunyuan3d/hunyuan3d_dit.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/hunyuan3d/hunyuan3d_vae.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/hunyuan3d/moe.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/hunyuan3d/surface_extractor.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/hunyuan3d/volume_decoder.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/qwen_image/qwen2_5_vl.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/qwen_image/qwen_image_dit.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/qwen_image/qwen_image_vae.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd/sd_controlnet.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd/sd_unet.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd/sd_vae.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd3/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd3/sd3_dit.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sdxl/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sdxl/sdxl_controlnet.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sdxl/sdxl_unet.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/text_encoder/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/text_encoder/clip.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/text_encoder/siglip.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/text_encoder/t5.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/vae/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/vae/vae.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/wan/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/wan/wan_audio_encoder.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/wan/wan_dit.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/wan/wan_s2v_dit.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/wan/wan_vae.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/flux_image.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/hunyuan3d_shape.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/sd_image.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/sdxl_image.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/utils.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/wan_s2v.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/processor/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/processor/canny_processor.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/processor/depth_processor.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tokenizers/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tokenizers/base.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tokenizers/clip.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tokenizers/qwen2.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tokenizers/qwen2_vl_image_processor.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tokenizers/qwen2_vl_processor.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tokenizers/t5.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tokenizers/wan.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tools/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tools/flux_inpainting_tool.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tools/flux_outpainting_tool.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tools/flux_reference_tool.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/tools/flux_replace_tool.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/cache.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/constants.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/download.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/env.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/fp8_linear.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/gguf.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/image.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/loader.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/lock.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/logging.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/memory/__init__.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/memory/linear_regression.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/memory/memory_predcit_model.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/offload.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/onnx.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/platform.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/prompt.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/utils/video.py +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine.egg-info/requires.txt +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine.egg-info/top_level.txt +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/docs/tutorial.md +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/docs/tutorial_zh.md +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/pyproject.toml +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/setup.cfg +0 -0
- {diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/setup.py +0 -0
{diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/configs/pipeline.py
RENAMED
|
@@ -251,6 +251,11 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
|
|
|
251
251
|
# override OptimizationConfig
|
|
252
252
|
fbcache_relative_l1_threshold = 0.009
|
|
253
253
|
|
|
254
|
+
# svd
|
|
255
|
+
use_nunchaku: Optional[bool] = field(default=None, init=False)
|
|
256
|
+
use_nunchaku_awq: Optional[bool] = field(default=None, init=False)
|
|
257
|
+
use_nunchaku_attn: Optional[bool] = field(default=None, init=False)
|
|
258
|
+
|
|
254
259
|
@classmethod
|
|
255
260
|
def basic_config(
|
|
256
261
|
cls,
|
{diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/models/base.py
RENAMED
|
@@ -40,7 +40,7 @@ class PreTrainedModel(nn.Module):
|
|
|
40
40
|
|
|
41
41
|
def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = True):
|
|
42
42
|
for args in lora_args:
|
|
43
|
-
key = args["
|
|
43
|
+
key = args["key"]
|
|
44
44
|
module = self.get_submodule(key)
|
|
45
45
|
if not isinstance(module, (LoRALinear, LoRAConv2d)):
|
|
46
46
|
raise ValueError(f"Unsupported lora key: {key}")
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
|
|
5
|
+
from .lora import LoRA
|
|
6
|
+
from nunchaku.models.linear import AWQW4A16Linear, SVDQW4A4Linear
|
|
7
|
+
from nunchaku.lora.flux.nunchaku_converter import (
|
|
8
|
+
pack_lowrank_weight,
|
|
9
|
+
unpack_lowrank_weight,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LoRASVDQW4A4Linear(nn.Module):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
origin_linear: SVDQW4A4Linear,
|
|
17
|
+
):
|
|
18
|
+
super().__init__()
|
|
19
|
+
|
|
20
|
+
self.origin_linear = origin_linear
|
|
21
|
+
self.base_rank = self.origin_linear.rank
|
|
22
|
+
self._lora_dict = OrderedDict()
|
|
23
|
+
|
|
24
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
25
|
+
return self.origin_linear(x)
|
|
26
|
+
|
|
27
|
+
def __getattr__(self, name: str):
|
|
28
|
+
try:
|
|
29
|
+
return super().__getattr__(name)
|
|
30
|
+
except AttributeError:
|
|
31
|
+
return getattr(self.origin_linear, name)
|
|
32
|
+
|
|
33
|
+
def _apply_lora_weights(self, name: str, down: torch.Tensor, up: torch.Tensor, alpha: int, scale: float, rank: int):
|
|
34
|
+
final_scale = scale * (alpha / rank)
|
|
35
|
+
|
|
36
|
+
up_scaled = (up * final_scale).to(
|
|
37
|
+
dtype=self.origin_linear.proj_up.dtype, device=self.origin_linear.proj_up.device
|
|
38
|
+
)
|
|
39
|
+
down_final = down.to(dtype=self.origin_linear.proj_down.dtype, device=self.origin_linear.proj_down.device)
|
|
40
|
+
|
|
41
|
+
with torch.no_grad():
|
|
42
|
+
pd_packed = self.origin_linear.proj_down.data
|
|
43
|
+
pu_packed = self.origin_linear.proj_up.data
|
|
44
|
+
pd = unpack_lowrank_weight(pd_packed, down=True)
|
|
45
|
+
pu = unpack_lowrank_weight(pu_packed, down=False)
|
|
46
|
+
|
|
47
|
+
new_proj_down = torch.cat([pd, down_final], dim=0)
|
|
48
|
+
new_proj_up = torch.cat([pu, up_scaled], dim=1)
|
|
49
|
+
|
|
50
|
+
self.origin_linear.proj_down.data = pack_lowrank_weight(new_proj_down, down=True)
|
|
51
|
+
self.origin_linear.proj_up.data = pack_lowrank_weight(new_proj_up, down=False)
|
|
52
|
+
|
|
53
|
+
current_total_rank = self.origin_linear.rank
|
|
54
|
+
self.origin_linear.rank += rank
|
|
55
|
+
self._lora_dict[name] = {"rank": rank, "alpha": alpha, "scale": scale, "start_idx": current_total_rank}
|
|
56
|
+
|
|
57
|
+
def add_frozen_lora(
|
|
58
|
+
self,
|
|
59
|
+
name: str,
|
|
60
|
+
scale: float,
|
|
61
|
+
rank: int,
|
|
62
|
+
alpha: int,
|
|
63
|
+
up: torch.Tensor,
|
|
64
|
+
down: torch.Tensor,
|
|
65
|
+
device: str,
|
|
66
|
+
dtype: torch.dtype,
|
|
67
|
+
**kwargs,
|
|
68
|
+
):
|
|
69
|
+
if name in self._lora_dict:
|
|
70
|
+
raise ValueError(f"LoRA with name '{name}' already exists.")
|
|
71
|
+
|
|
72
|
+
self._apply_lora_weights(name, down, up, alpha, scale, rank)
|
|
73
|
+
|
|
74
|
+
def add_qkv_lora(
|
|
75
|
+
self,
|
|
76
|
+
name: str,
|
|
77
|
+
scale: float,
|
|
78
|
+
rank: int,
|
|
79
|
+
alpha: int,
|
|
80
|
+
q_up: torch.Tensor,
|
|
81
|
+
q_down: torch.Tensor,
|
|
82
|
+
k_up: torch.Tensor,
|
|
83
|
+
k_down: torch.Tensor,
|
|
84
|
+
v_up: torch.Tensor,
|
|
85
|
+
v_down: torch.Tensor,
|
|
86
|
+
device: str,
|
|
87
|
+
dtype: torch.dtype,
|
|
88
|
+
**kwargs,
|
|
89
|
+
):
|
|
90
|
+
if name in self._lora_dict:
|
|
91
|
+
raise ValueError(f"LoRA with name '{name}' already exists.")
|
|
92
|
+
|
|
93
|
+
fused_down = torch.cat([q_down, k_down, v_down], dim=0)
|
|
94
|
+
|
|
95
|
+
fused_rank = 3 * rank
|
|
96
|
+
out_q, out_k = q_up.shape[0], k_up.shape[0]
|
|
97
|
+
fused_up = torch.zeros((self.out_features, fused_rank), device=q_up.device, dtype=q_up.dtype)
|
|
98
|
+
fused_up[:out_q, :rank] = q_up
|
|
99
|
+
fused_up[out_q : out_q + out_k, rank : 2 * rank] = k_up
|
|
100
|
+
fused_up[out_q + out_k :, 2 * rank :] = v_up
|
|
101
|
+
|
|
102
|
+
self._apply_lora_weights(name, fused_down, fused_up, alpha, scale, rank)
|
|
103
|
+
|
|
104
|
+
def modify_scale(self, name: str, scale: float):
|
|
105
|
+
if name not in self._lora_dict:
|
|
106
|
+
raise ValueError(f"LoRA name {name} not found in {self.__class__.__name__}")
|
|
107
|
+
|
|
108
|
+
info = self._lora_dict[name]
|
|
109
|
+
old_scale = info["scale"]
|
|
110
|
+
|
|
111
|
+
if old_scale == scale:
|
|
112
|
+
return
|
|
113
|
+
|
|
114
|
+
if old_scale == 0:
|
|
115
|
+
scale_factor = 0.0
|
|
116
|
+
else:
|
|
117
|
+
scale_factor = scale / old_scale
|
|
118
|
+
|
|
119
|
+
with torch.no_grad():
|
|
120
|
+
lora_rank = info["rank"]
|
|
121
|
+
start_idx = info["start_idx"]
|
|
122
|
+
end_idx = start_idx + lora_rank
|
|
123
|
+
|
|
124
|
+
pu_packed = self.origin_linear.proj_up.data
|
|
125
|
+
pu = unpack_lowrank_weight(pu_packed, down=False)
|
|
126
|
+
pu[:, start_idx:end_idx] *= scale_factor
|
|
127
|
+
|
|
128
|
+
self.origin_linear.proj_up.data = pack_lowrank_weight(pu, down=False)
|
|
129
|
+
|
|
130
|
+
self._lora_dict[name]["scale"] = scale
|
|
131
|
+
|
|
132
|
+
def clear(self, release_all_cpu_memory: bool = False):
|
|
133
|
+
if not self._lora_dict:
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
with torch.no_grad():
|
|
137
|
+
pd_packed = self.origin_linear.proj_down.data
|
|
138
|
+
pu_packed = self.origin_linear.proj_up.data
|
|
139
|
+
|
|
140
|
+
pd = unpack_lowrank_weight(pd_packed, down=True)
|
|
141
|
+
pu = unpack_lowrank_weight(pu_packed, down=False)
|
|
142
|
+
|
|
143
|
+
pd_reset = pd[: self.base_rank, :].clone()
|
|
144
|
+
pu_reset = pu[:, : self.base_rank].clone()
|
|
145
|
+
|
|
146
|
+
self.origin_linear.proj_down.data = pack_lowrank_weight(pd_reset, down=True)
|
|
147
|
+
self.origin_linear.proj_up.data = pack_lowrank_weight(pu_reset, down=False)
|
|
148
|
+
|
|
149
|
+
self.origin_linear.rank = self.base_rank
|
|
150
|
+
|
|
151
|
+
self._lora_dict.clear()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class LoRAAWQW4A16Linear(nn.Module):
|
|
155
|
+
def __init__(self, origin_linear: AWQW4A16Linear):
|
|
156
|
+
super().__init__()
|
|
157
|
+
self.origin_linear = origin_linear
|
|
158
|
+
self._lora_dict = OrderedDict()
|
|
159
|
+
|
|
160
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
161
|
+
quantized_output = self.origin_linear(x)
|
|
162
|
+
|
|
163
|
+
for name, lora in self._lora_dict.items():
|
|
164
|
+
quantized_output += lora(x.to(lora.dtype)).to(quantized_output.dtype)
|
|
165
|
+
|
|
166
|
+
return quantized_output
|
|
167
|
+
|
|
168
|
+
def __getattr__(self, name: str):
|
|
169
|
+
try:
|
|
170
|
+
return super().__getattr__(name)
|
|
171
|
+
except AttributeError:
|
|
172
|
+
return getattr(self.origin_linear, name)
|
|
173
|
+
|
|
174
|
+
def add_lora(
|
|
175
|
+
self,
|
|
176
|
+
name: str,
|
|
177
|
+
scale: float,
|
|
178
|
+
rank: int,
|
|
179
|
+
alpha: int,
|
|
180
|
+
up: torch.Tensor,
|
|
181
|
+
down: torch.Tensor,
|
|
182
|
+
device: str,
|
|
183
|
+
dtype: torch.dtype,
|
|
184
|
+
**kwargs,
|
|
185
|
+
):
|
|
186
|
+
up_linear = nn.Linear(rank, self.out_features, bias=False, device="meta", dtype=dtype).to_empty(device=device)
|
|
187
|
+
down_linear = nn.Linear(self.in_features, rank, bias=False, device="meta", dtype=dtype).to_empty(device=device)
|
|
188
|
+
|
|
189
|
+
up_linear.weight.data = up.reshape(self.out_features, rank)
|
|
190
|
+
down_linear.weight.data = down.reshape(rank, self.in_features)
|
|
191
|
+
|
|
192
|
+
lora = LoRA(scale, rank, alpha, up_linear, down_linear, device, dtype)
|
|
193
|
+
self._lora_dict[name] = lora
|
|
194
|
+
|
|
195
|
+
def modify_scale(self, name: str, scale: float):
|
|
196
|
+
if name not in self._lora_dict:
|
|
197
|
+
raise ValueError(f"LoRA name {name} not found in {self.__class__.__name__}")
|
|
198
|
+
self._lora_dict[name].scale = scale
|
|
199
|
+
|
|
200
|
+
def add_frozen_lora(self, *args, **kwargs):
|
|
201
|
+
raise NotImplementedError("Frozen LoRA (merging weights) is not supported for AWQW4A16Linear.")
|
|
202
|
+
|
|
203
|
+
def clear(self, *args, **kwargs):
|
|
204
|
+
self._lora_dict.clear()
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def patch_nunchaku_model_for_lora(model: nn.Module):
|
|
208
|
+
def _recursive_patch(module: nn.Module):
|
|
209
|
+
for name, child_module in module.named_children():
|
|
210
|
+
replacement = None
|
|
211
|
+
if isinstance(child_module, AWQW4A16Linear):
|
|
212
|
+
replacement = LoRAAWQW4A16Linear(child_module)
|
|
213
|
+
elif isinstance(child_module, SVDQW4A4Linear):
|
|
214
|
+
replacement = LoRASVDQW4A4Linear(child_module)
|
|
215
|
+
|
|
216
|
+
if replacement:
|
|
217
|
+
setattr(module, name, replacement)
|
|
218
|
+
else:
|
|
219
|
+
_recursive_patch(child_module)
|
|
220
|
+
|
|
221
|
+
_recursive_patch(model)
|
|
@@ -3,10 +3,15 @@ import math
|
|
|
3
3
|
import functools
|
|
4
4
|
|
|
5
5
|
from diffsynth_engine.utils.flag import VIDEO_SPARSE_ATTN_AVAILABLE
|
|
6
|
-
from diffsynth_engine.utils.
|
|
6
|
+
from diffsynth_engine.utils.process_group import get_sp_ulysses_group, get_sp_ring_world_size
|
|
7
7
|
|
|
8
|
+
|
|
9
|
+
vsa_core = None
|
|
8
10
|
if VIDEO_SPARSE_ATTN_AVAILABLE:
|
|
9
|
-
|
|
11
|
+
try:
|
|
12
|
+
from vsa import video_sparse_attn as vsa_core
|
|
13
|
+
except Exception:
|
|
14
|
+
vsa_core = None
|
|
10
15
|
|
|
11
16
|
VSA_TILE_SIZE = (4, 4, 4)
|
|
12
17
|
|
|
@@ -171,6 +176,12 @@ def video_sparse_attn(
|
|
|
171
176
|
variable_block_sizes: torch.LongTensor,
|
|
172
177
|
non_pad_index: torch.LongTensor,
|
|
173
178
|
):
|
|
179
|
+
if vsa_core is None:
|
|
180
|
+
raise RuntimeError(
|
|
181
|
+
"Video sparse attention (VSA) is not available. "
|
|
182
|
+
"Please install the 'vsa' package and ensure all its dependencies (including pytest) are installed."
|
|
183
|
+
)
|
|
184
|
+
|
|
174
185
|
q = tile(q, num_tiles, tile_partition_indices, non_pad_index)
|
|
175
186
|
k = tile(k, num_tiles, tile_partition_indices, non_pad_index)
|
|
176
187
|
v = tile(v, num_tiles, tile_partition_indices, non_pad_index)
|
|
@@ -212,7 +223,8 @@ def distributed_video_sparse_attn(
|
|
|
212
223
|
):
|
|
213
224
|
from yunchang.comm.all_to_all import SeqAllToAll4D
|
|
214
225
|
|
|
215
|
-
|
|
226
|
+
ring_world_size = get_sp_ring_world_size()
|
|
227
|
+
assert ring_world_size == 1, "distributed video sparse attention requires ring degree to be 1"
|
|
216
228
|
sp_ulysses_group = get_sp_ulysses_group()
|
|
217
229
|
|
|
218
230
|
q = SeqAllToAll4D.apply(sp_ulysses_group, q, scatter_idx, gather_idx)
|
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from typing import Any, Dict, List, Tuple, Optional
|
|
4
|
+
from einops import rearrange
|
|
5
|
+
|
|
6
|
+
from diffsynth_engine.models.basic import attention as attention_ops
|
|
7
|
+
from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
|
|
8
|
+
from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, RMSNorm
|
|
9
|
+
from diffsynth_engine.models.qwen_image.qwen_image_dit import (
|
|
10
|
+
QwenFeedForward,
|
|
11
|
+
apply_rotary_emb_qwen,
|
|
12
|
+
QwenDoubleStreamAttention,
|
|
13
|
+
QwenImageTransformerBlock,
|
|
14
|
+
QwenImageDiT,
|
|
15
|
+
QwenEmbedRope,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from nunchaku.models.utils import fuse_linears
|
|
19
|
+
from nunchaku.ops.fused import fused_gelu_mlp
|
|
20
|
+
from nunchaku.models.linear import AWQW4A16Linear, SVDQW4A4Linear
|
|
21
|
+
from diffsynth_engine.models.basic.lora import LoRALinear, LoRAConv2d
|
|
22
|
+
from diffsynth_engine.models.basic.lora_nunchaku import LoRASVDQW4A4Linear, LoRAAWQW4A16Linear
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class QwenDoubleStreamAttentionNunchaku(QwenDoubleStreamAttention):
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
dim_a,
|
|
29
|
+
dim_b,
|
|
30
|
+
num_heads,
|
|
31
|
+
head_dim,
|
|
32
|
+
device: str = "cuda:0",
|
|
33
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
34
|
+
nunchaku_rank: int = 32,
|
|
35
|
+
):
|
|
36
|
+
super().__init__(dim_a, dim_b, num_heads, head_dim, device=device, dtype=dtype)
|
|
37
|
+
|
|
38
|
+
to_qkv = fuse_linears([self.to_q, self.to_k, self.to_v])
|
|
39
|
+
self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, rank=nunchaku_rank)
|
|
40
|
+
self.to_out = SVDQW4A4Linear.from_linear(self.to_out, rank=nunchaku_rank)
|
|
41
|
+
|
|
42
|
+
del self.to_q, self.to_k, self.to_v
|
|
43
|
+
|
|
44
|
+
add_qkv_proj = fuse_linears([self.add_q_proj, self.add_k_proj, self.add_v_proj])
|
|
45
|
+
self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, rank=nunchaku_rank)
|
|
46
|
+
self.to_add_out = SVDQW4A4Linear.from_linear(self.to_add_out, rank=nunchaku_rank)
|
|
47
|
+
|
|
48
|
+
del self.add_q_proj, self.add_k_proj, self.add_v_proj
|
|
49
|
+
|
|
50
|
+
def forward(
|
|
51
|
+
self,
|
|
52
|
+
image: torch.FloatTensor,
|
|
53
|
+
text: torch.FloatTensor,
|
|
54
|
+
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
55
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
56
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
57
|
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
58
|
+
img_q, img_k, img_v = self.to_qkv(image).chunk(3, dim=-1)
|
|
59
|
+
txt_q, txt_k, txt_v = self.add_qkv_proj(text).chunk(3, dim=-1)
|
|
60
|
+
|
|
61
|
+
img_q = rearrange(img_q, "b s (h d) -> b s h d", h=self.num_heads)
|
|
62
|
+
img_k = rearrange(img_k, "b s (h d) -> b s h d", h=self.num_heads)
|
|
63
|
+
img_v = rearrange(img_v, "b s (h d) -> b s h d", h=self.num_heads)
|
|
64
|
+
|
|
65
|
+
txt_q = rearrange(txt_q, "b s (h d) -> b s h d", h=self.num_heads)
|
|
66
|
+
txt_k = rearrange(txt_k, "b s (h d) -> b s h d", h=self.num_heads)
|
|
67
|
+
txt_v = rearrange(txt_v, "b s (h d) -> b s h d", h=self.num_heads)
|
|
68
|
+
|
|
69
|
+
img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
|
|
70
|
+
txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
|
|
71
|
+
|
|
72
|
+
if rotary_emb is not None:
|
|
73
|
+
img_freqs, txt_freqs = rotary_emb
|
|
74
|
+
img_q = apply_rotary_emb_qwen(img_q, img_freqs)
|
|
75
|
+
img_k = apply_rotary_emb_qwen(img_k, img_freqs)
|
|
76
|
+
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
|
|
77
|
+
txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
|
|
78
|
+
|
|
79
|
+
joint_q = torch.cat([txt_q, img_q], dim=1)
|
|
80
|
+
joint_k = torch.cat([txt_k, img_k], dim=1)
|
|
81
|
+
joint_v = torch.cat([txt_v, img_v], dim=1)
|
|
82
|
+
|
|
83
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
84
|
+
joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **attn_kwargs)
|
|
85
|
+
|
|
86
|
+
joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype)
|
|
87
|
+
|
|
88
|
+
txt_attn_output = joint_attn_out[:, : text.shape[1], :]
|
|
89
|
+
img_attn_output = joint_attn_out[:, text.shape[1] :, :]
|
|
90
|
+
|
|
91
|
+
img_attn_output = self.to_out(img_attn_output)
|
|
92
|
+
txt_attn_output = self.to_add_out(txt_attn_output)
|
|
93
|
+
|
|
94
|
+
return img_attn_output, txt_attn_output
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class QwenFeedForwardNunchaku(QwenFeedForward):
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
dim: int,
|
|
101
|
+
dim_out: Optional[int] = None,
|
|
102
|
+
dropout: float = 0.0,
|
|
103
|
+
device: str = "cuda:0",
|
|
104
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
105
|
+
rank: int = 32,
|
|
106
|
+
):
|
|
107
|
+
super().__init__(dim, dim_out, dropout, device=device, dtype=dtype)
|
|
108
|
+
self.net[0].proj = SVDQW4A4Linear.from_linear(self.net[0].proj, rank=rank)
|
|
109
|
+
self.net[2] = SVDQW4A4Linear.from_linear(self.net[2], rank=rank)
|
|
110
|
+
self.net[2].act_unsigned = self.net[2].precision != "nvfp4"
|
|
111
|
+
|
|
112
|
+
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
|
113
|
+
return fused_gelu_mlp(hidden_states, self.net[0].proj, self.net[2])
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class QwenImageTransformerBlockNunchaku(QwenImageTransformerBlock):
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
dim: int,
|
|
120
|
+
num_attention_heads: int,
|
|
121
|
+
attention_head_dim: int,
|
|
122
|
+
eps: float = 1e-6,
|
|
123
|
+
device: str = "cuda:0",
|
|
124
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
125
|
+
scale_shift: float = 1.0,
|
|
126
|
+
use_nunchaku_awq: bool = True,
|
|
127
|
+
use_nunchaku_attn: bool = True,
|
|
128
|
+
nunchaku_rank: int = 32,
|
|
129
|
+
):
|
|
130
|
+
super().__init__(dim, num_attention_heads, attention_head_dim, eps, device=device, dtype=dtype)
|
|
131
|
+
|
|
132
|
+
self.use_nunchaku_awq = use_nunchaku_awq
|
|
133
|
+
if use_nunchaku_awq:
|
|
134
|
+
self.img_mod[1] = AWQW4A16Linear.from_linear(self.img_mod[1], rank=nunchaku_rank)
|
|
135
|
+
|
|
136
|
+
if use_nunchaku_attn:
|
|
137
|
+
self.attn = QwenDoubleStreamAttentionNunchaku(
|
|
138
|
+
dim_a=dim,
|
|
139
|
+
dim_b=dim,
|
|
140
|
+
num_heads=num_attention_heads,
|
|
141
|
+
head_dim=attention_head_dim,
|
|
142
|
+
device=device,
|
|
143
|
+
dtype=dtype,
|
|
144
|
+
nunchaku_rank=nunchaku_rank,
|
|
145
|
+
)
|
|
146
|
+
else:
|
|
147
|
+
self.attn = QwenDoubleStreamAttention(
|
|
148
|
+
dim_a=dim,
|
|
149
|
+
dim_b=dim,
|
|
150
|
+
num_heads=num_attention_heads,
|
|
151
|
+
head_dim=attention_head_dim,
|
|
152
|
+
device=device,
|
|
153
|
+
dtype=dtype,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
self.img_mlp = QwenFeedForwardNunchaku(dim=dim, dim_out=dim, device=device, dtype=dtype, rank=nunchaku_rank)
|
|
157
|
+
|
|
158
|
+
if use_nunchaku_awq:
|
|
159
|
+
self.txt_mod[1] = AWQW4A16Linear.from_linear(self.txt_mod[1], rank=nunchaku_rank)
|
|
160
|
+
|
|
161
|
+
self.txt_mlp = QwenFeedForwardNunchaku(dim=dim, dim_out=dim, device=device, dtype=dtype, rank=nunchaku_rank)
|
|
162
|
+
|
|
163
|
+
self.scale_shift = scale_shift
|
|
164
|
+
|
|
165
|
+
def _modulate(self, x, mod_params):
|
|
166
|
+
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
|
167
|
+
if self.use_nunchaku_awq:
|
|
168
|
+
if self.scale_shift != 0:
|
|
169
|
+
scale.add_(self.scale_shift)
|
|
170
|
+
return x * scale.unsqueeze(1) + shift.unsqueeze(1), gate.unsqueeze(1)
|
|
171
|
+
else:
|
|
172
|
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
|
173
|
+
|
|
174
|
+
def forward(
|
|
175
|
+
self,
|
|
176
|
+
image: torch.Tensor,
|
|
177
|
+
text: torch.Tensor,
|
|
178
|
+
temb: torch.Tensor,
|
|
179
|
+
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
180
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
181
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
182
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
183
|
+
if self.use_nunchaku_awq:
|
|
184
|
+
img_mod_params = self.img_mod(temb) # [B, 6*dim]
|
|
185
|
+
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
|
|
186
|
+
|
|
187
|
+
# nunchaku's mod_params is [B, 6*dim] instead of [B, dim*6]
|
|
188
|
+
img_mod_params = (
|
|
189
|
+
img_mod_params.view(img_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(img_mod_params.shape[0], -1)
|
|
190
|
+
)
|
|
191
|
+
txt_mod_params = (
|
|
192
|
+
txt_mod_params.view(txt_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(txt_mod_params.shape[0], -1)
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
img_mod_attn, img_mod_mlp = img_mod_params.chunk(2, dim=-1) # [B, 3*dim] each
|
|
196
|
+
txt_mod_attn, txt_mod_mlp = txt_mod_params.chunk(2, dim=-1) # [B, 3*dim] each
|
|
197
|
+
else:
|
|
198
|
+
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
|
199
|
+
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
|
200
|
+
|
|
201
|
+
img_normed = self.img_norm1(image)
|
|
202
|
+
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn)
|
|
203
|
+
|
|
204
|
+
txt_normed = self.txt_norm1(text)
|
|
205
|
+
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
|
|
206
|
+
|
|
207
|
+
img_attn_out, txt_attn_out = self.attn(
|
|
208
|
+
image=img_modulated,
|
|
209
|
+
text=txt_modulated,
|
|
210
|
+
rotary_emb=rotary_emb,
|
|
211
|
+
attn_mask=attn_mask,
|
|
212
|
+
attn_kwargs=attn_kwargs,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
image = image + img_gate * img_attn_out
|
|
216
|
+
text = text + txt_gate * txt_attn_out
|
|
217
|
+
|
|
218
|
+
img_normed_2 = self.img_norm2(image)
|
|
219
|
+
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp)
|
|
220
|
+
|
|
221
|
+
txt_normed_2 = self.txt_norm2(text)
|
|
222
|
+
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
|
|
223
|
+
|
|
224
|
+
img_mlp_out = self.img_mlp(img_modulated_2)
|
|
225
|
+
txt_mlp_out = self.txt_mlp(txt_modulated_2)
|
|
226
|
+
|
|
227
|
+
image = image + img_gate_2 * img_mlp_out
|
|
228
|
+
text = text + txt_gate_2 * txt_mlp_out
|
|
229
|
+
|
|
230
|
+
return text, image
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class QwenImageDiTNunchaku(QwenImageDiT):
|
|
234
|
+
def __init__(
|
|
235
|
+
self,
|
|
236
|
+
num_layers: int = 60,
|
|
237
|
+
device: str = "cuda:0",
|
|
238
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
239
|
+
use_nunchaku_awq: bool = True,
|
|
240
|
+
use_nunchaku_attn: bool = True,
|
|
241
|
+
nunchaku_rank: int = 32,
|
|
242
|
+
):
|
|
243
|
+
super().__init__()
|
|
244
|
+
|
|
245
|
+
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True, device=device)
|
|
246
|
+
|
|
247
|
+
self.time_text_embed = TimestepEmbeddings(256, 3072, device=device, dtype=dtype)
|
|
248
|
+
|
|
249
|
+
self.txt_norm = RMSNorm(3584, eps=1e-6, device=device, dtype=dtype)
|
|
250
|
+
|
|
251
|
+
self.img_in = nn.Linear(64, 3072, device=device, dtype=dtype)
|
|
252
|
+
self.txt_in = nn.Linear(3584, 3072, device=device, dtype=dtype)
|
|
253
|
+
|
|
254
|
+
self.transformer_blocks = nn.ModuleList(
|
|
255
|
+
[
|
|
256
|
+
QwenImageTransformerBlockNunchaku(
|
|
257
|
+
dim=3072,
|
|
258
|
+
num_attention_heads=24,
|
|
259
|
+
attention_head_dim=128,
|
|
260
|
+
device=device,
|
|
261
|
+
dtype=dtype,
|
|
262
|
+
scale_shift=0,
|
|
263
|
+
use_nunchaku_awq=use_nunchaku_awq,
|
|
264
|
+
use_nunchaku_attn=use_nunchaku_attn,
|
|
265
|
+
nunchaku_rank=nunchaku_rank,
|
|
266
|
+
)
|
|
267
|
+
for _ in range(num_layers)
|
|
268
|
+
]
|
|
269
|
+
)
|
|
270
|
+
self.norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
|
|
271
|
+
self.proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
|
|
272
|
+
|
|
273
|
+
@classmethod
|
|
274
|
+
def from_state_dict(
|
|
275
|
+
cls,
|
|
276
|
+
state_dict: Dict[str, torch.Tensor],
|
|
277
|
+
device: str,
|
|
278
|
+
dtype: torch.dtype,
|
|
279
|
+
num_layers: int = 60,
|
|
280
|
+
use_nunchaku_awq: bool = True,
|
|
281
|
+
use_nunchaku_attn: bool = True,
|
|
282
|
+
nunchaku_rank: int = 32,
|
|
283
|
+
):
|
|
284
|
+
model = cls(
|
|
285
|
+
device="meta",
|
|
286
|
+
dtype=dtype,
|
|
287
|
+
num_layers=num_layers,
|
|
288
|
+
use_nunchaku_awq=use_nunchaku_awq,
|
|
289
|
+
use_nunchaku_attn=use_nunchaku_attn,
|
|
290
|
+
nunchaku_rank=nunchaku_rank,
|
|
291
|
+
)
|
|
292
|
+
model = model.requires_grad_(False)
|
|
293
|
+
model.load_state_dict(state_dict, assign=True)
|
|
294
|
+
model.to(device=device, non_blocking=True)
|
|
295
|
+
return model
|
|
296
|
+
|
|
297
|
+
def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = False):
|
|
298
|
+
fuse_dict = {}
|
|
299
|
+
for args in lora_args:
|
|
300
|
+
key = args["key"]
|
|
301
|
+
if any(suffix in key for suffix in {"add_q_proj", "add_k_proj", "add_v_proj"}):
|
|
302
|
+
fuse_key = f"{key.rsplit('.', 1)[0]}.add_qkv_proj"
|
|
303
|
+
type = key.rsplit(".", 1)[-1].split("_")[1]
|
|
304
|
+
fuse_dict[fuse_key] = fuse_dict.get(fuse_key, {})
|
|
305
|
+
fuse_dict[fuse_key][type] = args
|
|
306
|
+
continue
|
|
307
|
+
|
|
308
|
+
if any(suffix in key for suffix in {"to_q", "to_k", "to_v"}):
|
|
309
|
+
fuse_key = f"{key.rsplit('.', 1)[0]}.to_qkv"
|
|
310
|
+
type = key.rsplit(".", 1)[-1].split("_")[1]
|
|
311
|
+
fuse_dict[fuse_key] = fuse_dict.get(fuse_key, {})
|
|
312
|
+
fuse_dict[fuse_key][type] = args
|
|
313
|
+
continue
|
|
314
|
+
|
|
315
|
+
module = self.get_submodule(key)
|
|
316
|
+
if not isinstance(module, (LoRALinear, LoRAConv2d, LoRASVDQW4A4Linear, LoRAAWQW4A16Linear)):
|
|
317
|
+
raise ValueError(f"Unsupported lora key: {key}")
|
|
318
|
+
|
|
319
|
+
if fused and not isinstance(module, LoRAAWQW4A16Linear):
|
|
320
|
+
module.add_frozen_lora(**args)
|
|
321
|
+
else:
|
|
322
|
+
module.add_lora(**args)
|
|
323
|
+
|
|
324
|
+
for key in fuse_dict.keys():
|
|
325
|
+
module = self.get_submodule(key)
|
|
326
|
+
if not isinstance(module, LoRASVDQW4A4Linear):
|
|
327
|
+
raise ValueError(f"Unsupported lora key: {key}")
|
|
328
|
+
module.add_qkv_lora(
|
|
329
|
+
name=args["name"],
|
|
330
|
+
scale=fuse_dict[key]["q"]["scale"],
|
|
331
|
+
rank=fuse_dict[key]["q"]["rank"],
|
|
332
|
+
alpha=fuse_dict[key]["q"]["alpha"],
|
|
333
|
+
q_up=fuse_dict[key]["q"]["up"],
|
|
334
|
+
q_down=fuse_dict[key]["q"]["down"],
|
|
335
|
+
k_up=fuse_dict[key]["k"]["up"],
|
|
336
|
+
k_down=fuse_dict[key]["k"]["down"],
|
|
337
|
+
v_up=fuse_dict[key]["v"]["up"],
|
|
338
|
+
v_down=fuse_dict[key]["v"]["down"],
|
|
339
|
+
device=fuse_dict[key]["q"]["device"],
|
|
340
|
+
dtype=fuse_dict[key]["q"]["dtype"],
|
|
341
|
+
)
|
{diffsynth_engine-0.6.1.dev27 → diffsynth_engine-0.6.1.dev29}/diffsynth_engine/pipelines/base.py
RENAMED
|
@@ -106,7 +106,8 @@ class BasePipeline:
|
|
|
106
106
|
for key, param in state_dict.items():
|
|
107
107
|
lora_args.append(
|
|
108
108
|
{
|
|
109
|
-
"name":
|
|
109
|
+
"name": lora_path,
|
|
110
|
+
"key": key,
|
|
110
111
|
"scale": lora_scale,
|
|
111
112
|
"rank": param["rank"],
|
|
112
113
|
"alpha": param["alpha"],
|
|
@@ -130,7 +131,10 @@ class BasePipeline:
|
|
|
130
131
|
|
|
131
132
|
@staticmethod
|
|
132
133
|
def load_model_checkpoint(
|
|
133
|
-
checkpoint_path: str | List[str],
|
|
134
|
+
checkpoint_path: str | List[str],
|
|
135
|
+
device: str = "cpu",
|
|
136
|
+
dtype: torch.dtype = torch.float16,
|
|
137
|
+
convert_dtype: bool = True,
|
|
134
138
|
) -> Dict[str, torch.Tensor]:
|
|
135
139
|
if isinstance(checkpoint_path, str):
|
|
136
140
|
checkpoint_path = [checkpoint_path]
|
|
@@ -140,8 +144,11 @@ class BasePipeline:
|
|
|
140
144
|
raise FileNotFoundError(f"{path} is not a file")
|
|
141
145
|
elif path.endswith(".safetensors"):
|
|
142
146
|
state_dict_ = load_file(path, device=device)
|
|
143
|
-
|
|
144
|
-
|
|
147
|
+
if convert_dtype:
|
|
148
|
+
for key, value in state_dict_.items():
|
|
149
|
+
state_dict[key] = value.to(dtype)
|
|
150
|
+
else:
|
|
151
|
+
state_dict.update(state_dict_)
|
|
145
152
|
|
|
146
153
|
elif path.endswith(".gguf"):
|
|
147
154
|
state_dict.update(**load_gguf_checkpoint(path, device=device, dtype=dtype))
|