diffsynth 2.0.10__tar.gz → 2.0.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. {diffsynth-2.0.10 → diffsynth-2.0.12}/PKG-INFO +9 -1
  2. {diffsynth-2.0.10 → diffsynth-2.0.12}/README.md +90 -4
  3. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/configs/model_configs.py +11 -16
  4. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/configs/vram_management_module_maps.py +9 -1
  5. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/__init__.py +1 -0
  6. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/attention/attention.py +8 -8
  7. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/data/operators.py +10 -6
  8. diffsynth-2.0.12/diffsynth/core/offload_training/__init__.py +1 -0
  9. diffsynth-2.0.12/diffsynth/core/offload_training/manager.py +177 -0
  10. diffsynth-2.0.12/diffsynth/core/offload_training/memory_buffer.py +136 -0
  11. diffsynth-2.0.12/diffsynth/core/offload_training/offloader.py +71 -0
  12. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/__init__.py +1 -1
  13. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/base_pipeline.py +37 -4
  14. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/flow_match.py +62 -0
  15. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/loss.py +6 -1
  16. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/parsers.py +13 -0
  17. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/runner.py +32 -6
  18. diffsynth-2.0.12/diffsynth/diffusion/template.py +203 -0
  19. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/training_module.py +52 -0
  20. diffsynth-2.0.12/diffsynth/models/ace_step_residual_fsq.py +569 -0
  21. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ace_step_tokenizer.py +2 -4
  22. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/dinov3_image_encoder.py +8 -4
  23. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux2_dit.py +82 -137
  24. diffsynth-2.0.12/diffsynth/models/hidream_common.py +373 -0
  25. diffsynth-2.0.12/diffsynth/models/hidream_o1_image_dit.py +1910 -0
  26. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ltx2_audio_vae.py +4 -2
  27. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/siglip2_image_encoder.py +10 -4
  28. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/ace_step.py +15 -14
  29. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/flux2_image.py +49 -0
  30. diffsynth-2.0.12/diffsynth/pipelines/hidream_o1_image.py +420 -0
  31. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/qwen_image.py +1 -1
  32. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/wan_video.py +1 -1
  33. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ltx2_text_encoder.py +5 -0
  34. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth.egg-info/PKG-INFO +9 -1
  35. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth.egg-info/SOURCES.txt +9 -0
  36. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth.egg-info/requires.txt +9 -0
  37. {diffsynth-2.0.10 → diffsynth-2.0.12}/pyproject.toml +11 -2
  38. {diffsynth-2.0.10 → diffsynth-2.0.12}/LICENSE +0 -0
  39. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/__init__.py +0 -0
  40. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/configs/__init__.py +0 -0
  41. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/attention/__init__.py +0 -0
  42. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/data/__init__.py +0 -0
  43. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/data/unified_dataset.py +0 -0
  44. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/device/__init__.py +0 -0
  45. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/device/npu_compatible_device.py +0 -0
  46. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/gradient/__init__.py +0 -0
  47. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/gradient/gradient_checkpoint.py +0 -0
  48. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/loader/__init__.py +0 -0
  49. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/loader/config.py +0 -0
  50. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/loader/file.py +0 -0
  51. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/loader/model.py +0 -0
  52. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/npu_patch/npu_fused_operator.py +0 -0
  53. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/vram/__init__.py +0 -0
  54. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/vram/disk_map.py +0 -0
  55. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/vram/initialization.py +0 -0
  56. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/core/vram/layers.py +0 -0
  57. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/ddim_scheduler.py +0 -0
  58. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/diffusion/logger.py +0 -0
  59. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ace_step_conditioner.py +0 -0
  60. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ace_step_dit.py +0 -0
  61. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ace_step_text_encoder.py +0 -0
  62. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ace_step_vae.py +0 -0
  63. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/anima_dit.py +0 -0
  64. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ernie_image_dit.py +0 -0
  65. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ernie_image_text_encoder.py +0 -0
  66. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux2_text_encoder.py +0 -0
  67. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux2_vae.py +0 -0
  68. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_controlnet.py +0 -0
  69. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_dit.py +0 -0
  70. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_infiniteyou.py +0 -0
  71. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_ipadapter.py +0 -0
  72. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_lora_encoder.py +0 -0
  73. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_lora_patcher.py +0 -0
  74. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_text_encoder_clip.py +0 -0
  75. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_text_encoder_t5.py +0 -0
  76. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_vae.py +0 -0
  77. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/flux_value_control.py +0 -0
  78. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/general_modules.py +0 -0
  79. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/joyai_image_dit.py +0 -0
  80. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/joyai_image_text_encoder.py +0 -0
  81. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/longcat_video_dit.py +0 -0
  82. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ltx2_common.py +0 -0
  83. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ltx2_dit.py +0 -0
  84. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ltx2_text_encoder.py +0 -0
  85. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ltx2_upsampler.py +0 -0
  86. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/ltx2_video_vae.py +0 -0
  87. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/model_loader.py +0 -0
  88. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/mova_audio_dit.py +0 -0
  89. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/mova_audio_vae.py +0 -0
  90. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/mova_dual_tower_bridge.py +0 -0
  91. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/nexus_gen.py +0 -0
  92. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/nexus_gen_ar_model.py +0 -0
  93. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/nexus_gen_projector.py +0 -0
  94. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/qwen_image_controlnet.py +0 -0
  95. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/qwen_image_dit.py +0 -0
  96. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/qwen_image_image2lora.py +0 -0
  97. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/qwen_image_text_encoder.py +0 -0
  98. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/qwen_image_vae.py +0 -0
  99. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/sd_text_encoder.py +0 -0
  100. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/stable_diffusion_text_encoder.py +0 -0
  101. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/stable_diffusion_unet.py +0 -0
  102. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/stable_diffusion_vae.py +0 -0
  103. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/stable_diffusion_xl_text_encoder.py +0 -0
  104. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/stable_diffusion_xl_unet.py +0 -0
  105. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/step1x_connector.py +0 -0
  106. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/step1x_text_encoder.py +0 -0
  107. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_animate_adapter.py +0 -0
  108. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_camera_controller.py +0 -0
  109. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_dit.py +0 -0
  110. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_dit_s2v.py +0 -0
  111. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_image_encoder.py +0 -0
  112. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_mot.py +0 -0
  113. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_motion_controller.py +0 -0
  114. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_text_encoder.py +0 -0
  115. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_vace.py +0 -0
  116. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wan_video_vae.py +0 -0
  117. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wantodance.py +0 -0
  118. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/wav2vec.py +0 -0
  119. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/z_image_controlnet.py +0 -0
  120. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/z_image_dit.py +0 -0
  121. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/z_image_image2lora.py +0 -0
  122. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/models/z_image_text_encoder.py +0 -0
  123. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/anima_image.py +0 -0
  124. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/ernie_image.py +0 -0
  125. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/flux_image.py +0 -0
  126. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/joyai_image.py +0 -0
  127. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/ltx2_audio_video.py +0 -0
  128. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/mova_audio_video.py +0 -0
  129. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/stable_diffusion.py +0 -0
  130. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/stable_diffusion_xl.py +0 -0
  131. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/pipelines/z_image.py +0 -0
  132. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/controlnet/__init__.py +0 -0
  133. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/controlnet/annotator.py +0 -0
  134. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/controlnet/controlnet_input.py +0 -0
  135. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/data/__init__.py +0 -0
  136. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/data/audio.py +0 -0
  137. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/data/audio_video.py +0 -0
  138. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/data/media_io_ltx2.py +0 -0
  139. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/lora/__init__.py +0 -0
  140. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/lora/flux.py +0 -0
  141. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/lora/general.py +0 -0
  142. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/lora/merge.py +0 -0
  143. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/lora/reset_rank.py +0 -0
  144. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/ses/__init__.py +0 -0
  145. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/ses/ses.py +0 -0
  146. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/__init__.py +0 -0
  147. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ace_step_conditioner.py +0 -0
  148. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ace_step_dit.py +0 -0
  149. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py +0 -0
  150. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py +0 -0
  151. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/anima_dit.py +0 -0
  152. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/dino_v3.py +0 -0
  153. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ernie_image_text_encoder.py +0 -0
  154. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux2_text_encoder.py +0 -0
  155. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_controlnet.py +0 -0
  156. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_dit.py +0 -0
  157. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_infiniteyou.py +0 -0
  158. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_ipadapter.py +0 -0
  159. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py +0 -0
  160. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py +0 -0
  161. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/flux_vae.py +0 -0
  162. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/joyai_image_text_encoder.py +0 -0
  163. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py +0 -0
  164. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ltx2_dit.py +0 -0
  165. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/ltx2_video_vae.py +0 -0
  166. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/nexus_gen.py +0 -0
  167. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/nexus_gen_projector.py +0 -0
  168. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/qwen_image_text_encoder.py +0 -0
  169. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/stable_diffusion_text_encoder.py +0 -0
  170. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/stable_diffusion_vae.py +0 -0
  171. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/stable_diffusion_xl_text_encoder.py +0 -0
  172. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/step1x_connector.py +0 -0
  173. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_animate_adapter.py +0 -0
  174. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_dit.py +0 -0
  175. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_image_encoder.py +0 -0
  176. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_mot.py +0 -0
  177. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_vace.py +0 -0
  178. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wan_video_vae.py +0 -0
  179. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/wans2v_audio_encoder.py +0 -0
  180. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/z_image_dit.py +0 -0
  181. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/state_dict_converters/z_image_text_encoder.py +0 -0
  182. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/xfuser/__init__.py +0 -0
  183. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/utils/xfuser/xdit_context_parallel.py +0 -0
  184. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth/version.py +0 -0
  185. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth.egg-info/dependency_links.txt +0 -0
  186. {diffsynth-2.0.10 → diffsynth-2.0.12}/diffsynth.egg-info/top_level.txt +0 -0
  187. {diffsynth-2.0.10 → diffsynth-2.0.12}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth
