diffsynth 2.0.10__tar.gz → 2.0.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diffsynth-2.0.10 → diffsynth-2.0.11}/PKG-INFO +1 -1
- {diffsynth-2.0.10 → diffsynth-2.0.11}/README.md +22 -2
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/base_pipeline.py +37 -4
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/loss.py +5 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/parsers.py +6 -0
- diffsynth-2.0.11/diffsynth/diffusion/template.py +203 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/training_module.py +52 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/dinov3_image_encoder.py +8 -4
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux2_dit.py +82 -137
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/siglip2_image_encoder.py +10 -4
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/flux2_image.py +49 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/wan_video.py +1 -1
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth.egg-info/PKG-INFO +1 -1
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth.egg-info/SOURCES.txt +1 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/pyproject.toml +1 -1
- {diffsynth-2.0.10 → diffsynth-2.0.11}/LICENSE +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/configs/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/configs/model_configs.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/configs/vram_management_module_maps.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/attention/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/attention/attention.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/data/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/data/operators.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/data/unified_dataset.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/device/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/device/npu_compatible_device.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/gradient/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/gradient/gradient_checkpoint.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/loader/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/loader/config.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/loader/file.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/loader/model.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/npu_patch/npu_fused_operator.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/vram/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/vram/disk_map.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/vram/initialization.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/vram/layers.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/ddim_scheduler.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/flow_match.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/logger.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/runner.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ace_step_conditioner.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ace_step_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ace_step_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ace_step_tokenizer.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ace_step_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/anima_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ernie_image_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ernie_image_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux2_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux2_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_controlnet.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_infiniteyou.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_ipadapter.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_lora_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_lora_patcher.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_text_encoder_clip.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_text_encoder_t5.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_value_control.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/general_modules.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/joyai_image_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/joyai_image_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/longcat_video_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ltx2_audio_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ltx2_common.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ltx2_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ltx2_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ltx2_upsampler.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ltx2_video_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/model_loader.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/mova_audio_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/mova_audio_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/mova_dual_tower_bridge.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/nexus_gen.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/nexus_gen_ar_model.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/nexus_gen_projector.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/qwen_image_controlnet.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/qwen_image_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/qwen_image_image2lora.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/qwen_image_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/qwen_image_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/sd_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/stable_diffusion_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/stable_diffusion_unet.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/stable_diffusion_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/stable_diffusion_xl_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/stable_diffusion_xl_unet.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/step1x_connector.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/step1x_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_animate_adapter.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_camera_controller.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_dit_s2v.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_image_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_mot.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_motion_controller.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_vace.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wantodance.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wav2vec.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/z_image_controlnet.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/z_image_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/z_image_image2lora.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/z_image_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/ace_step.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/anima_image.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/ernie_image.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/flux_image.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/joyai_image.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/ltx2_audio_video.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/mova_audio_video.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/qwen_image.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/stable_diffusion.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/stable_diffusion_xl.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/z_image.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/controlnet/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/controlnet/annotator.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/controlnet/controlnet_input.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/data/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/data/audio.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/data/audio_video.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/data/media_io_ltx2.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/lora/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/lora/flux.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/lora/general.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/lora/merge.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/lora/reset_rank.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/ses/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/ses/ses.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ace_step_conditioner.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ace_step_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/anima_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/dino_v3.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ernie_image_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/flux2_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/flux_controlnet.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/flux_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/flux_infiniteyou.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/flux_ipadapter.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/flux_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/joyai_image_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ltx2_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ltx2_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ltx2_video_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/nexus_gen.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/nexus_gen_projector.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/qwen_image_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/stable_diffusion_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/stable_diffusion_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/stable_diffusion_xl_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/step1x_connector.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/wan_video_animate_adapter.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/wan_video_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/wan_video_image_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/wan_video_mot.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/wan_video_vace.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/wan_video_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/wans2v_audio_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/z_image_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/z_image_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/xfuser/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/xfuser/xdit_context_parallel.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/version.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth.egg-info/dependency_links.txt +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth.egg-info/requires.txt +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth.egg-info/top_level.txt +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.11}/setup.cfg +0 -0
|
@@ -34,6 +34,15 @@ We believe that a well-developed open-source code framework can lower the thresh
|
|
|
34
34
|
|
|
35
35
|
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
|
36
36
|
|
|
37
|
+
- **April 28, 2026** 🔥 We are excited to announce the release of **Diffusion Templates**, a plugin framework designed for Diffusion models that significantly lowers the barrier to training controllable generative models. Let's explore this cutting-edge technology together!
|
|
38
|
+
* Open-source code: [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
|
|
39
|
+
* Technical report: [arXiv](https://arxiv.org/abs/2604.24351)
|
|
40
|
+
* Project homepage: [GitHub](https://modelscope.github.io/diffusion-templates-web/)
|
|
41
|
+
* Documentation: [English Version](https://diffsynth-studio-doc.readthedocs.io/en/latest/Diffusion_Templates/Introducing_Diffusion_Templates.html) | [Chinese Version](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/Diffusion_Templates/Introducing_Diffusion_Templates.html)
|
|
42
|
+
* Online demo: [ModelScope](https://modelscope.cn/studios/DiffSynth-Studio/Diffusion-Templates)
|
|
43
|
+
* Model collections: [ModelScope](https://modelscope.cn/collections/DiffSynth-Studio/KleinBase4B-Templates) | [ModelScope International](https://modelscope.ai/collections/DiffSynth-Studio/KleinBase4B-Templates) | [HuggingFace](https://huggingface.co/collections/DiffSynth-Studio/kleinbase4b-templates)
|
|
44
|
+
* Datasets: [ModelScope](https://modelscope.cn/collections/DiffSynth-Studio/ImagePulseV2) | [ModelScope International](https://modelscope.ai/collections/DiffSynth-Studio/ImagePulseV2) | [HuggingFace](https://huggingface.co/collections/DiffSynth-Studio/imagepulsev2)
|
|
45
|
+
|
|
37
46
|
- **April 27, 2026** We support ACE-Step-1.5! Support includes text-to-music generation, low VRAM inference, and LoRA training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/ACE-Step.md) and [example code](/examples/ace_step/).
|
|
38
47
|
|
|
39
48
|
- **April 27, 2026**: We have reinstated support for the Stable Diffusion v1.5 and SDXL models, providing academic research support exclusively for these two model types.
|
|
@@ -96,7 +105,7 @@ We believe that a well-developed open-source code framework can lower the thresh
|
|
|
96
105
|
|
|
97
106
|
- **August 20, 2025** We open-sourced the [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) model, improving the editing effect of Qwen-Image-Edit on low-resolution image inputs. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
|
|
98
107
|
|
|
99
|
-
- **August 19, 2025**
|
|
108
|
+
- **August 19, 2025** Qwen-Image-Edit open-sourced, welcome a new member to the image editing model family!
|
|
100
109
|
|
|
101
110
|
- **August 18, 2025** We trained and open-sourced the Qwen-Image inpainting ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py).
|
|
102
111
|
|
|
@@ -112,7 +121,7 @@ We believe that a well-developed open-source code framework can lower the thresh
|
|
|
112
121
|
|
|
113
122
|
- **August 5, 2025** We open-sourced the distilled acceleration model [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) for Qwen-Image, achieving approximately 5x acceleration.
|
|
114
123
|
|
|
115
|
-
- **August 4, 2025**
|
|
124
|
+
- **August 4, 2025** Qwen-Image open-sourced, welcome a new member to the image generation model family!
|
|
116
125
|
|
|
117
126
|
- **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) open-sourced, a text-to-image model focused on aesthetic photography. We provided comprehensive support in a timely manner, including low VRAM layer-by-layer offload, LoRA training, and full training. For more details, please refer to [./examples/flux/](./examples/flux/).
|
|
118
127
|
|
|
@@ -479,6 +488,17 @@ Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/)
|
|
|
479
488
|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|
|
480
489
|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|
|
481
490
|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
|
|
491
|
+
|[DiffSynth-Studio/Template-KleinBase4B-Aesthetic](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Aesthetic)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Aesthetic.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Aesthetic.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Aesthetic.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Aesthetic.py)|-|-|
|
|
492
|
+
|[DiffSynth-Studio/Template-KleinBase4B-Brightness](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Brightness)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Brightness.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Brightness.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Brightness.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Brightness.py)|-|-|
|
|
493
|
+
|[DiffSynth-Studio/Template-KleinBase4B-Age](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Age)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Age.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Age.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Age.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Age.py)|-|-|
|
|
494
|
+
|[DiffSynth-Studio/Template-KleinBase4B-ControlNet](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-ControlNet)|[code](/examples/flux2/model_inference/Template-KleinBase4B-ControlNet.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-ControlNet.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-ControlNet.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-ControlNet.py)|-|-|
|
|
495
|
+
|[DiffSynth-Studio/Template-KleinBase4B-Edit](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Edit)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Edit.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Edit.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Edit.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Edit.py)|-|-|
|
|
496
|
+
|[DiffSynth-Studio/Template-KleinBase4B-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Inpaint)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Inpaint.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Inpaint.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Inpaint.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Inpaint.py)|-|-|
|
|
497
|
+
|[DiffSynth-Studio/Template-KleinBase4B-PandaMeme](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-PandaMeme)|[code](/examples/flux2/model_inference/Template-KleinBase4B-PandaMeme.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-PandaMeme.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-PandaMeme.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-PandaMeme.py)|-|-|
|
|
498
|
+
|[DiffSynth-Studio/Template-KleinBase4B-Sharpness](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Sharpness)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Sharpness.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Sharpness.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Sharpness.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Sharpness.py)|-|-|
|
|
499
|
+
|[DiffSynth-Studio/Template-KleinBase4B-SoftRGB](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-SoftRGB)|[code](/examples/flux2/model_inference/Template-KleinBase4B-SoftRGB.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-SoftRGB.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-SoftRGB.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-SoftRGB.py)|-|-|
|
|
500
|
+
|[DiffSynth-Studio/Template-KleinBase4B-Upscaler](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Upscaler)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Upscaler.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Upscaler.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Upscaler.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Upscaler.py)|-|-|
|
|
501
|
+
|[DiffSynth-Studio/Template-KleinBase4B-ContentRef](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-ContentRef)|[code](/examples/flux2/model_inference/Template-KleinBase4B-ContentRef.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-ContentRef.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-ContentRef.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-ContentRef.py)|-|-|
|
|
482
502
|
|
|
483
503
|
</details>
|
|
484
504
|
|
|
@@ -3,12 +3,13 @@ import torch
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
from einops import repeat, reduce
|
|
5
5
|
from typing import Union
|
|
6
|
-
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
|
|
6
|
+
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type, enable_vram_management
|
|
7
7
|
from ..core.device.npu_compatible_device import get_device_type
|
|
8
8
|
from ..utils.lora import GeneralLoRALoader
|
|
9
9
|
from ..models.model_loader import ModelPool
|
|
10
10
|
from ..utils.controlnet import ControlNetInput
|
|
11
11
|
from ..core.device import get_device_name, IS_NPU_AVAILABLE
|
|
12
|
+
from .template import load_template_model, load_template_data_processor
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class PipelineUnit:
|
|
@@ -319,14 +320,21 @@ class BasePipeline(torch.nn.Module):
|
|
|
319
320
|
|
|
320
321
|
|
|
321
322
|
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
|
323
|
+
# Positive side forward
|
|
322
324
|
if inputs_shared.get("positive_only_lora", None) is not None:
|
|
323
|
-
self.clear_lora(verbose=0)
|
|
324
325
|
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
|
|
325
326
|
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
|
327
|
+
if inputs_shared.get("positive_only_lora", None) is not None:
|
|
328
|
+
self.clear_lora(verbose=0)
|
|
329
|
+
|
|
326
330
|
if cfg_scale != 1.0:
|
|
327
|
-
|
|
328
|
-
|
|
331
|
+
# Negative side forward
|
|
332
|
+
if inputs_shared.get("negative_only_lora", None) is not None:
|
|
333
|
+
self.load_lora(self.dit, state_dict=inputs_shared["negative_only_lora"], verbose=0)
|
|
329
334
|
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
|
335
|
+
if inputs_shared.get("negative_only_lora", None) is not None:
|
|
336
|
+
self.clear_lora(verbose=0)
|
|
337
|
+
|
|
330
338
|
if isinstance(noise_pred_posi, tuple):
|
|
331
339
|
# Separately handling different output types of latents, eg. video and audio latents.
|
|
332
340
|
noise_pred = tuple(
|
|
@@ -338,6 +346,31 @@ class BasePipeline(torch.nn.Module):
|
|
|
338
346
|
else:
|
|
339
347
|
noise_pred = noise_pred_posi
|
|
340
348
|
return noise_pred
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def load_training_template_model(self, model_config: ModelConfig = None):
|
|
352
|
+
if model_config is not None:
|
|
353
|
+
model_config.download_if_necessary()
|
|
354
|
+
self.template_model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
|
355
|
+
self.template_data_processor = load_template_data_processor(model_config.path)()
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def enable_lora_hot_loading(self, model: torch.nn.Module):
|
|
359
|
+
if hasattr(model, "vram_management_enabled") and getattr(model, "vram_management_enabled"):
|
|
360
|
+
return model
|
|
361
|
+
module_map = {torch.nn.Linear: AutoWrappedLinear}
|
|
362
|
+
vram_config = {
|
|
363
|
+
"offload_dtype": self.torch_dtype,
|
|
364
|
+
"offload_device": self.device,
|
|
365
|
+
"onload_dtype": self.torch_dtype,
|
|
366
|
+
"onload_device": self.device,
|
|
367
|
+
"preparing_dtype": self.torch_dtype,
|
|
368
|
+
"preparing_device": self.device,
|
|
369
|
+
"computation_dtype": self.torch_dtype,
|
|
370
|
+
"computation_device": self.device,
|
|
371
|
+
}
|
|
372
|
+
model = enable_vram_management(model, module_map, vram_config=vram_config)
|
|
373
|
+
return model
|
|
341
374
|
|
|
342
375
|
def compile_pipeline(self, mode: str = "default", dynamic: bool = True, fullgraph: bool = False, compile_models: list = None, **kwargs):
|
|
343
376
|
"""
|
|
@@ -3,6 +3,11 @@ import torch
|
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
|
6
|
+
if "lora" in inputs:
|
|
7
|
+
# Image-to-LoRA models need to load lora here.
|
|
8
|
+
pipe.clear_lora(verbose=0)
|
|
9
|
+
pipe.load_lora(pipe.dit, state_dict=inputs["lora"], hotload=True, verbose=0)
|
|
10
|
+
|
|
6
11
|
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
|
7
12
|
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
|
8
13
|
|
|
@@ -60,6 +60,11 @@ def add_gradient_config(parser: argparse.ArgumentParser):
|
|
|
60
60
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
|
61
61
|
return parser
|
|
62
62
|
|
|
63
|
+
def add_template_model_config(parser: argparse.ArgumentParser):
|
|
64
|
+
parser.add_argument("--template_model_id_or_path", type=str, default=None, help="Model ID of path of template models.")
|
|
65
|
+
parser.add_argument("--enable_lora_hot_loading", default=False, action="store_true", help="Whether to enable LoRA hot-loading. Only available for image-to-lora models.")
|
|
66
|
+
return parser
|
|
67
|
+
|
|
63
68
|
def add_general_config(parser: argparse.ArgumentParser):
|
|
64
69
|
parser = add_dataset_base_config(parser)
|
|
65
70
|
parser = add_model_config(parser)
|
|
@@ -67,4 +72,5 @@ def add_general_config(parser: argparse.ArgumentParser):
|
|
|
67
72
|
parser = add_output_config(parser)
|
|
68
73
|
parser = add_lora_config(parser)
|
|
69
74
|
parser = add_gradient_config(parser)
|
|
75
|
+
parser = add_template_model_config(parser)
|
|
70
76
|
return parser
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
import torch, os, importlib, warnings, json, inspect
|
|
2
|
+
from typing import Dict, List, Tuple, Union
|
|
3
|
+
from ..core import ModelConfig, load_model
|
|
4
|
+
from ..core.device.npu_compatible_device import get_device_type
|
|
5
|
+
from ..utils.lora.merge import merge_lora
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
KVCache = Dict[str, Tuple[torch.Tensor, torch.Tensor]]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TemplateModel(torch.nn.Module):
|
|
12
|
+
def __init__(self):
|
|
13
|
+
super().__init__()
|
|
14
|
+
|
|
15
|
+
@torch.no_grad()
|
|
16
|
+
def process_inputs(self, **kwargs):
|
|
17
|
+
return {}
|
|
18
|
+
|
|
19
|
+
def forward(self, **kwargs):
|
|
20
|
+
raise NotImplementedError()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def check_template_model_format(model):
|
|
24
|
+
if not hasattr(model, "process_inputs"):
|
|
25
|
+
raise NotImplementedError("`process_inputs` is not implemented in the Template model.")
|
|
26
|
+
if "kwargs" not in inspect.signature(model.process_inputs).parameters:
|
|
27
|
+
raise NotImplementedError("`**kwargs` is not included in `process_inputs`.")
|
|
28
|
+
if not hasattr(model, "forward"):
|
|
29
|
+
raise NotImplementedError("`forward` is not implemented in the Template model.")
|
|
30
|
+
if "kwargs" not in inspect.signature(model.forward).parameters:
|
|
31
|
+
raise NotImplementedError("`**kwargs` is not included in `forward`.")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def load_template_model(path, torch_dtype=torch.bfloat16, device="cuda", verbose=1):
|
|
35
|
+
spec = importlib.util.spec_from_file_location("template_model", os.path.join(path, "model.py"))
|
|
36
|
+
module = importlib.util.module_from_spec(spec)
|
|
37
|
+
spec.loader.exec_module(module)
|
|
38
|
+
template_model_path = getattr(module, 'TEMPLATE_MODEL_PATH') if hasattr(module, 'TEMPLATE_MODEL_PATH') else None
|
|
39
|
+
if template_model_path is not None:
|
|
40
|
+
# With `TEMPLATE_MODEL_PATH`, a pretrained model will be loaded.
|
|
41
|
+
model = load_model(
|
|
42
|
+
model_class=getattr(module, 'TEMPLATE_MODEL'),
|
|
43
|
+
config=getattr(module, 'TEMPLATE_MODEL_CONFIG') if hasattr(module, 'TEMPLATE_MODEL_CONFIG') else None,
|
|
44
|
+
path=os.path.join(path, getattr(module, 'TEMPLATE_MODEL_PATH')),
|
|
45
|
+
torch_dtype=torch_dtype,
|
|
46
|
+
device=device,
|
|
47
|
+
)
|
|
48
|
+
else:
|
|
49
|
+
# Without `TEMPLATE_MODEL_PATH`, a randomly initialized model or a non-model module will be loaded.
|
|
50
|
+
model = module.TEMPLATE_MODEL()
|
|
51
|
+
if hasattr(model, "to"):
|
|
52
|
+
model = model.to(dtype=torch_dtype, device=device)
|
|
53
|
+
if hasattr(model, "eval"):
|
|
54
|
+
model = model.eval()
|
|
55
|
+
check_template_model_format(model)
|
|
56
|
+
if verbose > 0:
|
|
57
|
+
metadata = {
|
|
58
|
+
"model_architecture": getattr(module, 'TEMPLATE_MODEL').__name__,
|
|
59
|
+
"code_path": os.path.join(path, "model.py"),
|
|
60
|
+
"weight_path": template_model_path,
|
|
61
|
+
}
|
|
62
|
+
print(f"Template model loaded: {json.dumps(metadata, indent=4)}")
|
|
63
|
+
return model
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def load_template_data_processor(path):
|
|
67
|
+
spec = importlib.util.spec_from_file_location("template_model", os.path.join(path, "model.py"))
|
|
68
|
+
module = importlib.util.module_from_spec(spec)
|
|
69
|
+
spec.loader.exec_module(module)
|
|
70
|
+
if hasattr(module, 'TEMPLATE_DATA_PROCESSOR'):
|
|
71
|
+
processor = getattr(module, 'TEMPLATE_DATA_PROCESSOR')
|
|
72
|
+
return processor
|
|
73
|
+
else:
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class TemplatePipeline(torch.nn.Module):
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
torch_dtype: torch.dtype = torch.bfloat16,
|
|
81
|
+
device: Union[str, torch.device] = get_device_type(),
|
|
82
|
+
model_configs: list[ModelConfig] = [],
|
|
83
|
+
lazy_loading: bool = False,
|
|
84
|
+
):
|
|
85
|
+
super().__init__()
|
|
86
|
+
self.torch_dtype = torch_dtype
|
|
87
|
+
self.device = device
|
|
88
|
+
self.model_configs = model_configs
|
|
89
|
+
self.lazy_loading = lazy_loading
|
|
90
|
+
if lazy_loading:
|
|
91
|
+
for model_config in model_configs:
|
|
92
|
+
TemplatePipeline.check_vram_config(model_config)
|
|
93
|
+
model_config.download_if_necessary()
|
|
94
|
+
self.models = None
|
|
95
|
+
else:
|
|
96
|
+
models = []
|
|
97
|
+
for model_config in model_configs:
|
|
98
|
+
TemplatePipeline.check_vram_config(model_config)
|
|
99
|
+
model_config.download_if_necessary()
|
|
100
|
+
model = load_template_model(model_config.path, torch_dtype=torch_dtype, device=device)
|
|
101
|
+
models.append(model)
|
|
102
|
+
self.models = torch.nn.ModuleList(models)
|
|
103
|
+
|
|
104
|
+
def merge_kv_cache(self, kv_cache_list: List[KVCache]) -> KVCache:
|
|
105
|
+
names = {}
|
|
106
|
+
for kv_cache in kv_cache_list:
|
|
107
|
+
for name in kv_cache:
|
|
108
|
+
names[name] = None
|
|
109
|
+
kv_cache_merged = {}
|
|
110
|
+
for name in names:
|
|
111
|
+
kv_list = [kv_cache.get(name) for kv_cache in kv_cache_list]
|
|
112
|
+
kv_list = [kv for kv in kv_list if kv is not None]
|
|
113
|
+
if len(kv_list) > 0:
|
|
114
|
+
k = torch.concat([kv[0] for kv in kv_list], dim=1)
|
|
115
|
+
v = torch.concat([kv[1] for kv in kv_list], dim=1)
|
|
116
|
+
kv_cache_merged[name] = (k, v)
|
|
117
|
+
return kv_cache_merged
|
|
118
|
+
|
|
119
|
+
def merge_template_cache(self, template_cache_list):
|
|
120
|
+
params = sorted(list(set(sum([list(template_cache.keys()) for template_cache in template_cache_list], []))))
|
|
121
|
+
template_cache_merged = {}
|
|
122
|
+
for param in params:
|
|
123
|
+
data = [template_cache[param] for template_cache in template_cache_list if param in template_cache]
|
|
124
|
+
if param == "kv_cache":
|
|
125
|
+
data = self.merge_kv_cache(data)
|
|
126
|
+
elif param == "lora":
|
|
127
|
+
data = merge_lora(data)
|
|
128
|
+
elif len(data) == 1:
|
|
129
|
+
data = data[0]
|
|
130
|
+
else:
|
|
131
|
+
print(f"Conflict detected: `{param}` appears in the outputs of multiple Template models. Only the first one will be retained.")
|
|
132
|
+
data = data[0]
|
|
133
|
+
template_cache_merged[param] = data
|
|
134
|
+
return template_cache_merged
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def check_vram_config(model_config: ModelConfig):
|
|
138
|
+
params = [
|
|
139
|
+
model_config.offload_device, model_config.offload_dtype,
|
|
140
|
+
model_config.onload_device, model_config.onload_dtype,
|
|
141
|
+
model_config.preparing_device, model_config.preparing_dtype,
|
|
142
|
+
model_config.computation_device, model_config.computation_dtype,
|
|
143
|
+
]
|
|
144
|
+
for param in params:
|
|
145
|
+
if param is not None:
|
|
146
|
+
warnings.warn("TemplatePipeline doesn't support VRAM management. VRAM config will be ignored.")
|
|
147
|
+
|
|
148
|
+
@staticmethod
|
|
149
|
+
def from_pretrained(
|
|
150
|
+
torch_dtype: torch.dtype = torch.bfloat16,
|
|
151
|
+
device: Union[str, torch.device] = get_device_type(),
|
|
152
|
+
model_configs: list[ModelConfig] = [],
|
|
153
|
+
lazy_loading: bool = False,
|
|
154
|
+
):
|
|
155
|
+
pipe = TemplatePipeline(torch_dtype, device, model_configs, lazy_loading)
|
|
156
|
+
return pipe
|
|
157
|
+
|
|
158
|
+
def fetch_model(self, model_id):
|
|
159
|
+
if self.lazy_loading:
|
|
160
|
+
model_config = self.model_configs[model_id]
|
|
161
|
+
model_config.download_if_necessary()
|
|
162
|
+
model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
|
163
|
+
else:
|
|
164
|
+
model = self.models[model_id]
|
|
165
|
+
return model
|
|
166
|
+
|
|
167
|
+
def call_single_side(self, pipe=None, inputs: List[Dict] = None):
|
|
168
|
+
model = None
|
|
169
|
+
onload_model_id = -1
|
|
170
|
+
template_cache = []
|
|
171
|
+
for i in inputs:
|
|
172
|
+
model_id = i.get("model_id", 0)
|
|
173
|
+
if model_id != onload_model_id:
|
|
174
|
+
model = self.fetch_model(model_id)
|
|
175
|
+
onload_model_id = model_id
|
|
176
|
+
cache = model.process_inputs(pipe=pipe, **i)
|
|
177
|
+
cache = model.forward(pipe=pipe, **cache)
|
|
178
|
+
template_cache.append(cache)
|
|
179
|
+
template_cache = self.merge_template_cache(template_cache)
|
|
180
|
+
return template_cache
|
|
181
|
+
|
|
182
|
+
@torch.no_grad()
|
|
183
|
+
def __call__(
|
|
184
|
+
self,
|
|
185
|
+
pipe=None,
|
|
186
|
+
template_inputs: List[Dict] = None,
|
|
187
|
+
negative_template_inputs: List[Dict] = None,
|
|
188
|
+
**kwargs,
|
|
189
|
+
):
|
|
190
|
+
template_cache = self.call_single_side(pipe=pipe, inputs=template_inputs or [])
|
|
191
|
+
negative_template_cache = self.call_single_side(pipe=pipe, inputs=negative_template_inputs or [])
|
|
192
|
+
required_params = list(inspect.signature(pipe.__call__).parameters.keys())
|
|
193
|
+
for param in template_cache:
|
|
194
|
+
if param in required_params:
|
|
195
|
+
kwargs[param] = template_cache[param]
|
|
196
|
+
else:
|
|
197
|
+
print(f"`{param}` is not included in the inputs of `{pipe.__class__.__name__}`. This parameter will be ignored.")
|
|
198
|
+
for param in negative_template_cache:
|
|
199
|
+
if "negative_" + param in required_params:
|
|
200
|
+
kwargs["negative_" + param] = negative_template_cache[param]
|
|
201
|
+
else:
|
|
202
|
+
print(f"`{'negative_' + param}` is not included in the inputs of `{pipe.__class__.__name__}`. This parameter will be ignored.")
|
|
203
|
+
return pipe(**kwargs)
|
|
@@ -6,6 +6,7 @@ from peft import LoraConfig, inject_adapter_in_model
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class GeneralUnit_RemoveCache(PipelineUnit):
|
|
9
|
+
# Only used for training
|
|
9
10
|
def __init__(self, required_params=tuple(), force_remove_params_shared=tuple(), force_remove_params_posi=tuple(), force_remove_params_nega=tuple()):
|
|
10
11
|
super().__init__(take_over=True)
|
|
11
12
|
self.required_params = required_params
|
|
@@ -27,6 +28,47 @@ class GeneralUnit_RemoveCache(PipelineUnit):
|
|
|
27
28
|
return inputs_shared, inputs_posi, inputs_nega
|
|
28
29
|
|
|
29
30
|
|
|
31
|
+
class GeneralUnit_TemplateProcessInputs(PipelineUnit):
|
|
32
|
+
# Only used for training
|
|
33
|
+
def __init__(self, data_processor):
|
|
34
|
+
super().__init__(
|
|
35
|
+
input_params=("template_inputs",),
|
|
36
|
+
output_params=("template_inputs",),
|
|
37
|
+
)
|
|
38
|
+
self.data_processor = data_processor
|
|
39
|
+
|
|
40
|
+
def process(self, pipe, template_inputs):
|
|
41
|
+
if not hasattr(pipe, "template_model") or template_inputs is None:
|
|
42
|
+
return {}
|
|
43
|
+
if self.data_processor is not None:
|
|
44
|
+
template_inputs = self.data_processor(**template_inputs)
|
|
45
|
+
template_inputs = pipe.template_model.process_inputs(pipe=pipe, **template_inputs)
|
|
46
|
+
return {"template_inputs": template_inputs}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class GeneralUnit_TemplateForward(PipelineUnit):
|
|
50
|
+
# Only used for training
|
|
51
|
+
def __init__(self, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
|
|
52
|
+
super().__init__(
|
|
53
|
+
input_params=("template_inputs",),
|
|
54
|
+
output_params=("kv_cache",),
|
|
55
|
+
onload_model_names=("template_model",)
|
|
56
|
+
)
|
|
57
|
+
self.use_gradient_checkpointing = use_gradient_checkpointing
|
|
58
|
+
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
|
59
|
+
|
|
60
|
+
def process(self, pipe, template_inputs):
|
|
61
|
+
if not hasattr(pipe, "template_model") or template_inputs is None:
|
|
62
|
+
return {}
|
|
63
|
+
template_cache = pipe.template_model.forward(
|
|
64
|
+
**template_inputs,
|
|
65
|
+
pipe=pipe,
|
|
66
|
+
use_gradient_checkpointing=self.use_gradient_checkpointing,
|
|
67
|
+
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload,
|
|
68
|
+
)
|
|
69
|
+
return template_cache
|
|
70
|
+
|
|
71
|
+
|
|
30
72
|
class DiffusionTrainingModule(torch.nn.Module):
|
|
31
73
|
def __init__(self):
|
|
32
74
|
super().__init__()
|
|
@@ -209,6 +251,16 @@ class DiffusionTrainingModule(torch.nn.Module):
|
|
|
209
251
|
else:
|
|
210
252
|
lora_target_modules = lora_target_modules.split(",")
|
|
211
253
|
return lora_target_modules
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def load_training_template_model(self, pipe, path_or_model_id, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
|
|
257
|
+
if path_or_model_id is None:
|
|
258
|
+
return pipe
|
|
259
|
+
model_config = self.parse_path_or_model_id(path_or_model_id)
|
|
260
|
+
pipe.load_training_template_model(model_config)
|
|
261
|
+
pipe.units.append(GeneralUnit_TemplateProcessInputs(pipe.template_data_processor))
|
|
262
|
+
pipe.units.append(GeneralUnit_TemplateForward(use_gradient_checkpointing, use_gradient_checkpointing_offload))
|
|
263
|
+
return pipe
|
|
212
264
|
|
|
213
265
|
|
|
214
266
|
def switch_pipe_to_training_mode(
|
|
@@ -1,12 +1,16 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
import
|
|
4
|
-
|
|
1
|
+
import torch, warnings
|
|
2
|
+
try:
|
|
3
|
+
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTModel
|
|
4
|
+
except:
|
|
5
|
+
warnings.warn(f"Cannot import `DINOv3ViTModel`. `DINOv3ImageEncoder` is not available. Please update `transformers` by `pip install -U transformers`.")
|
|
6
|
+
DINOv3ViTModel = torch.nn.Module
|
|
5
7
|
from ..core.device.npu_compatible_device import get_device_type
|
|
6
8
|
|
|
7
9
|
|
|
8
10
|
class DINOv3ImageEncoder(DINOv3ViTModel):
|
|
9
11
|
def __init__(self):
|
|
12
|
+
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
|
|
13
|
+
from transformers import DINOv3ViTImageProcessor
|
|
10
14
|
config = DINOv3ViTConfig(
|
|
11
15
|
architectures = [
|
|
12
16
|
"DINOv3ViTModel"
|