diffsynth 2.0.11__tar.gz → 2.0.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diffsynth-2.0.11 → diffsynth-2.0.12}/PKG-INFO +9 -1
- {diffsynth-2.0.11 → diffsynth-2.0.12}/README.md +68 -2
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/configs/model_configs.py +11 -16
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/configs/vram_management_module_maps.py +9 -1
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/__init__.py +1 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/attention/attention.py +8 -8
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/data/operators.py +10 -6
- diffsynth-2.0.12/diffsynth/core/offload_training/__init__.py +1 -0
- diffsynth-2.0.12/diffsynth/core/offload_training/manager.py +177 -0
- diffsynth-2.0.12/diffsynth/core/offload_training/memory_buffer.py +136 -0
- diffsynth-2.0.12/diffsynth/core/offload_training/offloader.py +71 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/diffusion/__init__.py +1 -1
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/diffusion/flow_match.py +62 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/diffusion/loss.py +1 -1
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/diffusion/parsers.py +7 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/diffusion/runner.py +32 -6
- diffsynth-2.0.12/diffsynth/models/ace_step_residual_fsq.py +569 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/ace_step_tokenizer.py +2 -4
- diffsynth-2.0.12/diffsynth/models/hidream_common.py +373 -0
- diffsynth-2.0.12/diffsynth/models/hidream_o1_image_dit.py +1910 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/ltx2_audio_vae.py +4 -2
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/pipelines/ace_step.py +15 -14
- diffsynth-2.0.12/diffsynth/pipelines/hidream_o1_image.py +420 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/pipelines/qwen_image.py +1 -1
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ltx2_text_encoder.py +5 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth.egg-info/PKG-INFO +9 -1
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth.egg-info/SOURCES.txt +8 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth.egg-info/requires.txt +9 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/pyproject.toml +11 -2
- {diffsynth-2.0.11 → diffsynth-2.0.12}/LICENSE +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/__init__.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/configs/__init__.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/attention/__init__.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/data/__init__.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/data/unified_dataset.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/device/__init__.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/device/npu_compatible_device.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/gradient/__init__.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/gradient/gradient_checkpoint.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/loader/__init__.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/loader/config.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/loader/file.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/loader/model.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/npu_patch/npu_fused_operator.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/vram/__init__.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/vram/disk_map.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/vram/initialization.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/core/vram/layers.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/diffusion/base_pipeline.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/diffusion/ddim_scheduler.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/diffusion/logger.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/diffusion/template.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/diffusion/training_module.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/ace_step_conditioner.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/ace_step_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/ace_step_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/ace_step_vae.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/anima_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/dinov3_image_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/ernie_image_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/ernie_image_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/flux2_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/flux2_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/flux2_vae.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/flux_controlnet.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/flux_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/flux_infiniteyou.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/flux_ipadapter.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/flux_lora_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/flux_lora_patcher.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/flux_text_encoder_clip.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/flux_text_encoder_t5.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/flux_vae.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/flux_value_control.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/general_modules.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/joyai_image_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/joyai_image_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/longcat_video_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/ltx2_common.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/ltx2_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/ltx2_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/ltx2_upsampler.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/ltx2_video_vae.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/model_loader.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/mova_audio_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/mova_audio_vae.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/mova_dual_tower_bridge.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/nexus_gen.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/nexus_gen_ar_model.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/nexus_gen_projector.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/qwen_image_controlnet.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/qwen_image_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/qwen_image_image2lora.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/qwen_image_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/qwen_image_vae.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/sd_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/siglip2_image_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/stable_diffusion_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/stable_diffusion_unet.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/stable_diffusion_vae.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/stable_diffusion_xl_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/stable_diffusion_xl_unet.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/step1x_connector.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/step1x_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/wan_video_animate_adapter.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/wan_video_camera_controller.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/wan_video_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/wan_video_dit_s2v.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/wan_video_image_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/wan_video_mot.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/wan_video_motion_controller.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/wan_video_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/wan_video_vace.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/wan_video_vae.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/wantodance.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/wav2vec.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/z_image_controlnet.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/z_image_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/z_image_image2lora.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/models/z_image_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/pipelines/anima_image.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/pipelines/ernie_image.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/pipelines/flux2_image.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/pipelines/flux_image.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/pipelines/joyai_image.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/pipelines/ltx2_audio_video.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/pipelines/mova_audio_video.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/pipelines/stable_diffusion.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/pipelines/stable_diffusion_xl.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/pipelines/wan_video.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/pipelines/z_image.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/controlnet/__init__.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/controlnet/annotator.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/controlnet/controlnet_input.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/data/__init__.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/data/audio.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/data/audio_video.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/data/media_io_ltx2.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/lora/__init__.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/lora/flux.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/lora/general.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/lora/merge.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/lora/reset_rank.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/ses/__init__.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/ses/ses.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/__init__.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ace_step_conditioner.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ace_step_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/anima_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/dino_v3.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ernie_image_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux2_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_controlnet.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_infiniteyou.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_ipadapter.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_vae.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/joyai_image_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ltx2_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ltx2_video_vae.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/nexus_gen.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/nexus_gen_projector.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/qwen_image_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/stable_diffusion_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/stable_diffusion_vae.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/stable_diffusion_xl_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/step1x_connector.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_animate_adapter.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_image_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_mot.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_vace.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_vae.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wans2v_audio_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/z_image_dit.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/z_image_text_encoder.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/xfuser/__init__.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/utils/xfuser/xdit_context_parallel.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth/version.py +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth.egg-info/dependency_links.txt +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/diffsynth.egg-info/top_level.txt +0 -0
- {diffsynth-2.0.11 → diffsynth-2.0.12}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: diffsynth
|
|
3
|
-
Version: 2.0.
|
|
3
|
+
Version: 2.0.12
|
|
4
4
|
Summary: Enjoy the magic of Diffusion models!
|
|
5
5
|
Author: ModelScope Team
|
|
6
6
|
License: Apache-2.0
|
|
@@ -33,6 +33,14 @@ Requires-Dist: torch==2.7.1+cpu; extra == "npu"
|
|
|
33
33
|
Requires-Dist: torch-npu==2.7.1; extra == "npu"
|
|
34
34
|
Requires-Dist: torchvision==0.22.1+cpu; extra == "npu"
|
|
35
35
|
Provides-Extra: audio
|
|
36
|
+
Requires-Dist: av; extra == "audio"
|
|
36
37
|
Requires-Dist: torchaudio; extra == "audio"
|
|
37
38
|
Requires-Dist: torchcodec; extra == "audio"
|
|
39
|
+
Requires-Dist: librosa; extra == "audio"
|
|
40
|
+
Provides-Extra: all
|
|
41
|
+
Requires-Dist: av; extra == "all"
|
|
42
|
+
Requires-Dist: torchaudio; extra == "all"
|
|
43
|
+
Requires-Dist: torchcodec; extra == "all"
|
|
44
|
+
Requires-Dist: librosa; extra == "all"
|
|
45
|
+
Requires-Dist: streamlit; extra == "all"
|
|
38
46
|
Dynamic: license-file
|
|
@@ -34,6 +34,10 @@ We believe that a well-developed open-source code framework can lower the thresh
|
|
|
34
34
|
|
|
35
35
|
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
|
36
36
|
|
|
37
|
+
- **May 18, 2026** Added **CPU Offload Training** support. By moving model weights layer-by-layer between CPU and GPU, it significantly reduces GPU VRAM usage during training, enabling LoRA training of large models even on consumer-grade GPUs, compatible with all models. Simply add `--enable_model_cpu_offload` to your training command to enable (currently supports single-GPU training only). For details, see the [documentation](/docs/en/Training/Offload_Training.md).
|
|
38
|
+
|
|
39
|
+
- **May 14, 2026** HiDream-O1-Image open-sourced, welcome a new member to the image model family! Support includes text-to-image generation, image editing, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/HiDream-O1-Image.md) and [example code](/examples/hidream_o1_image/).
|
|
40
|
+
|
|
37
41
|
- **April 28, 2026** 🔥 We are excited to announce the release of **Diffusion Templates**, a plugin framework designed for Diffusion models that significantly lowers the barrier to training controllable generative models. Let's explore this cutting-edge technology together!
|
|
38
42
|
* Open-source code: [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
|
|
39
43
|
* Technical report: [arXiv](https://arxiv.org/abs/2604.24351)
|
|
@@ -884,6 +888,68 @@ Example code for JoyAI-Image is available at: [/examples/joyai_image/](/examples
|
|
|
884
888
|
|
|
885
889
|
</details>
|
|
886
890
|
|
|
891
|
+
#### HiDream-O1-Image: [/docs/en/Model_Details/HiDream-O1-Image.md](/docs/en/Model_Details/HiDream-O1-Image.md)
|
|
892
|
+
|
|
893
|
+
<details>
|
|
894
|
+
|
|
895
|
+
<summary>Quick Start</summary>
|
|
896
|
+
|
|
897
|
+
Running the following code will quickly load the [HiDream-ai/HiDream-O1-Image](https://modelscope.cn/HiDream-ai/HiDream-O1-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
|
|
898
|
+
|
|
899
|
+
```python
|
|
900
|
+
from diffsynth.pipelines.hidream_o1_image import HiDreamO1ImagePipeline
|
|
901
|
+
from diffsynth.core.loader.config import ModelConfig
|
|
902
|
+
import torch
|
|
903
|
+
|
|
904
|
+
|
|
905
|
+
vram_config = {
|
|
906
|
+
"offload_dtype": torch.bfloat16,
|
|
907
|
+
"offload_device": "cpu",
|
|
908
|
+
"onload_dtype": torch.bfloat16,
|
|
909
|
+
"onload_device": "cpu",
|
|
910
|
+
"preparing_dtype": torch.bfloat16,
|
|
911
|
+
"preparing_device": "cuda",
|
|
912
|
+
"computation_dtype": torch.bfloat16,
|
|
913
|
+
"computation_device": "cuda",
|
|
914
|
+
}
|
|
915
|
+
|
|
916
|
+
|
|
917
|
+
pipe = HiDreamO1ImagePipeline.from_pretrained(
|
|
918
|
+
torch_dtype=torch.bfloat16,
|
|
919
|
+
device="cuda",
|
|
920
|
+
model_configs=[
|
|
921
|
+
ModelConfig(model_id="HiDream-ai/HiDream-O1-Image", origin_file_pattern="model-*.safetensors", **vram_config),
|
|
922
|
+
],
|
|
923
|
+
processor_config=ModelConfig(model_id="HiDream-ai/HiDream-O1-Image", origin_file_pattern="./"),
|
|
924
|
+
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
|
925
|
+
)
|
|
926
|
+
image = pipe(
|
|
927
|
+
prompt="medium shot, eye-level, front view. A woman is seated in an ornate bedroom, illuminated by candlelight, with a calm and composed expression. The subject is a young woman with fair skin, light brown hair styled in an updo with loose tendrils framing her face, and blue eyes. She wears a cream-colored satin robe with delicate floral embroidery and lace trim along the neckline. Her ears are adorned with pearl drop earrings. She is seated on a bed with a dark, intricately carved wooden headboard. To her left, a wooden nightstand holds three lit white candles and a candelabra with multiple lit candles in the background. The bed is covered with patterned pillows and a dark, textured blanket. The walls are paneled with dark wood and feature a large, ornate tapestry with muted earth tones. The lighting creates soft highlights on her face and robe, with warm shadows cast across the room.",
|
|
928
|
+
negative_prompt=" ",
|
|
929
|
+
cfg_scale=4.0,
|
|
930
|
+
height=2048,
|
|
931
|
+
width=2048,
|
|
932
|
+
seed=42,
|
|
933
|
+
num_inference_steps=50,
|
|
934
|
+
)
|
|
935
|
+
image.save("image.jpg")
|
|
936
|
+
```
|
|
937
|
+
|
|
938
|
+
</details>
|
|
939
|
+
|
|
940
|
+
<details>
|
|
941
|
+
|
|
942
|
+
<summary>Examples</summary>
|
|
943
|
+
|
|
944
|
+
Example code for HiDream-O1-Image is available at: [/examples/hidream_o1_image/](/examples/hidream_o1_image/)
|
|
945
|
+
|
|
946
|
+
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
|
947
|
+
|-|-|-|-|-|-|-|
|
|
948
|
+
|[HiDream-ai/HiDream-O1-Image](https://modelscope.cn/HiDream-ai/HiDream-O1-Image)|[code](/examples/hidream_o1_image/model_inference/HiDream-O1-Image.py)|[code](/examples/hidream_o1_image/model_inference_low_vram/HiDream-O1-Image.py)|[code](/examples/hidream_o1_image/model_training/full/HiDream-O1-Image.sh)|[code](/examples/hidream_o1_image/model_training/validate_full/HiDream-O1-Image.py)|[code](/examples/hidream_o1_image/model_training/lora/HiDream-O1-Image.sh)|[code](/examples/hidream_o1_image/model_training/validate_lora/HiDream-O1-Image.py)|
|
|
949
|
+
|[HiDream-ai/HiDream-O1-Image-Dev](https://modelscope.cn/HiDream-ai/HiDream-O1-Image-Dev)|[code](/examples/hidream_o1_image/model_inference/HiDream-O1-Image-Dev.py)|[code](/examples/hidream_o1_image/model_inference_low_vram/HiDream-O1-Image-Dev.py)|[code](/examples/hidream_o1_image/model_training/full/HiDream-O1-Image-Dev.sh)|[code](/examples/hidream_o1_image/model_training/validate_full/HiDream-O1-Image-Dev.py)|[code](/examples/hidream_o1_image/model_training/lora/HiDream-O1-Image-Dev.sh)|[code](/examples/hidream_o1_image/model_training/validate_lora/HiDream-O1-Image-Dev.py)|
|
|
950
|
+
|
|
951
|
+
</details>
|
|
952
|
+
|
|
887
953
|
### Video Synthesis
|
|
888
954
|
|
|
889
955
|
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
|
@@ -1158,8 +1224,8 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
|
|
|
1158
1224
|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
|
1159
1225
|
|[openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py)|
|
|
1160
1226
|
|[openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py)|
|
|
1161
|
-
|[Wan-AI/
|
|
1162
|
-
|[Wan-AI/
|
|
1227
|
+
|[Wan-AI/Wan2.2-Dancer-14B (global model)](https://modelscope.cn/models/Wan-AI/Wan2.2-Dancer-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Dancer-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Dancer-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Dancer-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Dancer-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Dancer-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Dancer-14B-global.py)|
|
|
1228
|
+
|[Wan-AI/Wan2.2-Dancer-14B (local model)](https://modelscope.cn/models/Wan-AI/Wan2.2-Dancer-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Dancer-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Dancer-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Dancer-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Dancer-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Dancer-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Dancer-14B-local.py)|
|
|
1163
1229
|
|
|
1164
1230
|
</details>
|
|
1165
1231
|
|
|
@@ -309,7 +309,7 @@ wan_series = [
|
|
|
309
309
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
|
|
310
310
|
},
|
|
311
311
|
{
|
|
312
|
-
# Example: ModelConfig(model_id="Wan-AI/
|
|
312
|
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Dancer-14B", origin_file_pattern="global_model.safetensors")
|
|
313
313
|
"model_hash": "eb18873fc0ba77b541eb7b62dbcd2059",
|
|
314
314
|
"model_name": "wan_video_dit",
|
|
315
315
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
|
@@ -833,20 +833,6 @@ ltx2_series = [
|
|
|
833
833
|
"extra_kwargs": {"decoder_version": "ltx-2.3"},
|
|
834
834
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
|
835
835
|
},
|
|
836
|
-
{
|
|
837
|
-
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors")
|
|
838
|
-
"model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
|
|
839
|
-
"model_name": "ltx2_audio_vae_decoder",
|
|
840
|
-
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
|
841
|
-
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
|
842
|
-
},
|
|
843
|
-
{
|
|
844
|
-
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vae_encoder.safetensors")
|
|
845
|
-
"model_hash": "29338f3b95e7e312a3460a482e4f4554",
|
|
846
|
-
"model_name": "ltx2_audio_vae_encoder",
|
|
847
|
-
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
|
848
|
-
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
|
849
|
-
},
|
|
850
836
|
{
|
|
851
837
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors")
|
|
852
838
|
"model_hash": "cd436c99e69ec5c80f050f0944f02a15",
|
|
@@ -1040,7 +1026,16 @@ ace_step_series = [
|
|
|
1040
1026
|
},
|
|
1041
1027
|
]
|
|
1042
1028
|
|
|
1029
|
+
hidream_o1_image_series = [
|
|
1030
|
+
{
|
|
1031
|
+
# Example: ModelConfig(model_id="HiDream-ai/HiDream-O1-Image", origin_file_pattern="model-*.safetensors")
|
|
1032
|
+
"model_hash": "58a7c1073d79556bfc61e05e6061b771",
|
|
1033
|
+
"model_name": "hidream_o1_image_dit",
|
|
1034
|
+
"model_class": "diffsynth.models.hidream_o1_image_dit.HiDreamO1ImageModel",
|
|
1035
|
+
},
|
|
1036
|
+
]
|
|
1037
|
+
|
|
1043
1038
|
MODEL_CONFIGS = (
|
|
1044
1039
|
stable_diffusion_xl_series + stable_diffusion_series + qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series
|
|
1045
|
-
+ z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
|
|
1040
|
+
+ z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series + hidream_o1_image_series
|
|
1046
1041
|
)
|
|
@@ -327,7 +327,7 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
|
|
327
327
|
"diffsynth.models.ace_step_tokenizer.AceStepTokenizer": {
|
|
328
328
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
|
329
329
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
330
|
-
"
|
|
330
|
+
"diffsynth.models.ace_step_residual_fsq.ResidualFSQ": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
331
331
|
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
332
332
|
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
333
333
|
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
@@ -372,6 +372,14 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
|
|
372
372
|
"diffsynth.models.stable_diffusion_text_encoder.CLIPAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
373
373
|
"diffsynth.models.stable_diffusion_xl_text_encoder.CLIPTextModelWithProjection": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
374
374
|
},
|
|
375
|
+
"diffsynth.models.hidream_o1_image_dit.HiDreamO1ImageModel": {
|
|
376
|
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
|
377
|
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
378
|
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
379
|
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
380
|
+
"diffsynth.models.hidream_o1_image_dit.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
381
|
+
"diffsynth.models.hidream_o1_image_dit.Qwen3VLVisionModel": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
382
|
+
},
|
|
375
383
|
}
|
|
376
384
|
|
|
377
385
|
def QwenImageTextEncoder_Module_Map_Updater():
|
|
@@ -63,10 +63,10 @@ def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern
|
|
|
63
63
|
return out
|
|
64
64
|
|
|
65
65
|
|
|
66
|
-
def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
|
|
66
|
+
def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, is_causal=False):
|
|
67
67
|
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
|
68
68
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
|
69
|
-
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
|
|
69
|
+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale, is_causal=is_causal)
|
|
70
70
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
|
71
71
|
return out
|
|
72
72
|
|
|
@@ -81,10 +81,10 @@ def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_patte
|
|
|
81
81
|
return out
|
|
82
82
|
|
|
83
83
|
|
|
84
|
-
def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
|
84
|
+
def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None, is_causal=False):
|
|
85
85
|
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
|
86
86
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
|
87
|
-
out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
|
|
87
|
+
out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale, causal=is_causal)
|
|
88
88
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
|
89
89
|
return out
|
|
90
90
|
|
|
@@ -105,17 +105,17 @@ def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_patt
|
|
|
105
105
|
return out
|
|
106
106
|
|
|
107
107
|
|
|
108
|
-
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
|
|
108
|
+
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, is_causal=False, compatibility_mode=False):
|
|
109
109
|
if compatibility_mode or (attn_mask is not None):
|
|
110
|
-
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
|
|
110
|
+
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale, is_causal=is_causal)
|
|
111
111
|
else:
|
|
112
112
|
if ATTENTION_IMPLEMENTATION == "flash_attention_3":
|
|
113
113
|
return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
|
114
114
|
elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
|
|
115
|
-
return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
|
115
|
+
return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale, is_causal=is_causal)
|
|
116
116
|
elif ATTENTION_IMPLEMENTATION == "sage_attention":
|
|
117
117
|
return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
|
118
118
|
elif ATTENTION_IMPLEMENTATION == "xformers":
|
|
119
119
|
return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
|
120
120
|
else:
|
|
121
|
-
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
|
121
|
+
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale, is_causal=is_causal)
|
|
@@ -2,8 +2,6 @@ import math, warnings
|
|
|
2
2
|
import torch, torchvision, imageio, os
|
|
3
3
|
import imageio.v3 as iio
|
|
4
4
|
from PIL import Image
|
|
5
|
-
import torchaudio
|
|
6
|
-
from diffsynth.utils.data.audio import read_audio
|
|
7
5
|
|
|
8
6
|
|
|
9
7
|
class DataProcessingPipeline:
|
|
@@ -249,9 +247,11 @@ class ToAbsolutePath(DataProcessingOperator):
|
|
|
249
247
|
class LoadAudio(DataProcessingOperator):
|
|
250
248
|
def __init__(self, sr=16000):
|
|
251
249
|
self.sr = sr
|
|
252
|
-
def __call__(self, data: str):
|
|
253
250
|
import librosa
|
|
254
|
-
|
|
251
|
+
self.audio_loader = librosa.load
|
|
252
|
+
|
|
253
|
+
def __call__(self, data: str):
|
|
254
|
+
input_audio, sample_rate = self.audio_loader(data, sr=self.sr)
|
|
255
255
|
return input_audio
|
|
256
256
|
|
|
257
257
|
|
|
@@ -259,13 +259,15 @@ class LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin):
|
|
|
259
259
|
|
|
260
260
|
def __init__(self, num_frames=121, time_division_factor=8, time_division_remainder=1, frame_rate=24, fix_frame_rate=True):
|
|
261
261
|
FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)
|
|
262
|
+
import torchaudio
|
|
263
|
+
self.audio_loader = torchaudio.load
|
|
262
264
|
|
|
263
265
|
def __call__(self, data: str):
|
|
264
266
|
try:
|
|
265
267
|
reader = self.get_reader(data)
|
|
266
268
|
num_frames = self.get_num_frames(reader)
|
|
267
269
|
duration = num_frames / self.frame_rate
|
|
268
|
-
waveform, sample_rate =
|
|
270
|
+
waveform, sample_rate = self.audio_loader(data)
|
|
269
271
|
target_samples = int(duration * sample_rate)
|
|
270
272
|
current_samples = waveform.shape[-1]
|
|
271
273
|
if current_samples > target_samples:
|
|
@@ -285,10 +287,12 @@ class LoadPureAudioWithTorchaudio(DataProcessingOperator):
|
|
|
285
287
|
self.target_sample_rate = target_sample_rate
|
|
286
288
|
self.target_duration = target_duration
|
|
287
289
|
self.resample = True if target_sample_rate is not None else False
|
|
290
|
+
from diffsynth.utils.data.audio import read_audio
|
|
291
|
+
self.audio_loader = read_audio
|
|
288
292
|
|
|
289
293
|
def __call__(self, data: str):
|
|
290
294
|
try:
|
|
291
|
-
waveform, sample_rate =
|
|
295
|
+
waveform, sample_rate = self.audio_loader(data, resample=self.resample, resample_rate=self.target_sample_rate)
|
|
292
296
|
if self.target_duration is not None:
|
|
293
297
|
target_samples = int(self.target_duration * sample_rate)
|
|
294
298
|
current_samples = waveform.shape[-1]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .manager import OffloadTrainingManager
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Layer offloading for training — hook-based CPU offload.
|
|
3
|
+
|
|
4
|
+
Hook lifecycle per module:
|
|
5
|
+
|
|
6
|
+
No checkpointing:
|
|
7
|
+
forward_pre(load→GPU) → forward() → forward_hook(offload)
|
|
8
|
+
backward_pre(load→GPU) → backward() → backward_hook(offload)
|
|
9
|
+
|
|
10
|
+
With checkpointing (use_reentrant=False):
|
|
11
|
+
First forward:
|
|
12
|
+
forward_pre(load→GPU) → forward() → forward_hook(offload, mark in_recompute)
|
|
13
|
+
Recomputing forward (during backward):
|
|
14
|
+
forward_pre(load→GPU) → forward() → forward_hook(in_recompute=True → keep GPU)
|
|
15
|
+
Backward:
|
|
16
|
+
backward_pre(load→GPU) → backward() → backward_hook(offload)
|
|
17
|
+
"""
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
import warnings
|
|
21
|
+
from .offloader import StaticParamOffloader, TrainableParamOffloader, AlwaysOnGPUParamOffloader, BufferOffloader
|
|
22
|
+
from .memory_buffer import PinnedArenaPool, BaseBufferPool
|
|
23
|
+
warnings.filterwarnings("ignore", message="Full backward hook is firing when gradients are computed with respect to module outputs")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def has_parameters(module: nn.Module) -> bool:
|
|
27
|
+
return len(list(module.parameters())) > 0
|
|
28
|
+
|
|
29
|
+
def count_parameters(module: nn.Module) -> int:
|
|
30
|
+
return sum(p.numel() for p in module.parameters())
|
|
31
|
+
|
|
32
|
+
def is_leaf_module(module: nn.Module) -> bool:
|
|
33
|
+
return len(list(module.children())) == 0
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class UnitWiseParamManager:
|
|
37
|
+
def __init__(self, model: nn.Module, target_device: torch.device, enable_optimizer_cpu_offload: bool = False, params: list = None, buffers: list = None, memory_buffer: BaseBufferPool = None):
|
|
38
|
+
self.model = model
|
|
39
|
+
self.target_device = target_device
|
|
40
|
+
self.param_offloaders = {}
|
|
41
|
+
for param in (model.parameters() if params is None else params):
|
|
42
|
+
if not param.requires_grad:
|
|
43
|
+
self.param_offloaders[id(param)] = StaticParamOffloader(param, target_device, memory_buffer=memory_buffer)
|
|
44
|
+
else:
|
|
45
|
+
if enable_optimizer_cpu_offload:
|
|
46
|
+
self.param_offloaders[id(param)] = TrainableParamOffloader(param, target_device)
|
|
47
|
+
else:
|
|
48
|
+
self.param_offloaders[id(param)] = AlwaysOnGPUParamOffloader(param, target_device)
|
|
49
|
+
if buffers is not None and len(buffers) > 0:
|
|
50
|
+
for mod, buf_name, buf in buffers:
|
|
51
|
+
self.param_offloaders[id(buf)] = BufferOffloader(mod, buf_name, buf, target_device, memory_buffer=memory_buffer)
|
|
52
|
+
|
|
53
|
+
def move_gradients_to_cpu(self):
|
|
54
|
+
for offloader in self.param_offloaders.values():
|
|
55
|
+
offloader.offload_grad()
|
|
56
|
+
|
|
57
|
+
def onload_module(self, module: nn.Module):
|
|
58
|
+
for param in module.parameters(recurse=False):
|
|
59
|
+
if id(param) in self.param_offloaders:
|
|
60
|
+
self.param_offloaders[id(param)].onload()
|
|
61
|
+
for name, buf in module.named_buffers(recurse=False):
|
|
62
|
+
if id(buf) in self.param_offloaders:
|
|
63
|
+
self.param_offloaders[id(buf)].onload()
|
|
64
|
+
|
|
65
|
+
def offload_module(self, module: nn.Module):
|
|
66
|
+
for param in module.parameters(recurse=False):
|
|
67
|
+
if id(param) in self.param_offloaders:
|
|
68
|
+
self.param_offloaders[id(param)].offload()
|
|
69
|
+
for name, buf in module.named_buffers(recurse=False):
|
|
70
|
+
if id(buf) in self.param_offloaders:
|
|
71
|
+
self.param_offloaders[id(buf)].offload()
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class UnitWiseHookManager:
|
|
75
|
+
def __init__(self, model: nn.Module, target_device: torch.device, enable_optimizer_cpu_offload: bool = False,
|
|
76
|
+
params: list = None, buffers: list = None, memory_buffer: BaseBufferPool = None):
|
|
77
|
+
self.param_manager = UnitWiseParamManager(model, target_device, enable_optimizer_cpu_offload, params=params, buffers=buffers, memory_buffer=memory_buffer)
|
|
78
|
+
self._in_recompute: set = set()
|
|
79
|
+
self._register_hooks(model)
|
|
80
|
+
|
|
81
|
+
def _register_hooks(self, module: nn.Module):
|
|
82
|
+
def forward_pre_hook(mod, args):
|
|
83
|
+
self.param_manager.onload_module(mod)
|
|
84
|
+
|
|
85
|
+
def forward_hook(mod, args, output):
|
|
86
|
+
if mod in self._in_recompute:
|
|
87
|
+
return
|
|
88
|
+
self._in_recompute.add(mod)
|
|
89
|
+
self.param_manager.offload_module(mod)
|
|
90
|
+
|
|
91
|
+
def backward_pre_hook(mod, grad_output):
|
|
92
|
+
self.param_manager.onload_module(mod)
|
|
93
|
+
|
|
94
|
+
def backward_hook(mod, grad_input, grad_output):
|
|
95
|
+
self.param_manager.offload_module(mod)
|
|
96
|
+
|
|
97
|
+
module.register_forward_pre_hook(forward_pre_hook)
|
|
98
|
+
module.register_forward_hook(forward_hook)
|
|
99
|
+
module.register_full_backward_pre_hook(backward_pre_hook)
|
|
100
|
+
if is_leaf_module(module):
|
|
101
|
+
module.register_full_backward_hook(backward_hook)
|
|
102
|
+
else:
|
|
103
|
+
# Parent module backward_hook fires before child backward completes.
|
|
104
|
+
# Register on leaf children instead.
|
|
105
|
+
sub_modules = [m for m in module.modules() if is_leaf_module(m) and has_parameters(m)]
|
|
106
|
+
for sub_mod in sub_modules:
|
|
107
|
+
sub_mod.register_full_backward_hook(backward_hook)
|
|
108
|
+
|
|
109
|
+
def after_backward(self):
|
|
110
|
+
self._in_recompute.clear()
|
|
111
|
+
self.param_manager.move_gradients_to_cpu()
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def managed_param_ids(self):
|
|
115
|
+
return set(self.param_manager.param_offloaders.keys())
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class OffloadTrainingManager:
|
|
119
|
+
def __init__(self, model: nn.Module, target_device: torch.device, enable_optimizer_cpu_offload: bool = False, cpu_offload_split_threshold: int = None):
|
|
120
|
+
self.model = model
|
|
121
|
+
self.target_device = target_device
|
|
122
|
+
self.enable_optimizer_cpu_offload = enable_optimizer_cpu_offload
|
|
123
|
+
cpu_offload_split_threshold = cpu_offload_split_threshold * 1024 * 1024 if cpu_offload_split_threshold is not None else None
|
|
124
|
+
self._register_units(model, target_device, enable_optimizer_cpu_offload, cpu_offload_split_threshold)
|
|
125
|
+
|
|
126
|
+
def _register_units(self, model: nn.Module, target_device: torch.device, enable_optimizer_cpu_offload: bool, cpu_offload_split_threshold: int = None):
|
|
127
|
+
self.memory_buffer = PinnedArenaPool.from_model(model)
|
|
128
|
+
units = self._find_units_recursive(model, cpu_offload_split_threshold)
|
|
129
|
+
self.units = [UnitWiseHookManager(u, target_device, enable_optimizer_cpu_offload, memory_buffer=self.memory_buffer) for u in units]
|
|
130
|
+
|
|
131
|
+
managed_param_ids = set().union(*[unit.managed_param_ids for unit in self.units])
|
|
132
|
+
orphan_params, orphan_buffers = self._find_orphan_params_and_buffers(model, managed_param_ids)
|
|
133
|
+
for orphan_module in set(orphan_params.keys()) | set(orphan_buffers.keys()):
|
|
134
|
+
params = orphan_params.get(orphan_module, [])
|
|
135
|
+
buffers = orphan_buffers.get(orphan_module, [])
|
|
136
|
+
self.units.append(UnitWiseHookManager(orphan_module, target_device, enable_optimizer_cpu_offload, params=params, buffers=buffers, memory_buffer=self.memory_buffer))
|
|
137
|
+
|
|
138
|
+
def _find_orphan_params_and_buffers(self, model: nn.Module, managed_param_ids: set):
|
|
139
|
+
orphan_params_by_module = {}
|
|
140
|
+
for _, mod in model.named_modules():
|
|
141
|
+
for param in mod.parameters(recurse=False):
|
|
142
|
+
if id(param) not in managed_param_ids:
|
|
143
|
+
orphan_params_by_module.setdefault(mod, []).append(param)
|
|
144
|
+
# Collect orphan buffers grouped by owner module
|
|
145
|
+
orphan_buffers_by_module = {}
|
|
146
|
+
for _, mod in model.named_modules():
|
|
147
|
+
for name, buf in mod.named_buffers(recurse=False):
|
|
148
|
+
orphan_buffers_by_module.setdefault(mod, []).append((mod, name, buf))
|
|
149
|
+
|
|
150
|
+
return orphan_params_by_module, orphan_buffers_by_module
|
|
151
|
+
|
|
152
|
+
def _find_units_recursive(self, module: nn.Module, cpu_offload_split_threshold: int = None) -> list:
|
|
153
|
+
if cpu_offload_split_threshold is None:
|
|
154
|
+
return [m for m in module.modules() if is_leaf_module(m) and has_parameters(m)]
|
|
155
|
+
if self._should_force_recurse(module, cpu_offload_split_threshold):
|
|
156
|
+
units = []
|
|
157
|
+
for child in module.children():
|
|
158
|
+
units.extend(self._find_units_recursive(child, cpu_offload_split_threshold))
|
|
159
|
+
return units
|
|
160
|
+
return [module]
|
|
161
|
+
|
|
162
|
+
def _should_force_recurse(self, module: nn.Module, cpu_offload_split_threshold: int = None) -> bool:
|
|
163
|
+
if is_leaf_module(module):
|
|
164
|
+
return False
|
|
165
|
+
if (
|
|
166
|
+
count_parameters(module) > cpu_offload_split_threshold
|
|
167
|
+
or ('forward' not in type(module).__dict__)
|
|
168
|
+
or (hasattr(module, 'encode') and hasattr(module, 'decode'))
|
|
169
|
+
):
|
|
170
|
+
return True
|
|
171
|
+
return False
|
|
172
|
+
|
|
173
|
+
# run after backward() and before optimizer.step()
|
|
174
|
+
def after_backward(self):
|
|
175
|
+
for unit in self.units:
|
|
176
|
+
unit.after_backward()
|
|
177
|
+
torch.cuda.synchronize()
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
ALIGNMENT = 64
|
|
4
|
+
|
|
5
|
+
def _align_up(x: int, alignment: int = ALIGNMENT) -> int:
|
|
6
|
+
return (x + alignment - 1) // alignment * alignment
|
|
7
|
+
|
|
8
|
+
def _next_power_of_two(x: int) -> int:
|
|
9
|
+
"""
|
|
10
|
+
Smallest power of two >= x.
|
|
11
|
+
For power-of-two x=2^k: (x-1) has bit_length=k, so 1<<k = x (unchanged).
|
|
12
|
+
For non-power-of-two: (x-1).bit_length() exceeds floor-log2(x), rounding up to next 2^n.
|
|
13
|
+
"""
|
|
14
|
+
return 1 if x <= 1 else 1 << (x - 1).bit_length()
|
|
15
|
+
|
|
16
|
+
def _prev_power_of_two(x: int) -> int:
|
|
17
|
+
"""Largest power-of-two <= x."""
|
|
18
|
+
return 1 if x <= 1 else 1 << (x.bit_length() - 1)
|
|
19
|
+
|
|
20
|
+
def _tensor_storage_size(tensor: torch.Tensor) -> int:
|
|
21
|
+
return _align_up(tensor.numel() * tensor.element_size())
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BaseBufferPool:
|
|
25
|
+
"""Naive per-tensor pin_memory allocation. No pre-allocation, no memory saving."""
|
|
26
|
+
|
|
27
|
+
def allocate_like(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
28
|
+
return tensor.pin_memory()
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def from_model(cls, model: torch.nn.Module, **kwargs):
|
|
32
|
+
return cls()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class PinnedBuffer:
|
|
36
|
+
"""Single pinned uint8 buffer with bump-pointer allocation. Lazy: actual memory allocated on first allocate_like."""
|
|
37
|
+
|
|
38
|
+
def __init__(self, size: int):
|
|
39
|
+
self._size = size
|
|
40
|
+
self._buf: torch.Tensor | None = None
|
|
41
|
+
self._offset = 0
|
|
42
|
+
|
|
43
|
+
def _ensure_allocated(self):
|
|
44
|
+
if self._buf is None:
|
|
45
|
+
self._buf = torch.empty(self._size, dtype=torch.uint8, device="cpu", pin_memory=True)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def capacity(self) -> int:
|
|
49
|
+
return self._size
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def remaining(self) -> int:
|
|
53
|
+
return self._size if self._buf is None else self._buf.numel() - self._offset
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def used(self) -> int:
|
|
57
|
+
return self._offset
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def from_tensor(cls, tensor: torch.Tensor, min_size: int = 1 * 1024**3):
|
|
61
|
+
size = max(_tensor_storage_size(tensor) + ALIGNMENT, min_size)
|
|
62
|
+
return cls(_next_power_of_two(size))
|
|
63
|
+
|
|
64
|
+
def allocate_like(self, tensor: torch.Tensor, *, copy: bool = True, non_blocking: bool = False) -> torch.Tensor | None:
|
|
65
|
+
"""Try to allocate a view for tensor. Returns None if not enough space."""
|
|
66
|
+
num_bytes = tensor.numel() * tensor.element_size()
|
|
67
|
+
if num_bytes > self.remaining:
|
|
68
|
+
return None
|
|
69
|
+
self._ensure_allocated()
|
|
70
|
+
view = self._buf.narrow(0, self._offset, num_bytes).view(tensor.dtype).reshape(tuple(tensor.shape))
|
|
71
|
+
if copy:
|
|
72
|
+
view.copy_(tensor, non_blocking=bool(non_blocking and tensor.device.type == "cuda"))
|
|
73
|
+
self._offset = _align_up(self._offset + num_bytes)
|
|
74
|
+
return view
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class PinnedArenaPool(BaseBufferPool):
|
|
78
|
+
"""Pinned arena pool — pre-allocate pinned memory, avoid per-tensor cudaHostAlloc overhead.
|
|
79
|
+
|
|
80
|
+
Pool strategy:
|
|
81
|
+
1. Sizing: from_model() scans all non-trainable params + buffers, sums their aligned sizes
|
|
82
|
+
as total_bytes. max_chunk_size is raised to fit the largest single tensor.
|
|
83
|
+
2. Decomposition: total_bytes is split into power-of-two chunks (min_chunk_size ~ max_chunk_size).
|
|
84
|
+
Each chunk becomes one PinnedBuffer (lazy — actual pin_memory on first use).
|
|
85
|
+
3. Allocation: allocate_like() sequentially probes each buffer for space (first-fit).
|
|
86
|
+
Each PinnedBuffer uses bump-pointer with ALIGNMENT padding between tensors.
|
|
87
|
+
4. Growth: if all existing buffers are full, _grow() appends a new PinnedBuffer sized
|
|
88
|
+
to the requesting tensor (at least min_chunk_size, power-of-two rounded).
|
|
89
|
+
5. Fallback: on any exception, falls back to per-tensor pin_memory().
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(self, total_bytes: int, min_chunk_size: int = 1 * 1024**3, max_chunk_size: int = 4 * 1024**3):
|
|
93
|
+
self.min_chunk_size = _next_power_of_two(int(min_chunk_size))
|
|
94
|
+
self.max_chunk_size = _next_power_of_two(int(max_chunk_size))
|
|
95
|
+
self.min_chunk_size = min(self.min_chunk_size, self.max_chunk_size)
|
|
96
|
+
self._buffers = [PinnedBuffer(s) for s in self._decompose(total_bytes, self.min_chunk_size, self.max_chunk_size)]
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def from_model(cls, model: torch.nn.Module, min_chunk_size: int = 1 * 1024**3, max_chunk_size: int = 4 * 1024**3):
|
|
100
|
+
"""Size pool for all non-trainable params + buffers."""
|
|
101
|
+
tensors = [p for p in model.parameters() if not p.requires_grad] + list(model.buffers())
|
|
102
|
+
total = sum(_tensor_storage_size(t) for t in tensors)
|
|
103
|
+
max_tensor_size = max((_tensor_storage_size(t) for t in tensors), default=0)
|
|
104
|
+
max_chunk_size = _next_power_of_two(max_tensor_size) if max_tensor_size > max_chunk_size else max_chunk_size
|
|
105
|
+
return cls(total, min_chunk_size=min_chunk_size, max_chunk_size=max_chunk_size)
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def _decompose(total_bytes: int, min_chunk_size: int, max_chunk_size: int) -> list:
|
|
109
|
+
"""Decompose total_bytes into power-of-two chunks capped by min/max."""
|
|
110
|
+
if total_bytes <= 0:
|
|
111
|
+
return []
|
|
112
|
+
chunks, remaining = [], total_bytes
|
|
113
|
+
while remaining > 0:
|
|
114
|
+
chunk = max(min(_prev_power_of_two(remaining), max_chunk_size), min_chunk_size)
|
|
115
|
+
chunks.append(chunk)
|
|
116
|
+
remaining -= chunk
|
|
117
|
+
chunks.sort(reverse=True)
|
|
118
|
+
return chunks
|
|
119
|
+
|
|
120
|
+
def _grow(self, tensor: torch.Tensor):
|
|
121
|
+
self._buffers.append(PinnedBuffer.from_tensor(tensor, min_size=self.min_chunk_size))
|
|
122
|
+
|
|
123
|
+
def allocate_like(self, tensor: torch.Tensor, *, copy: bool = True, require_contiguous: bool = True, non_blocking: bool = False) -> torch.Tensor:
|
|
124
|
+
"""Allocate a pinned view. Falls back to per-tensor pin_memory on failure."""
|
|
125
|
+
src = tensor.detach()
|
|
126
|
+
if require_contiguous and not src.is_contiguous():
|
|
127
|
+
src = src.contiguous()
|
|
128
|
+
try:
|
|
129
|
+
for buf in self._buffers:
|
|
130
|
+
view = buf.allocate_like(src, copy=copy, non_blocking=non_blocking)
|
|
131
|
+
if view is not None:
|
|
132
|
+
return view
|
|
133
|
+
self._grow(src)
|
|
134
|
+
return self._buffers[-1].allocate_like(src, copy=copy, non_blocking=non_blocking)
|
|
135
|
+
except Exception:
|
|
136
|
+
return src.pin_memory()
|