3
- Version: 2.0.10
3
+ Version: 2.0.12
4
4
  Summary: Enjoy the magic of Diffusion models!
5
5
  Author: ModelScope Team
6
6
  License: Apache-2.0
@@ -33,6 +33,14 @@ Requires-Dist: torch==2.7.1+cpu; extra == "npu"
33
33
  Requires-Dist: torch-npu==2.7.1; extra == "npu"
34
34
  Requires-Dist: torchvision==0.22.1+cpu; extra == "npu"
35
35
  Provides-Extra: audio
36
+ Requires-Dist: av; extra == "audio"
36
37
  Requires-Dist: torchaudio; extra == "audio"
37
38
  Requires-Dist: torchcodec; extra == "audio"
39
+ Requires-Dist: librosa; extra == "audio"
40
+ Provides-Extra: all
41
+ Requires-Dist: av; extra == "all"
42
+ Requires-Dist: torchaudio; extra == "all"
43
+ Requires-Dist: torchcodec; extra == "all"
44
+ Requires-Dist: librosa; extra == "all"
45
+ Requires-Dist: streamlit; extra == "all"
38
46
  Dynamic: license-file
@@ -34,6 +34,19 @@ We believe that a well-developed open-source code framework can lower the thresh
34
34
 
35
35
  > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
