diffsynth 2.0.10__tar.gz → 2.0.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diffsynth-2.0.10 → diffsynth-2.0.12}/PKG-INFO +9 -1
- {diffsynth-2.0.10 → diffsynth-2.0.12}/README.md +90 -4
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/configs/model_configs.py +11 -16
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/configs/vram_management_module_maps.py +9 -1
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/__init__.py +1 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/attention/attention.py +8 -8
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/data/operators.py +10 -6
- diffsynth-2.0.12/diffsynth/core/offload_training/__init__.py +1 -0
- diffsynth-2.0.12/diffsynth/core/offload_training/manager.py +177 -0
- diffsynth-2.0.12/diffsynth/core/offload_training/memory_buffer.py +136 -0
- diffsynth-2.0.12/diffsynth/core/offload_training/offloader.py +71 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/__init__.py +1 -1
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/base_pipeline.py +37 -4
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/flow_match.py +62 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/loss.py +6 -1
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/parsers.py +13 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/runner.py +32 -6
- diffsynth-2.0.12/diffsynth/diffusion/template.py +203 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/training_module.py +52 -0
- diffsynth-2.0.12/diffsynth/models/ace_step_residual_fsq.py +569 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ace_step_tokenizer.py +2 -4
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/dinov3_image_encoder.py +8 -4
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux2_dit.py +82 -137
- diffsynth-2.0.12/diffsynth/models/hidream_common.py +373 -0
- diffsynth-2.0.12/diffsynth/models/hidream_o1_image_dit.py +1910 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ltx2_audio_vae.py +4 -2
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/siglip2_image_encoder.py +10 -4
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/ace_step.py +15 -14
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/flux2_image.py +49 -0
- diffsynth-2.0.12/diffsynth/pipelines/hidream_o1_image.py +420 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/qwen_image.py +1 -1
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/wan_video.py +1 -1
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ltx2_text_encoder.py +5 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth.egg-info/PKG-INFO +9 -1
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth.egg-info/SOURCES.txt +9 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth.egg-info/requires.txt +9 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/pyproject.toml +11 -2
- {diffsynth-2.0.10 → diffsynth-2.0.12}/LICENSE +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/configs/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/attention/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/data/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/data/unified_dataset.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/device/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/device/npu_compatible_device.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/gradient/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/gradient/gradient_checkpoint.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/loader/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/loader/config.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/loader/file.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/loader/model.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/npu_patch/npu_fused_operator.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/vram/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/vram/disk_map.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/vram/initialization.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/vram/layers.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/ddim_scheduler.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/logger.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ace_step_conditioner.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ace_step_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ace_step_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ace_step_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/anima_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ernie_image_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ernie_image_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux2_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux2_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_controlnet.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_infiniteyou.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_ipadapter.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_lora_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_lora_patcher.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_text_encoder_clip.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_text_encoder_t5.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_value_control.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/general_modules.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/joyai_image_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/joyai_image_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/longcat_video_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ltx2_common.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ltx2_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ltx2_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ltx2_upsampler.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ltx2_video_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/model_loader.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/mova_audio_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/mova_audio_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/mova_dual_tower_bridge.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/nexus_gen.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/nexus_gen_ar_model.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/nexus_gen_projector.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/qwen_image_controlnet.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/qwen_image_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/qwen_image_image2lora.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/qwen_image_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/qwen_image_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/sd_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/stable_diffusion_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/stable_diffusion_unet.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/stable_diffusion_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/stable_diffusion_xl_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/stable_diffusion_xl_unet.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/step1x_connector.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/step1x_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_animate_adapter.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_camera_controller.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_dit_s2v.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_image_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_mot.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_motion_controller.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_vace.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wantodance.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wav2vec.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/z_image_controlnet.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/z_image_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/z_image_image2lora.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/z_image_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/anima_image.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/ernie_image.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/flux_image.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/joyai_image.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/ltx2_audio_video.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/mova_audio_video.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/stable_diffusion.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/stable_diffusion_xl.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/z_image.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/controlnet/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/controlnet/annotator.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/controlnet/controlnet_input.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/data/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/data/audio.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/data/audio_video.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/data/media_io_ltx2.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/lora/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/lora/flux.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/lora/general.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/lora/merge.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/lora/reset_rank.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/ses/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/ses/ses.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ace_step_conditioner.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ace_step_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/anima_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/dino_v3.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ernie_image_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux2_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_controlnet.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_infiniteyou.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_ipadapter.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/joyai_image_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ltx2_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ltx2_video_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/nexus_gen.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/nexus_gen_projector.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/qwen_image_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/stable_diffusion_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/stable_diffusion_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/stable_diffusion_xl_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/step1x_connector.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_animate_adapter.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_image_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_mot.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_vace.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_vae.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wans2v_audio_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/z_image_dit.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/z_image_text_encoder.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/xfuser/__init__.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/xfuser/xdit_context_parallel.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/version.py +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth.egg-info/dependency_links.txt +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth.egg-info/top_level.txt +0 -0
- {diffsynth-2.0.10 → diffsynth-2.0.12}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: diffsynth
|
|
3
|
-
Version: 2.0.
|
|
3
|
+
Version: 2.0.12
|
|
4
4
|
Summary: Enjoy the magic of Diffusion models!
|
|
5
5
|
Author: ModelScope Team
|
|
6
6
|
License: Apache-2.0
|
|
@@ -33,6 +33,14 @@ Requires-Dist: torch==2.7.1+cpu; extra == "npu"
|
|
|
33
33
|
Requires-Dist: torch-npu==2.7.1; extra == "npu"
|
|
34
34
|
Requires-Dist: torchvision==0.22.1+cpu; extra == "npu"
|
|
35
35
|
Provides-Extra: audio
|
|
36
|
+
Requires-Dist: av; extra == "audio"
|
|
36
37
|
Requires-Dist: torchaudio; extra == "audio"
|
|
37
38
|
Requires-Dist: torchcodec; extra == "audio"
|
|
39
|
+
Requires-Dist: librosa; extra == "audio"
|
|
40
|
+
Provides-Extra: all
|
|
41
|
+
Requires-Dist: av; extra == "all"
|
|
42
|
+
Requires-Dist: torchaudio; extra == "all"
|
|
43
|
+
Requires-Dist: torchcodec; extra == "all"
|
|
44
|
+
Requires-Dist: librosa; extra == "all"
|
|
45
|
+
Requires-Dist: streamlit; extra == "all"
|
|
38
46
|
Dynamic: license-file
|
|
@@ -34,6 +34,19 @@ We believe that a well-developed open-source code framework can lower the thresh
|
|
|
34
34
|
|
|
35
35
|
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
|
36
36
|
|
|
37
|
+
- **May 18, 2026** Added **CPU Offload Training** support. By moving model weights layer-by-layer between CPU and GPU, it significantly reduces GPU VRAM usage during training, enabling LoRA training of large models even on consumer-grade GPUs, compatible with all models. Simply add `--enable_model_cpu_offload` to your training command to enable (currently supports single-GPU training only). For details, see the [documentation](/docs/en/Training/Offload_Training.md).
|
|
38
|
+
|
|
39
|
+
- **May 14, 2026** HiDream-O1-Image open-sourced, welcome a new member to the image model family! Support includes text-to-image generation, image editing, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/HiDream-O1-Image.md) and [example code](/examples/hidream_o1_image/).
|
|
40
|
+
|
|
41
|
+
- **April 28, 2026** 🔥 We are excited to announce the release of **Diffusion Templates**, a plugin framework designed for Diffusion models that significantly lowers the barrier to training controllable generative models. Let's explore this cutting-edge technology together!
|
|
42
|
+
* Open-source code: [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
|
|
43
|
+
* Technical report: [arXiv](https://arxiv.org/abs/2604.24351)
|
|
44
|
+
* Project homepage: [GitHub](https://modelscope.github.io/diffusion-templates-web/)
|
|
45
|
+
* Documentation: [English Version](https://diffsynth-studio-doc.readthedocs.io/en/latest/Diffusion_Templates/Introducing_Diffusion_Templates.html) | [Chinese Version](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/Diffusion_Templates/Introducing_Diffusion_Templates.html)
|
|
46
|
+
* Online demo: [ModelScope](https://modelscope.cn/studios/DiffSynth-Studio/Diffusion-Templates)
|
|
47
|
+
* Model collections: [ModelScope](https://modelscope.cn/collections/DiffSynth-Studio/KleinBase4B-Templates) | [ModelScope International](https://modelscope.ai/collections/DiffSynth-Studio/KleinBase4B-Templates) | [HuggingFace](https://huggingface.co/collections/DiffSynth-Studio/kleinbase4b-templates)
|
|
48
|
+
* Datasets: [ModelScope](https://modelscope.cn/collections/DiffSynth-Studio/ImagePulseV2) | [ModelScope International](https://modelscope.ai/collections/DiffSynth-Studio/ImagePulseV2) | [HuggingFace](https://huggingface.co/collections/DiffSynth-Studio/imagepulsev2)
|
|
49
|
+
|
|
37
50
|
- **April 27, 2026** We support ACE-Step-1.5! Support includes text-to-music generation, low VRAM inference, and LoRA training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/ACE-Step.md) and [example code](/examples/ace_step/).
|
|
38
51
|
|
|
39
52
|
- **April 27, 2026**: We have reinstated support for the Stable Diffusion v1.5 and SDXL models, providing academic research support exclusively for these two model types.
|
|
@@ -96,7 +109,7 @@ We believe that a well-developed open-source code framework can lower the thresh
|
|
|
96
109
|
|
|
97
110
|
- **August 20, 2025** We open-sourced the [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) model, improving the editing effect of Qwen-Image-Edit on low-resolution image inputs. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
|
|
98
111
|
|
|
99
|
-
- **August 19, 2025**
|
|
112
|
+
- **August 19, 2025** Qwen-Image-Edit open-sourced, welcome a new member to the image editing model family!
|
|
100
113
|
|
|
101
114
|
- **August 18, 2025** We trained and open-sourced the Qwen-Image inpainting ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py).
|
|
102
115
|
|
|
@@ -112,7 +125,7 @@ We believe that a well-developed open-source code framework can lower the thresh
|
|
|
112
125
|
|
|
113
126
|
- **August 5, 2025** We open-sourced the distilled acceleration model [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) for Qwen-Image, achieving approximately 5x acceleration.
|
|
114
127
|
|
|
115
|
-
- **August 4, 2025**
|
|
128
|
+
- **August 4, 2025** Qwen-Image open-sourced, welcome a new member to the image generation model family!
|
|
116
129
|
|
|
117
130
|
- **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) open-sourced, a text-to-image model focused on aesthetic photography. We provided comprehensive support in a timely manner, including low VRAM layer-by-layer offload, LoRA training, and full training. For more details, please refer to [./examples/flux/](./examples/flux/).
|
|
118
131
|
|
|
@@ -479,6 +492,17 @@ Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/)
|
|
|
479
492
|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|
|
480
493
|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|
|
481
494
|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
|
|
495
|
+
|[DiffSynth-Studio/Template-KleinBase4B-Aesthetic](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Aesthetic)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Aesthetic.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Aesthetic.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Aesthetic.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Aesthetic.py)|-|-|
|
|
496
|
+
|[DiffSynth-Studio/Template-KleinBase4B-Brightness](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Brightness)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Brightness.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Brightness.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Brightness.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Brightness.py)|-|-|
|
|
497
|
+
|[DiffSynth-Studio/Template-KleinBase4B-Age](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Age)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Age.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Age.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Age.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Age.py)|-|-|
|
|
498
|
+
|[DiffSynth-Studio/Template-KleinBase4B-ControlNet](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-ControlNet)|[code](/examples/flux2/model_inference/Template-KleinBase4B-ControlNet.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-ControlNet.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-ControlNet.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-ControlNet.py)|-|-|
|
|
499
|
+
|[DiffSynth-Studio/Template-KleinBase4B-Edit](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Edit)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Edit.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Edit.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Edit.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Edit.py)|-|-|
|
|
500
|
+
|[DiffSynth-Studio/Template-KleinBase4B-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Inpaint)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Inpaint.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Inpaint.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Inpaint.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Inpaint.py)|-|-|
|
|
501
|
+
|[DiffSynth-Studio/Template-KleinBase4B-PandaMeme](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-PandaMeme)|[code](/examples/flux2/model_inference/Template-KleinBase4B-PandaMeme.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-PandaMeme.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-PandaMeme.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-PandaMeme.py)|-|-|
|
|
502
|
+
|[DiffSynth-Studio/Template-KleinBase4B-Sharpness](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Sharpness)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Sharpness.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Sharpness.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Sharpness.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Sharpness.py)|-|-|
|
|
503
|
+
|[DiffSynth-Studio/Template-KleinBase4B-SoftRGB](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-SoftRGB)|[code](/examples/flux2/model_inference/Template-KleinBase4B-SoftRGB.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-SoftRGB.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-SoftRGB.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-SoftRGB.py)|-|-|
|
|
504
|
+
|[DiffSynth-Studio/Template-KleinBase4B-Upscaler](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Upscaler)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Upscaler.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Upscaler.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Upscaler.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Upscaler.py)|-|-|
|
|
505
|
+
|[DiffSynth-Studio/Template-KleinBase4B-ContentRef](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-ContentRef)|[code](/examples/flux2/model_inference/Template-KleinBase4B-ContentRef.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-ContentRef.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-ContentRef.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-ContentRef.py)|-|-|
|
|
482
506
|
|
|
483
507
|
</details>
|
|
484
508
|
|
|
@@ -864,6 +888,68 @@ Example code for JoyAI-Image is available at: [/examples/joyai_image/](/examples
|
|
|
864
888
|
|
|
865
889
|
</details>
|
|
866
890
|
|
|
891
|
+
#### HiDream-O1-Image: [/docs/en/Model_Details/HiDream-O1-Image.md](/docs/en/Model_Details/HiDream-O1-Image.md)
|
|
892
|
+
|
|
893
|
+
<details>
|
|
894
|
+
|
|
895
|
+
<summary>Quick Start</summary>
|
|
896
|
+
|
|
897
|
+
Running the following code will quickly load the [HiDream-ai/HiDream-O1-Image](https://modelscope.cn/HiDream-ai/HiDream-O1-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
|
|
898
|
+
|
|
899
|
+
```python
|
|
900
|
+
from diffsynth.pipelines.hidream_o1_image import HiDreamO1ImagePipeline
|
|
901
|
+
from diffsynth.core.loader.config import ModelConfig
|
|
902
|
+
import torch
|
|
903
|
+
|
|
904
|
+
|
|
905
|
+
vram_config = {
|
|
906
|
+
"offload_dtype": torch.bfloat16,
|
|
907
|
+
"offload_device": "cpu",
|
|
908
|
+
"onload_dtype": torch.bfloat16,
|
|
909
|
+
"onload_device": "cpu",
|
|
910
|
+
"preparing_dtype": torch.bfloat16,
|
|
911
|
+
"preparing_device": "cuda",
|
|
912
|
+
"computation_dtype": torch.bfloat16,
|
|
913
|
+
"computation_device": "cuda",
|
|
914
|
+
}
|
|
915
|
+
|
|
916
|
+
|
|
917
|
+
pipe = HiDreamO1ImagePipeline.from_pretrained(
|
|
918
|
+
torch_dtype=torch.bfloat16,
|
|
919
|
+
device="cuda",
|
|
920
|
+
model_configs=[
|
|
921
|
+
ModelConfig(model_id="HiDream-ai/HiDream-O1-Image", origin_file_pattern="model-*.safetensors", **vram_config),
|
|
922
|
+
],
|
|
923
|
+
processor_config=ModelConfig(model_id="HiDream-ai/HiDream-O1-Image", origin_file_pattern="./"),
|
|
924
|
+
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
|
925
|
+
)
|
|
926
|
+
image = pipe(
|
|
927
|
+
prompt="medium shot, eye-level, front view. A woman is seated in an ornate bedroom, illuminated by candlelight, with a calm and composed expression. The subject is a young woman with fair skin, light brown hair styled in an updo with loose tendrils framing her face, and blue eyes. She wears a cream-colored satin robe with delicate floral embroidery and lace trim along the neckline. Her ears are adorned with pearl drop earrings. She is seated on a bed with a dark, intricately carved wooden headboard. To her left, a wooden nightstand holds three lit white candles and a candelabra with multiple lit candles in the background. The bed is covered with patterned pillows and a dark, textured blanket. The walls are paneled with dark wood and feature a large, ornate tapestry with muted earth tones. The lighting creates soft highlights on her face and robe, with warm shadows cast across the room.",
|
|
928
|
+
negative_prompt=" ",
|
|
929
|
+
cfg_scale=4.0,
|
|
930
|
+
height=2048,
|
|
931
|
+
width=2048,
|
|
932
|
+
seed=42,
|
|
933
|
+
num_inference_steps=50,
|
|
934
|
+
)
|
|
935
|
+
image.save("image.jpg")
|
|
936
|
+
```
|
|
937
|
+
|
|
938
|
+
</details>
|
|
939
|
+
|
|
940
|
+
<details>
|
|
941
|
+
|
|
942
|
+
<summary>Examples</summary>
|
|
943
|
+
|
|
944
|
+
Example code for HiDream-O1-Image is available at: [/examples/hidream_o1_image/](/examples/hidream_o1_image/)
|
|
945
|
+
|
|
946
|
+
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
|
947
|
+
|-|-|-|-|-|-|-|
|
|
948
|
+
|[HiDream-ai/HiDream-O1-Image](https://modelscope.cn/HiDream-ai/HiDream-O1-Image)|[code](/examples/hidream_o1_image/model_inference/HiDream-O1-Image.py)|[code](/examples/hidream_o1_image/model_inference_low_vram/HiDream-O1-Image.py)|[code](/examples/hidream_o1_image/model_training/full/HiDream-O1-Image.sh)|[code](/examples/hidream_o1_image/model_training/validate_full/HiDream-O1-Image.py)|[code](/examples/hidream_o1_image/model_training/lora/HiDream-O1-Image.sh)|[code](/examples/hidream_o1_image/model_training/validate_lora/HiDream-O1-Image.py)|
|
|
949
|
+
|[HiDream-ai/HiDream-O1-Image-Dev](https://modelscope.cn/HiDream-ai/HiDream-O1-Image-Dev)|[code](/examples/hidream_o1_image/model_inference/HiDream-O1-Image-Dev.py)|[code](/examples/hidream_o1_image/model_inference_low_vram/HiDream-O1-Image-Dev.py)|[code](/examples/hidream_o1_image/model_training/full/HiDream-O1-Image-Dev.sh)|[code](/examples/hidream_o1_image/model_training/validate_full/HiDream-O1-Image-Dev.py)|[code](/examples/hidream_o1_image/model_training/lora/HiDream-O1-Image-Dev.sh)|[code](/examples/hidream_o1_image/model_training/validate_lora/HiDream-O1-Image-Dev.py)|
|
|
950
|
+
|
|
951
|
+
</details>
|
|
952
|
+
|
|
867
953
|
### Video Synthesis
|
|
868
954
|
|
|
869
955
|
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
|
@@ -1138,8 +1224,8 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
|
|
|
1138
1224
|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
|
1139
1225
|
|[openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py)|
|
|
1140
1226
|
|[openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py)|
|
|
1141
|
-
|[Wan-AI/
|
|
1142
|
-
|[Wan-AI/
|
|
1227
|
+
|[Wan-AI/Wan2.2-Dancer-14B (global model)](https://modelscope.cn/models/Wan-AI/Wan2.2-Dancer-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Dancer-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Dancer-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Dancer-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Dancer-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Dancer-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Dancer-14B-global.py)|
|
|
1228
|
+
|[Wan-AI/Wan2.2-Dancer-14B (local model)](https://modelscope.cn/models/Wan-AI/Wan2.2-Dancer-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Dancer-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Dancer-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Dancer-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Dancer-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Dancer-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Dancer-14B-local.py)|
|
|
1143
1229
|
|
|
1144
1230
|
</details>
|
|
1145
1231
|
|
|
@@ -309,7 +309,7 @@ wan_series = [
|
|
|
309
309
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
|
|
310
310
|
},
|
|
311
311
|
{
|
|
312
|
-
# Example: ModelConfig(model_id="Wan-AI/
|
|
312
|
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Dancer-14B", origin_file_pattern="global_model.safetensors")
|
|
313
313
|
"model_hash": "eb18873fc0ba77b541eb7b62dbcd2059",
|
|
314
314
|
"model_name": "wan_video_dit",
|
|
315
315
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
|
@@ -833,20 +833,6 @@ ltx2_series = [
|
|
|
833
833
|
"extra_kwargs": {"decoder_version": "ltx-2.3"},
|
|
834
834
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
|
835
835
|
},
|
|
836
|
-
{
|
|
837
|
-
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors")
|
|
838
|
-
"model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
|
|
839
|
-
"model_name": "ltx2_audio_vae_decoder",
|
|
840
|
-
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
|
841
|
-
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
|
842
|
-
},
|
|
843
|
-
{
|
|
844
|
-
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vae_encoder.safetensors")
|
|
845
|
-
"model_hash": "29338f3b95e7e312a3460a482e4f4554",
|
|
846
|
-
"model_name": "ltx2_audio_vae_encoder",
|
|
847
|
-
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
|
848
|
-
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
|
849
|
-
},
|
|
850
836
|
{
|
|
851
837
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors")
|
|
852
838
|
"model_hash": "cd436c99e69ec5c80f050f0944f02a15",
|
|
@@ -1040,7 +1026,16 @@ ace_step_series = [
|
|
|
1040
1026
|
},
|
|
1041
1027
|
]
|
|
1042
1028
|
|
|
1029
|
+
hidream_o1_image_series = [
|
|
1030
|
+
{
|
|
1031
|
+
# Example: ModelConfig(model_id="HiDream-ai/HiDream-O1-Image", origin_file_pattern="model-*.safetensors")
|
|
1032
|
+
"model_hash": "58a7c1073d79556bfc61e05e6061b771",
|
|
1033
|
+
"model_name": "hidream_o1_image_dit",
|
|
1034
|
+
"model_class": "diffsynth.models.hidream_o1_image_dit.HiDreamO1ImageModel",
|
|
1035
|
+
},
|
|
1036
|
+
]
|
|
1037
|
+
|
|
1043
1038
|
MODEL_CONFIGS = (
|
|
1044
1039
|
stable_diffusion_xl_series + stable_diffusion_series + qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series
|
|
1045
|
-
+ z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
|
|
1040
|
+
+ z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series + hidream_o1_image_series
|
|
1046
1041
|
)
|
|
@@ -327,7 +327,7 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
|
|
327
327
|
"diffsynth.models.ace_step_tokenizer.AceStepTokenizer": {
|
|
328
328
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
|
329
329
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
330
|
-
"
|
|
330
|
+
"diffsynth.models.ace_step_residual_fsq.ResidualFSQ": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
331
331
|
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
332
332
|
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
333
333
|
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
@@ -372,6 +372,14 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
|
|
372
372
|
"diffsynth.models.stable_diffusion_text_encoder.CLIPAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
373
373
|
"diffsynth.models.stable_diffusion_xl_text_encoder.CLIPTextModelWithProjection": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
374
374
|
},
|
|
375
|
+
"diffsynth.models.hidream_o1_image_dit.HiDreamO1ImageModel": {
|
|
376
|
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
|
377
|
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
378
|
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
379
|
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
380
|
+
"diffsynth.models.hidream_o1_image_dit.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
381
|
+
"diffsynth.models.hidream_o1_image_dit.Qwen3VLVisionModel": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
382
|
+
},
|
|
375
383
|
}
|
|
376
384
|
|
|
377
385
|
def QwenImageTextEncoder_Module_Map_Updater():
|
|
@@ -63,10 +63,10 @@ def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern
|
|
|
63
63
|
return out
|
|
64
64
|
|
|
65
65
|
|
|
66
|
-
def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
|
|
66
|
+
def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, is_causal=False):
|
|
67
67
|
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
|
68
68
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
|
69
|
-
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
|
|
69
|
+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale, is_causal=is_causal)
|
|
70
70
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
|
71
71
|
return out
|
|
72
72
|
|
|
@@ -81,10 +81,10 @@ def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_patte
|
|
|
81
81
|
return out
|
|
82
82
|
|
|
83
83
|
|
|
84
|
-
def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
|
84
|
+
def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None, is_causal=False):
|
|
85
85
|
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
|
86
86
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
|
87
|
-
out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
|
|
87
|
+
out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale, causal=is_causal)
|
|
88
88
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
|
89
89
|
return out
|
|
90
90
|
|
|
@@ -105,17 +105,17 @@ def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_patt
|
|
|
105
105
|
return out
|
|
106
106
|
|
|
107
107
|
|
|
108
|
-
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
|
|
108
|
+
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, is_causal=False, compatibility_mode=False):
|
|
109
109
|
if compatibility_mode or (attn_mask is not None):
|
|
110
|
-
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
|
|
110
|
+
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale, is_causal=is_causal)
|
|
111
111
|
else:
|
|
112
112
|
if ATTENTION_IMPLEMENTATION == "flash_attention_3":
|
|
113
113
|
return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
|
114
114
|
elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
|
|
115
|
-
return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
|
115
|
+
return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale, is_causal=is_causal)
|
|
116
116
|
elif ATTENTION_IMPLEMENTATION == "sage_attention":
|
|
117
117
|
return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
|
118
118
|
elif ATTENTION_IMPLEMENTATION == "xformers":
|
|
119
119
|
return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
|
120
120
|
else:
|
|
121
|
-
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
|
121
|
+
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale, is_causal=is_causal)
|
|
@@ -2,8 +2,6 @@ import math, warnings
|
|
|
2
2
|
import torch, torchvision, imageio, os
|
|
3
3
|
import imageio.v3 as iio
|
|
4
4
|
from PIL import Image
|
|
5
|
-
import torchaudio
|
|
6
|
-
from diffsynth.utils.data.audio import read_audio
|
|
7
5
|
|
|
8
6
|
|
|
9
7
|
class DataProcessingPipeline:
|
|
@@ -249,9 +247,11 @@ class ToAbsolutePath(DataProcessingOperator):
|
|
|
249
247
|
class LoadAudio(DataProcessingOperator):
|
|
250
248
|
def __init__(self, sr=16000):
|
|
251
249
|
self.sr = sr
|
|
252
|
-
def __call__(self, data: str):
|
|
253
250
|
import librosa
|
|
254
|
-
|
|
251
|
+
self.audio_loader = librosa.load
|
|
252
|
+
|
|
253
|
+
def __call__(self, data: str):
|
|
254
|
+
input_audio, sample_rate = self.audio_loader(data, sr=self.sr)
|
|
255
255
|
return input_audio
|
|
256
256
|
|
|
257
257
|
|
|
@@ -259,13 +259,15 @@ class LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin):
|
|
|
259
259
|
|
|
260
260
|
def __init__(self, num_frames=121, time_division_factor=8, time_division_remainder=1, frame_rate=24, fix_frame_rate=True):
|
|
261
261
|
FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)
|
|
262
|
+
import torchaudio
|
|
263
|
+
self.audio_loader = torchaudio.load
|
|
262
264
|
|
|
263
265
|
def __call__(self, data: str):
|
|
264
266
|
try:
|
|
265
267
|
reader = self.get_reader(data)
|
|
266
268
|
num_frames = self.get_num_frames(reader)
|
|
267
269
|
duration = num_frames / self.frame_rate
|
|
268
|
-
waveform, sample_rate =
|
|
270
|
+
waveform, sample_rate = self.audio_loader(data)
|
|
269
271
|
target_samples = int(duration * sample_rate)
|
|
270
272
|
current_samples = waveform.shape[-1]
|
|
271
273
|
if current_samples > target_samples:
|
|
@@ -285,10 +287,12 @@ class LoadPureAudioWithTorchaudio(DataProcessingOperator):
|
|
|
285
287
|
self.target_sample_rate = target_sample_rate
|
|
286
288
|
self.target_duration = target_duration
|
|
287
289
|
self.resample = True if target_sample_rate is not None else False
|
|
290
|
+
from diffsynth.utils.data.audio import read_audio
|
|
291
|
+
self.audio_loader = read_audio
|
|
288
292
|
|
|
289
293
|
def __call__(self, data: str):
|
|
290
294
|
try:
|
|
291
|
-
waveform, sample_rate =
|
|
295
|
+
waveform, sample_rate = self.audio_loader(data, resample=self.resample, resample_rate=self.target_sample_rate)
|
|
292
296
|
if self.target_duration is not None:
|
|
293
297
|
target_samples = int(self.target_duration * sample_rate)
|
|
294
298
|
current_samples = waveform.shape[-1]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .manager import OffloadTrainingManager
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Layer offloading for training — hook-based CPU offload.
|
|
3
|
+
|
|
4
|
+
Hook lifecycle per module:
|
|
5
|
+
|
|
6
|
+
No checkpointing:
|
|
7
|
+
forward_pre(load→GPU) → forward() → forward_hook(offload)
|
|
8
|
+
backward_pre(load→GPU) → backward() → backward_hook(offload)
|
|
9
|
+
|
|
10
|
+
With checkpointing (use_reentrant=False):
|
|
11
|
+
First forward:
|
|
12
|
+
forward_pre(load→GPU) → forward() → forward_hook(offload, mark in_recompute)
|
|
13
|
+
Recomputing forward (during backward):
|
|
14
|
+
forward_pre(load→GPU) → forward() → forward_hook(in_recompute=True → keep GPU)
|
|
15
|
+
Backward:
|
|
16
|
+
backward_pre(load→GPU) → backward() → backward_hook(offload)
|
|
17
|
+
"""
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
import warnings
|
|
21
|
+
from .offloader import StaticParamOffloader, TrainableParamOffloader, AlwaysOnGPUParamOffloader, BufferOffloader
|
|
22
|
+
from .memory_buffer import PinnedArenaPool, BaseBufferPool
|
|
23
|
+
warnings.filterwarnings("ignore", message="Full backward hook is firing when gradients are computed with respect to module outputs")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def has_parameters(module: nn.Module) -> bool:
|
|
27
|
+
return len(list(module.parameters())) > 0
|
|
28
|
+
|
|
29
|
+
def count_parameters(module: nn.Module) -> int:
|
|
30
|
+
return sum(p.numel() for p in module.parameters())
|
|
31
|
+
|
|
32
|
+
def is_leaf_module(module: nn.Module) -> bool:
|
|
33
|
+
return len(list(module.children())) == 0
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class UnitWiseParamManager:
|
|
37
|
+
def __init__(self, model: nn.Module, target_device: torch.device, enable_optimizer_cpu_offload: bool = False, params: list = None, buffers: list = None, memory_buffer: BaseBufferPool = None):
|
|
38
|
+
self.model = model
|
|
39
|
+
self.target_device = target_device
|
|
40
|
+
self.param_offloaders = {}
|
|
41
|
+
for param in (model.parameters() if params is None else params):
|
|
42
|
+
if not param.requires_grad:
|
|
43
|
+
self.param_offloaders[id(param)] = StaticParamOffloader(param, target_device, memory_buffer=memory_buffer)
|
|
44
|
+
else:
|
|
45
|
+
if enable_optimizer_cpu_offload:
|
|
46
|
+
self.param_offloaders[id(param)] = TrainableParamOffloader(param, target_device)
|
|
47
|
+
else:
|
|
48
|
+
self.param_offloaders[id(param)] = AlwaysOnGPUParamOffloader(param, target_device)
|
|
49
|
+
if buffers is not None and len(buffers) > 0:
|
|
50
|
+
for mod, buf_name, buf in buffers:
|
|
51
|
+
self.param_offloaders[id(buf)] = BufferOffloader(mod, buf_name, buf, target_device, memory_buffer=memory_buffer)
|
|
52
|
+
|
|
53
|
+
def move_gradients_to_cpu(self):
|
|
54
|
+
for offloader in self.param_offloaders.values():
|
|
55
|
+
offloader.offload_grad()
|
|
56
|
+
|
|
57
|
+
def onload_module(self, module: nn.Module):
|
|
58
|
+
for param in module.parameters(recurse=False):
|
|
59
|
+
if id(param) in self.param_offloaders:
|
|
60
|
+
self.param_offloaders[id(param)].onload()
|
|
61
|
+
for name, buf in module.named_buffers(recurse=False):
|
|
62
|
+
if id(buf) in self.param_offloaders:
|
|
63
|
+
self.param_offloaders[id(buf)].onload()
|
|
64
|
+
|
|
65
|
+
def offload_module(self, module: nn.Module):
|
|
66
|
+
for param in module.parameters(recurse=False):
|
|
67
|
+
if id(param) in self.param_offloaders:
|
|
68
|
+
self.param_offloaders[id(param)].offload()
|
|
69
|
+
for name, buf in module.named_buffers(recurse=False):
|
|
70
|
+
if id(buf) in self.param_offloaders:
|
|
71
|
+
self.param_offloaders[id(buf)].offload()
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class UnitWiseHookManager:
|
|
75
|
+
def __init__(self, model: nn.Module, target_device: torch.device, enable_optimizer_cpu_offload: bool = False,
|
|
76
|
+
params: list = None, buffers: list = None, memory_buffer: BaseBufferPool = None):
|
|
77
|
+
self.param_manager = UnitWiseParamManager(model, target_device, enable_optimizer_cpu_offload, params=params, buffers=buffers, memory_buffer=memory_buffer)
|
|
78
|
+
self._in_recompute: set = set()
|
|
79
|
+
self._register_hooks(model)
|
|
80
|
+
|
|
81
|
+
def _register_hooks(self, module: nn.Module):
|
|
82
|
+
def forward_pre_hook(mod, args):
|
|
83
|
+
self.param_manager.onload_module(mod)
|
|
84
|
+
|
|
85
|
+
def forward_hook(mod, args, output):
|
|
86
|
+
if mod in self._in_recompute:
|
|
87
|
+
return
|
|
88
|
+
self._in_recompute.add(mod)
|
|
89
|
+
self.param_manager.offload_module(mod)
|
|
90
|
+
|
|
91
|
+
def backward_pre_hook(mod, grad_output):
|
|
92
|
+
self.param_manager.onload_module(mod)
|
|
93
|
+
|
|
94
|
+
def backward_hook(mod, grad_input, grad_output):
|
|
95
|
+
self.param_manager.offload_module(mod)
|
|
96
|
+
|
|
97
|
+
module.register_forward_pre_hook(forward_pre_hook)
|
|
98
|
+
module.register_forward_hook(forward_hook)
|
|
99
|
+
module.register_full_backward_pre_hook(backward_pre_hook)
|
|
100
|
+
if is_leaf_module(module):
|
|
101
|
+
module.register_full_backward_hook(backward_hook)
|
|
102
|
+
else:
|
|
103
|
+
# Parent module backward_hook fires before child backward completes.
|
|
104
|
+
# Register on leaf children instead.
|
|
105
|
+
sub_modules = [m for m in module.modules() if is_leaf_module(m) and has_parameters(m)]
|
|
106
|
+
for sub_mod in sub_modules:
|
|
107
|
+
sub_mod.register_full_backward_hook(backward_hook)
|
|
108
|
+
|
|
109
|
+
def after_backward(self):
|
|
110
|
+
self._in_recompute.clear()
|
|
111
|
+
self.param_manager.move_gradients_to_cpu()
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def managed_param_ids(self):
|
|
115
|
+
return set(self.param_manager.param_offloaders.keys())
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class OffloadTrainingManager:
|
|
119
|
+
def __init__(self, model: nn.Module, target_device: torch.device, enable_optimizer_cpu_offload: bool = False, cpu_offload_split_threshold: int = None):
|
|
120
|
+
self.model = model
|
|
121
|
+
self.target_device = target_device
|
|
122
|
+
self.enable_optimizer_cpu_offload = enable_optimizer_cpu_offload
|
|
123
|
+
cpu_offload_split_threshold = cpu_offload_split_threshold * 1024 * 1024 if cpu_offload_split_threshold is not None else None
|
|
124
|
+
self._register_units(model, target_device, enable_optimizer_cpu_offload, cpu_offload_split_threshold)
|
|
125
|
+
|
|
126
|
+
def _register_units(self, model: nn.Module, target_device: torch.device, enable_optimizer_cpu_offload: bool, cpu_offload_split_threshold: int = None):
|
|
127
|
+
self.memory_buffer = PinnedArenaPool.from_model(model)
|
|
128
|
+
units = self._find_units_recursive(model, cpu_offload_split_threshold)
|
|
129
|
+
self.units = [UnitWiseHookManager(u, target_device, enable_optimizer_cpu_offload, memory_buffer=self.memory_buffer) for u in units]
|
|
130
|
+
|
|
131
|
+
managed_param_ids = set().union(*[unit.managed_param_ids for unit in self.units])
|
|
132
|
+
orphan_params, orphan_buffers = self._find_orphan_params_and_buffers(model, managed_param_ids)
|
|
133
|
+
for orphan_module in set(orphan_params.keys()) | set(orphan_buffers.keys()):
|
|
134
|
+
params = orphan_params.get(orphan_module, [])
|
|
135
|
+
buffers = orphan_buffers.get(orphan_module, [])
|
|
136
|
+
self.units.append(UnitWiseHookManager(orphan_module, target_device, enable_optimizer_cpu_offload, params=params, buffers=buffers, memory_buffer=self.memory_buffer))
|
|
137
|
+
|
|
138
|
+
def _find_orphan_params_and_buffers(self, model: nn.Module, managed_param_ids: set):
|
|
139
|
+
orphan_params_by_module = {}
|
|
140
|
+
for _, mod in model.named_modules():
|
|
141
|
+
for param in mod.parameters(recurse=False):
|
|
142
|
+
if id(param) not in managed_param_ids:
|
|
143
|
+
orphan_params_by_module.setdefault(mod, []).append(param)
|
|
144
|
+
# Collect orphan buffers grouped by owner module
|
|
145
|
+
orphan_buffers_by_module = {}
|
|
146
|
+
for _, mod in model.named_modules():
|
|
147
|
+
for name, buf in mod.named_buffers(recurse=False):
|
|
148
|
+
orphan_buffers_by_module.setdefault(mod, []).append((mod, name, buf))
|
|
149
|
+
|
|
150
|
+
return orphan_params_by_module, orphan_buffers_by_module
|
|
151
|
+
|
|
152
|
+
def _find_units_recursive(self, module: nn.Module, cpu_offload_split_threshold: int = None) -> list:
|
|
153
|
+
if cpu_offload_split_threshold is None:
|
|
154
|
+
return [m for m in module.modules() if is_leaf_module(m) and has_parameters(m)]
|
|
155
|
+
if self._should_force_recurse(module, cpu_offload_split_threshold):
|
|
156
|
+
units = []
|
|
157
|
+
for child in module.children():
|
|
158
|
+
units.extend(self._find_units_recursive(child, cpu_offload_split_threshold))
|
|
159
|
+
return units
|
|
160
|
+
return [module]
|
|
161
|
+
|
|
162
|
+
def _should_force_recurse(self, module: nn.Module, cpu_offload_split_threshold: int = None) -> bool:
|
|
163
|
+
if is_leaf_module(module):
|
|
164
|
+
return False
|
|
165
|
+
if (
|
|
166
|
+
count_parameters(module) > cpu_offload_split_threshold
|
|
167
|
+
or ('forward' not in type(module).__dict__)
|
|
168
|
+
or (hasattr(module, 'encode') and hasattr(module, 'decode'))
|
|
169
|
+
):
|
|
170
|
+
return True
|
|
171
|
+
return False
|
|
172
|
+
|
|
173
|
+
# run after backward() and before optimizer.step()
|
|
174
|
+
def after_backward(self):
|
|
175
|
+
for unit in self.units:
|
|
176
|
+
unit.after_backward()
|
|
177
|
+
torch.cuda.synchronize()
|