36
36
 
37
+ - **May 18, 2026** Added **CPU Offload Training** support. By moving model weights layer-by-layer between CPU and GPU, it significantly reduces GPU VRAM usage during training, enabling LoRA training of large models even on consumer-grade GPUs, compatible with all models. Simply add `--enable_model_cpu_offload` to your training command to enable (currently supports single-GPU training only). For details, see the [documentation](/docs/en/Training/Offload_Training.md).
38
+
39
+ - **May 14, 2026** HiDream-O1-Image open-sourced, welcome a new member to the image model family! Support includes text-to-image generation, image editing, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/HiDream-O1-Image.md) and [example code](/examples/hidream_o1_image/).
40
+
41
+ - **April 28, 2026** 🔥 We are excited to announce the release of **Diffusion Templates**, a plugin framework designed for Diffusion models that significantly lowers the barrier to training controllable generative models. Let's explore this cutting-edge technology together!
42
+ * Open-source code: [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
43
+ * Technical report: [arXiv](https://arxiv.org/abs/2604.24351)
44
+ * Project homepage: [GitHub](https://modelscope.github.io/diffusion-templates-web/)
45
+ * Documentation: [English Version](https://diffsynth-studio-doc.readthedocs.io/en/latest/Diffusion_Templates/Introducing_Diffusion_Templates.html) | [Chinese Version](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/Diffusion_Templates/Introducing_Diffusion_Templates.html)
46
+ * Online demo: [ModelScope](https://modelscope.cn/studios/DiffSynth-Studio/Diffusion-Templates)
47
+ * Model collections: [ModelScope](https://modelscope.cn/collections/DiffSynth-Studio/KleinBase4B-Templates) | [ModelScope International](https://modelscope.ai/collections/DiffSynth-Studio/KleinBase4B-Templates) | [HuggingFace](https://huggingface.co/collections/DiffSynth-Studio/kleinbase4b-templates)
48
+ * Datasets: [ModelScope](https://modelscope.cn/collections/DiffSynth-Studio/ImagePulseV2) | [ModelScope International](https://modelscope.ai/collections/DiffSynth-Studio/ImagePulseV2) | [HuggingFace](https://huggingface.co/collections/DiffSynth-Studio/imagepulsev2)
49
+
37
50
  - **April 27, 2026** We support ACE-Step-1.5! Support includes text-to-music generation, low VRAM inference, and LoRA training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/ACE-Step.md) and [example code](/examples/ace_step/).
38
51
 
39
52
  - **April 27, 2026**: We have reinstated support for the Stable Diffusion v1.5 and SDXL models, providing academic research support exclusively for these two model types.
@@ -96,7 +109,7 @@ We believe that a well-developed open-source code framework can lower the thresh
96
109
 
97
110
  - **August 20, 2025** We open-sourced the [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) model, improving the editing effect of Qwen-Image-Edit on low-resolution image inputs. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
98
111
 
99
- - **August 19, 2025** 🔥 Qwen-Image-Edit open-sourced, welcome a new member to the image editing model family!
112
+ - **August 19, 2025** Qwen-Image-Edit open-sourced, welcome a new member to the image editing model family!
100
113
 
101
114
  - **August 18, 2025** We trained and open-sourced the Qwen-Image inpainting ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py).
102
115
 
@@ -112,7 +125,7 @@ We believe that a well-developed open-source code framework can lower the thresh
112
125
 
113
126
  - **August 5, 2025** We open-sourced the distilled acceleration model [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) for Qwen-Image, achieving approximately 5x acceleration.
114
127
 
115
- - **August 4, 2025** 🔥 Qwen-Image open-sourced, welcome a new member to the image generation model family!
128
+ - **August 4, 2025** Qwen-Image open-sourced, welcome a new member to the image generation model family!
116
129
 
117
130
  - **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) open-sourced, a text-to-image model focused on aesthetic photography. We provided comprehensive support in a timely manner, including low VRAM layer-by-layer offload, LoRA training, and full training. For more details, please refer to [./examples/flux/](./examples/flux/).
118
131
 
@@ -479,6 +492,17 @@ Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/)
479
492
  |[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
480
493
  |[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
481
494
  |[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
495
+ |[DiffSynth-Studio/Template-KleinBase4B-Aesthetic](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Aesthetic)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Aesthetic.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Aesthetic.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Aesthetic.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Aesthetic.py)|-|-|
496
+ |[DiffSynth-Studio/Template-KleinBase4B-Brightness](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Brightness)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Brightness.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Brightness.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Brightness.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Brightness.py)|-|-|
497
+ |[DiffSynth-Studio/Template-KleinBase4B-Age](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Age)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Age.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Age.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Age.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Age.py)|-|-|
498
+ |[DiffSynth-Studio/Template-KleinBase4B-ControlNet](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-ControlNet)|[code](/examples/flux2/model_inference/Template-KleinBase4B-ControlNet.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-ControlNet.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-ControlNet.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-ControlNet.py)|-|-|
499
+ |[DiffSynth-Studio/Template-KleinBase4B-Edit](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Edit)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Edit.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Edit.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Edit.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Edit.py)|-|-|
500
+ |[DiffSynth-Studio/Template-KleinBase4B-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Inpaint)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Inpaint.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Inpaint.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Inpaint.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Inpaint.py)|-|-|
501
+ |[DiffSynth-Studio/Template-KleinBase4B-PandaMeme](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-PandaMeme)|[code](/examples/flux2/model_inference/Template-KleinBase4B-PandaMeme.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-PandaMeme.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-PandaMeme.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-PandaMeme.py)|-|-|
502
+ |[DiffSynth-Studio/Template-KleinBase4B-Sharpness](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Sharpness)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Sharpness.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Sharpness.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Sharpness.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Sharpness.py)|-|-|
503
+ |[DiffSynth-Studio/Template-KleinBase4B-SoftRGB](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-SoftRGB)|[code](/examples/flux2/model_inference/Template-KleinBase4B-SoftRGB.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-SoftRGB.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-SoftRGB.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-SoftRGB.py)|-|-|
504
+ |[DiffSynth-Studio/Template-KleinBase4B-Upscaler](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Upscaler)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Upscaler.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Upscaler.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Upscaler.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Upscaler.py)|-|-|
505
+ |[DiffSynth-Studio/Template-KleinBase4B-ContentRef](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-ContentRef)|[code](/examples/flux2/model_inference/Template-KleinBase4B-ContentRef.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-ContentRef.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-ContentRef.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-ContentRef.py)|-|-|
482
506
 
483
507
  </details>
484
508
 
@@ -864,6 +888,68 @@ Example code for JoyAI-Image is available at: [/examples/joyai_image/](/examples
864
888
 
865
889
  </details>
866
890
 
891
+ #### HiDream-O1-Image: [/docs/en/Model_Details/HiDream-O1-Image.md](/docs/en/Model_Details/HiDream-O1-Image.md)
892
+
893
+ <details>
894
+
895
+ <summary>Quick Start</summary>
896
+
897
+ Running the following code will quickly load the [HiDream-ai/HiDream-O1-Image](https://modelscope.cn/HiDream-ai/HiDream-O1-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
898
+
899
+ ```python
900
+ from diffsynth.pipelines.hidream_o1_image import HiDreamO1ImagePipeline
901
+ from diffsynth.core.loader.config import ModelConfig
902
+ import torch
903
+
904
+
905
+ vram_config = {
906
+ "offload_dtype": torch.bfloat16,
907
+ "offload_device": "cpu",
908
+ "onload_dtype": torch.bfloat16,
909
+ "onload_device": "cpu",
910
+ "preparing_dtype": torch.bfloat16,
911
+ "preparing_device": "cuda",
912
+ "computation_dtype": torch.bfloat16,
913
+ "computation_device": "cuda",
914
+ }
915
+
916
+
917
+ pipe = HiDreamO1ImagePipeline.from_pretrained(
918
+ torch_dtype=torch.bfloat16,
919
+ device="cuda",
920
+ model_configs=[
921
+ ModelConfig(model_id="HiDream-ai/HiDream-O1-Image", origin_file_pattern="model-*.safetensors", **vram_config),
922
+ ],
923
+ processor_config=ModelConfig(model_id="HiDream-ai/HiDream-O1-Image", origin_file_pattern="./"),
924
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
925
+ )
926
+ image = pipe(
927
+ prompt="medium shot, eye-level, front view. A woman is seated in an ornate bedroom, illuminated by candlelight, with a calm and composed expression. The subject is a young woman with fair skin, light brown hair styled in an updo with loose tendrils framing her face, and blue eyes. She wears a cream-colored satin robe with delicate floral embroidery and lace trim along the neckline. Her ears are adorned with pearl drop earrings. She is seated on a bed with a dark, intricately carved wooden headboard. To her left, a wooden nightstand holds three lit white candles and a candelabra with multiple lit candles in the background. The bed is covered with patterned pillows and a dark, textured blanket. The walls are paneled with dark wood and feature a large, ornate tapestry with muted earth tones. The lighting creates soft highlights on her face and robe, with warm shadows cast across the room.",
928
+ negative_prompt=" ",
929
+ cfg_scale=4.0,
930
+ height=2048,
931
+ width=2048,
932
+ seed=42,
933
+ num_inference_steps=50,
934
+ )
935
+ image.save("image.jpg")
936
+ ```
937
+
938
+ </details>
939
+
940
+ <details>
941
+
942
+ <summary>Examples</summary>
943
+
944
+ Example code for HiDream-O1-Image is available at: [/examples/hidream_o1_image/](/examples/hidream_o1_image/)
945
+
946
+ | Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
947
+ |-|-|-|-|-|-|-|
948
+ |[HiDream-ai/HiDream-O1-Image](https://modelscope.cn/HiDream-ai/HiDream-O1-Image)|[code](/examples/hidream_o1_image/model_inference/HiDream-O1-Image.py)|[code](/examples/hidream_o1_image/model_inference_low_vram/HiDream-O1-Image.py)|[code](/examples/hidream_o1_image/model_training/full/HiDream-O1-Image.sh)|[code](/examples/hidream_o1_image/model_training/validate_full/HiDream-O1-Image.py)|[code](/examples/hidream_o1_image/model_training/lora/HiDream-O1-Image.sh)|[code](/examples/hidream_o1_image/model_training/validate_lora/HiDream-O1-Image.py)|
949
+ |[HiDream-ai/HiDream-O1-Image-Dev](https://modelscope.cn/HiDream-ai/HiDream-O1-Image-Dev)|[code](/examples/hidream_o1_image/model_inference/HiDream-O1-Image-Dev.py)|[code](/examples/hidream_o1_image/model_inference_low_vram/HiDream-O1-Image-Dev.py)|[code](/examples/hidream_o1_image/model_training/full/HiDream-O1-Image-Dev.sh)|[code](/examples/hidream_o1_image/model_training/validate_full/HiDream-O1-Image-Dev.py)|[code](/examples/hidream_o1_image/model_training/lora/HiDream-O1-Image-Dev.sh)|[code](/examples/hidream_o1_image/model_training/validate_lora/HiDream-O1-Image-Dev.py)|
950
+
951
+ </details>
952
+
867
953
  ### Video Synthesis
868
954
 
869
955
  https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
@@ -1138,8 +1224,8 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
1138
1224
  |[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
1139
1225
  |[openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py)|
1140
1226
  |[openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py)|
1141
- |[Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py)|
1142
- |[Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py)|
1227
+ |[Wan-AI/Wan2.2-Dancer-14B (global model)](https://modelscope.cn/models/Wan-AI/Wan2.2-Dancer-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Dancer-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Dancer-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Dancer-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Dancer-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Dancer-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Dancer-14B-global.py)|
1228
+ |[Wan-AI/Wan2.2-Dancer-14B (local model)](https://modelscope.cn/models/Wan-AI/Wan2.2-Dancer-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Dancer-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Dancer-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Dancer-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Dancer-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Dancer-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Dancer-14B-local.py)|
1143
1229
 
1144
1230
  </details>
1145
1231
 
@@ -309,7 +309,7 @@ wan_series = [
309
309
  "state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
310
310
  },
311
311
  {
312
- # Example: ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors")
312
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.2-Dancer-14B", origin_file_pattern="global_model.safetensors")
313
313
  "model_hash": "eb18873fc0ba77b541eb7b62dbcd2059",
314
314
  "model_name": "wan_video_dit",
315
315
  "model_class": "diffsynth.models.wan_video_dit.WanModel",
@@ -833,20 +833,6 @@ ltx2_series = [
833
833
  "extra_kwargs": {"decoder_version": "ltx-2.3"},
834
834
  "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
835
835
  },
836
- {
837
- # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors")
838
- "model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
839
- "model_name": "ltx2_audio_vae_decoder",
840
- "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
841
- "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
842
- },
843
- {
844
- # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vae_encoder.safetensors")
845
- "model_hash": "29338f3b95e7e312a3460a482e4f4554",
846
- "model_name": "ltx2_audio_vae_encoder",
847
- "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
848
- "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
849
- },
850
836
  {
851
837
  # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors")
852
838
  "model_hash": "cd436c99e69ec5c80f050f0944f02a15",
@@ -1040,7 +1026,16 @@ ace_step_series = [
1040
1026
  },
1041
1027
  ]
1042
1028
 
1029
+ hidream_o1_image_series = [
1030
+ {
1031
+ # Example: ModelConfig(model_id="HiDream-ai/HiDream-O1-Image", origin_file_pattern="model-*.safetensors")
1032
+ "model_hash": "58a7c1073d79556bfc61e05e6061b771",
1033
+ "model_name": "hidream_o1_image_dit",
1034
+ "model_class": "diffsynth.models.hidream_o1_image_dit.HiDreamO1ImageModel",
1035
+ },
1036
+ ]
1037
+
1043
1038
  MODEL_CONFIGS = (
1044
1039
  stable_diffusion_xl_series + stable_diffusion_series + qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series
1045
- + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
1040
+ + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series + hidream_o1_image_series
1046
1041
  )
@@ -327,7 +327,7 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
327
327
  "diffsynth.models.ace_step_tokenizer.AceStepTokenizer": {
328
328
  "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
329
329
  "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
330
- "vector_quantize_pytorch.ResidualFSQ": "diffsynth.core.vram.layers.AutoWrappedModule",
330
+ "diffsynth.models.ace_step_residual_fsq.ResidualFSQ": "diffsynth.core.vram.layers.AutoWrappedModule",
331
331
  "transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
332
332
  "transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
333
333
  "transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
@@ -372,6 +372,14 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
372
372
  "diffsynth.models.stable_diffusion_text_encoder.CLIPAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
373
373
  "diffsynth.models.stable_diffusion_xl_text_encoder.CLIPTextModelWithProjection": "diffsynth.core.vram.layers.AutoWrappedModule",
374
374
  },
375
+ "diffsynth.models.hidream_o1_image_dit.HiDreamO1ImageModel": {
376
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
377
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
378
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
379
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
380
+ "diffsynth.models.hidream_o1_image_dit.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
381
+ "diffsynth.models.hidream_o1_image_dit.Qwen3VLVisionModel": "diffsynth.core.vram.layers.AutoWrappedModule",
382
+ },
375
383
  }
376
384
 
377
385
  def QwenImageTextEncoder_Module_Map_Updater():
@@ -4,3 +4,4 @@ from .gradient import *
4
4
  from .loader import *
5
5
  from .vram import *
6
6
  from .device import *
7
+ from .offload_training import *
@@ -63,10 +63,10 @@ def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern
63
63
  return out
64
64
 
65
65
 
66
- def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
66
+ def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, is_causal=False):
67
67
  required_in_pattern, required_out_pattern= "b n s d", "b n s d"
68
68
  q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
69
- out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
69
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale, is_causal=is_causal)
70
70
  out = rearrange_out(out, out_pattern, required_out_pattern, dims)
71
71
  return out
72
72
 
@@ -81,10 +81,10 @@ def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_patte
81
81
  return out
82
82
 
83
83
 
84
- def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
84
+ def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None, is_causal=False):
85
85
  required_in_pattern, required_out_pattern= "b s n d", "b s n d"
86
86
  q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
87
- out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
87
+ out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale, causal=is_causal)
88
88
  out = rearrange_out(out, out_pattern, required_out_pattern, dims)
89
89
  return out
90
90
 
@@ -105,17 +105,17 @@ def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_patt
105
105
  return out
106
106
 
107
107
 
108
- def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
108
+ def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, is_causal=False, compatibility_mode=False):
109
109
  if compatibility_mode or (attn_mask is not None):
110
- return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
110
+ return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale, is_causal=is_causal)
111
111
  else:
112
112
  if ATTENTION_IMPLEMENTATION == "flash_attention_3":
113
113
  return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
114
114
  elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
115
- return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
115
+ return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale, is_causal=is_causal)
116
116
  elif ATTENTION_IMPLEMENTATION == "sage_attention":
117
117
  return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
118
118
  elif ATTENTION_IMPLEMENTATION == "xformers":
119
119
  return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
120
120
  else:
121
- return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
121
+ return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale, is_causal=is_causal)
@@ -2,8 +2,6 @@ import math, warnings
2
2
  import torch, torchvision, imageio, os
3
3
  import imageio.v3 as iio
4
4
  from PIL import Image
5
- import torchaudio
6
- from diffsynth.utils.data.audio import read_audio
7
5
 
8
6
 
9
7
  class DataProcessingPipeline:
@@ -249,9 +247,11 @@ class ToAbsolutePath(DataProcessingOperator):
249
247
  class LoadAudio(DataProcessingOperator):
250
248
  def __init__(self, sr=16000):
251
249
  self.sr = sr
252
- def __call__(self, data: str):
253
250
  import librosa
254
- input_audio, sample_rate = librosa.load(data, sr=self.sr)
251
+ self.audio_loader = librosa.load
252
+
253
+ def __call__(self, data: str):
254
+ input_audio, sample_rate = self.audio_loader(data, sr=self.sr)
255
255
  return input_audio
256
256
 
257
257
 
@@ -259,13 +259,15 @@ class LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin):
259
259
 
260
260
  def __init__(self, num_frames=121, time_division_factor=8, time_division_remainder=1, frame_rate=24, fix_frame_rate=True):
261
261
  FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)
262
+ import torchaudio
263
+ self.audio_loader = torchaudio.load
262
264
 
263
265
  def __call__(self, data: str):
264
266
  try:
265
267
  reader = self.get_reader(data)
266
268
  num_frames = self.get_num_frames(reader)
267
269
  duration = num_frames / self.frame_rate
268
- waveform, sample_rate = torchaudio.load(data)
270
+ waveform, sample_rate = self.audio_loader(data)
269
271
  target_samples = int(duration * sample_rate)
270
272
  current_samples = waveform.shape[-1]
271
273
  if current_samples > target_samples:
@@ -285,10 +287,12 @@ class LoadPureAudioWithTorchaudio(DataProcessingOperator):
285
287
  self.target_sample_rate = target_sample_rate
286
288
  self.target_duration = target_duration
287
289
  self.resample = True if target_sample_rate is not None else False
290
+ from diffsynth.utils.data.audio import read_audio
291
+ self.audio_loader = read_audio
288
292
 
289
293
  def __call__(self, data: str):
290
294
  try:
291
- waveform, sample_rate = read_audio(data, resample=self.resample, resample_rate=self.target_sample_rate)
295
+ waveform, sample_rate = self.audio_loader(data, resample=self.resample, resample_rate=self.target_sample_rate)
292
296
  if self.target_duration is not None:
293
297
  target_samples = int(self.target_duration * sample_rate)
294
298
  current_samples = waveform.shape[-1]
@@ -0,0 +1 @@
1
+ from .manager import OffloadTrainingManager
@@ -0,0 +1,177 @@
1
+ """
2
+ Layer offloading for training — hook-based CPU offload.
3
+
4
+ Hook lifecycle per module:
5
+
6
+ No checkpointing:
7
+ forward_pre(load→GPU) → forward() → forward_hook(offload)
8
+ backward_pre(load→GPU) → backward() → backward_hook(offload)
9
+
10
+ With checkpointing (use_reentrant=False):
11
+ First forward:
12
+ forward_pre(load→GPU) → forward() → forward_hook(offload, mark in_recompute)
13
+ Recomputing forward (during backward):
14
+ forward_pre(load→GPU) → forward() → forward_hook(in_recompute=True → keep GPU)
15
+ Backward:
16
+ backward_pre(load→GPU) → backward() → backward_hook(offload)
17
+ """
18
+ import torch
19
+ import torch.nn as nn
20
+ import warnings
21
+ from .offloader import StaticParamOffloader, TrainableParamOffloader, AlwaysOnGPUParamOffloader, BufferOffloader
22
+ from .memory_buffer import PinnedArenaPool, BaseBufferPool
23
+ warnings.filterwarnings("ignore", message="Full backward hook is firing when gradients are computed with respect to module outputs")
24
+
25
+
26
+ def has_parameters(module: nn.Module) -> bool:
27
+ return len(list(module.parameters())) > 0
28
+
29
+ def count_parameters(module: nn.Module) -> int:
30
+ return sum(p.numel() for p in module.parameters())
31
+
32
+ def is_leaf_module(module: nn.Module) -> bool:
33
+ return len(list(module.children())) == 0
34
+
35
+
36
+ class UnitWiseParamManager:
37
+ def __init__(self, model: nn.Module, target_device: torch.device, enable_optimizer_cpu_offload: bool = False, params: list = None, buffers: list = None, memory_buffer: BaseBufferPool = None):
38
+ self.model = model
39
+ self.target_device = target_device
40
+ self.param_offloaders = {}
41
+ for param in (model.parameters() if params is None else params):
42
+ if not param.requires_grad:
43
+ self.param_offloaders[id(param)] = StaticParamOffloader(param, target_device, memory_buffer=memory_buffer)
44
+ else:
45
+ if enable_optimizer_cpu_offload:
46
+ self.param_offloaders[id(param)] = TrainableParamOffloader(param, target_device)
47
+ else:
48
+ self.param_offloaders[id(param)] = AlwaysOnGPUParamOffloader(param, target_device)
49
+ if buffers is not None and len(buffers) > 0:
50
+ for mod, buf_name, buf in buffers:
51
+ self.param_offloaders[id(buf)] = BufferOffloader(mod, buf_name, buf, target_device, memory_buffer=memory_buffer)
52
+
53
+ def move_gradients_to_cpu(self):
54
+ for offloader in self.param_offloaders.values():
55
+ offloader.offload_grad()
56
+
57
+ def onload_module(self, module: nn.Module):
58
+ for param in module.parameters(recurse=False):
59
+ if id(param) in self.param_offloaders:
60
+ self.param_offloaders[id(param)].onload()
61
+ for name, buf in module.named_buffers(recurse=False):
62
+ if id(buf) in self.param_offloaders:
63
+ self.param_offloaders[id(buf)].onload()
64
+
65
+ def offload_module(self, module: nn.Module):
66
+ for param in module.parameters(recurse=False):
67
+ if id(param) in self.param_offloaders:
68
+ self.param_offloaders[id(param)].offload()
69
+ for name, buf in module.named_buffers(recurse=False):
70
+ if id(buf) in self.param_offloaders:
71
+ self.param_offloaders[id(buf)].offload()
72
+
73
+
74
+ class UnitWiseHookManager:
75
+ def __init__(self, model: nn.Module, target_device: torch.device, enable_optimizer_cpu_offload: bool = False,
76
+ params: list = None, buffers: list = None, memory_buffer: BaseBufferPool = None):
77
+ self.param_manager = UnitWiseParamManager(model, target_device, enable_optimizer_cpu_offload, params=params, buffers=buffers, memory_buffer=memory_buffer)
78
+ self._in_recompute: set = set()
79
+ self._register_hooks(model)
80
+
81
+ def _register_hooks(self, module: nn.Module):
82
+ def forward_pre_hook(mod, args):
83
+ self.param_manager.onload_module(mod)
84
+
85
+ def forward_hook(mod, args, output):
86
+ if mod in self._in_recompute:
87
+ return
88
+ self._in_recompute.add(mod)
89
+ self.param_manager.offload_module(mod)
90
+
91
+ def backward_pre_hook(mod, grad_output):
92
+ self.param_manager.onload_module(mod)
93
+
94
+ def backward_hook(mod, grad_input, grad_output):
95
+ self.param_manager.offload_module(mod)
96
+
97
+ module.register_forward_pre_hook(forward_pre_hook)
98
+ module.register_forward_hook(forward_hook)
99
+ module.register_full_backward_pre_hook(backward_pre_hook)
100
+ if is_leaf_module(module):
101
+ module.register_full_backward_hook(backward_hook)
102
+ else:
103
+ # Parent module backward_hook fires before child backward completes.
104
+ # Register on leaf children instead.
105
+ sub_modules = [m for m in module.modules() if is_leaf_module(m) and has_parameters(m)]
106
+ for sub_mod in sub_modules:
107
+ sub_mod.register_full_backward_hook(backward_hook)
108
+
109
+ def after_backward(self):
110
+ self._in_recompute.clear()
111
+ self.param_manager.move_gradients_to_cpu()
112
+
113
+ @property
114
+ def managed_param_ids(self):
115
+ return set(self.param_manager.param_offloaders.keys())
116
+
117
+
118
+ class OffloadTrainingManager:
119
+ def __init__(self, model: nn.Module, target_device: torch.device, enable_optimizer_cpu_offload: bool = False, cpu_offload_split_threshold: int = None):
120
+ self.model = model
121
+ self.target_device = target_device
122
+ self.enable_optimizer_cpu_offload = enable_optimizer_cpu_offload
123
+ cpu_offload_split_threshold = cpu_offload_split_threshold * 1024 * 1024 if cpu_offload_split_threshold is not None else None
124
+ self._register_units(model, target_device, enable_optimizer_cpu_offload, cpu_offload_split_threshold)
125
+
126
+ def _register_units(self, model: nn.Module, target_device: torch.device, enable_optimizer_cpu_offload: bool, cpu_offload_split_threshold: int = None):
127
+ self.memory_buffer = PinnedArenaPool.from_model(model)
128
+ units = self._find_units_recursive(model, cpu_offload_split_threshold)
129
+ self.units = [UnitWiseHookManager(u, target_device, enable_optimizer_cpu_offload, memory_buffer=self.memory_buffer) for u in units]
130
+
131
+ managed_param_ids = set().union(*[unit.managed_param_ids for unit in self.units])
132
+ orphan_params, orphan_buffers = self._find_orphan_params_and_buffers(model, managed_param_ids)
133
+ for orphan_module in set(orphan_params.keys()) | set(orphan_buffers.keys()):
134
+ params = orphan_params.get(orphan_module, [])
135
+ buffers = orphan_buffers.get(orphan_module, [])
136
+ self.units.append(UnitWiseHookManager(orphan_module, target_device, enable_optimizer_cpu_offload, params=params, buffers=buffers, memory_buffer=self.memory_buffer))
137
+
138
+ def _find_orphan_params_and_buffers(self, model: nn.Module, managed_param_ids: set):
139
+ orphan_params_by_module = {}
140
+ for _, mod in model.named_modules():
141
+ for param in mod.parameters(recurse=False):
142
+ if id(param) not in managed_param_ids:
143
+ orphan_params_by_module.setdefault(mod, []).append(param)
144
+ # Collect orphan buffers grouped by owner module
145
+ orphan_buffers_by_module = {}
146
+ for _, mod in model.named_modules():
147
+ for name, buf in mod.named_buffers(recurse=False):
148
+ orphan_buffers_by_module.setdefault(mod, []).append((mod, name, buf))
149
+
150
+ return orphan_params_by_module, orphan_buffers_by_module
151
+
152
+ def _find_units_recursive(self, module: nn.Module, cpu_offload_split_threshold: int = None) -> list:
153
+ if cpu_offload_split_threshold is None:
154
+ return [m for m in module.modules() if is_leaf_module(m) and has_parameters(m)]
155
+ if self._should_force_recurse(module, cpu_offload_split_threshold):
156
+ units = []
157
+ for child in module.children():
158
+ units.extend(self._find_units_recursive(child, cpu_offload_split_threshold))
159
+ return units
160
+ return [module]
161
+
162
+ def _should_force_recurse(self, module: nn.Module, cpu_offload_split_threshold: int = None) -> bool:
163
+ if is_leaf_module(module):
164
+ return False
165
+ if (
166
+ count_parameters(module) > cpu_offload_split_threshold
167
+ or ('forward' not in type(module).__dict__)
168
+ or (hasattr(module, 'encode') and hasattr(module, 'decode'))
169
+ ):
170
+ return True
171
+ return False
172
+
173
+ # run after backward() and before optimizer.step()
174
+ def after_backward(self):
175
+ for unit in self.units:
176
+ unit.after_backward()
177
+ torch.cuda.synchronize()