diffsynth 2.0.10__tar.gz → 2.0.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. {diffsynth-2.0.10 → diffsynth-2.0.11}/PKG-INFO +1 -1
  2. {diffsynth-2.0.10 → diffsynth-2.0.11}/README.md +22 -2
  3. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/base_pipeline.py +37 -4
  4. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/loss.py +5 -0
  5. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/parsers.py +6 -0
  6. diffsynth-2.0.11/diffsynth/diffusion/template.py +203 -0
  7. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/training_module.py +52 -0
  8. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/dinov3_image_encoder.py +8 -4
  9. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux2_dit.py +82 -137
  10. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/siglip2_image_encoder.py +10 -4
  11. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/flux2_image.py +49 -0
  12. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/wan_video.py +1 -1
  13. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth.egg-info/PKG-INFO +1 -1
  14. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth.egg-info/SOURCES.txt +1 -0
  15. {diffsynth-2.0.10 → diffsynth-2.0.11}/pyproject.toml +1 -1
  16. {diffsynth-2.0.10 → diffsynth-2.0.11}/LICENSE +0 -0
  17. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/__init__.py +0 -0
  18. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/configs/__init__.py +0 -0
  19. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/configs/model_configs.py +0 -0
  20. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/configs/vram_management_module_maps.py +0 -0
  21. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/__init__.py +0 -0
  22. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/attention/__init__.py +0 -0
  23. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/attention/attention.py +0 -0
  24. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/data/__init__.py +0 -0
  25. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/data/operators.py +0 -0
  26. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/data/unified_dataset.py +0 -0
  27. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/device/__init__.py +0 -0
  28. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/device/npu_compatible_device.py +0 -0
  29. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/gradient/__init__.py +0 -0
  30. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/gradient/gradient_checkpoint.py +0 -0
  31. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/loader/__init__.py +0 -0
  32. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/loader/config.py +0 -0
  33. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/loader/file.py +0 -0
  34. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/loader/model.py +0 -0
  35. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/npu_patch/npu_fused_operator.py +0 -0
  36. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/vram/__init__.py +0 -0
  37. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/vram/disk_map.py +0 -0
  38. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/vram/initialization.py +0 -0
  39. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/core/vram/layers.py +0 -0
  40. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/__init__.py +0 -0
  41. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/ddim_scheduler.py +0 -0
  42. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/flow_match.py +0 -0
  43. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/logger.py +0 -0
  44. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/diffusion/runner.py +0 -0
  45. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ace_step_conditioner.py +0 -0
  46. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ace_step_dit.py +0 -0
  47. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ace_step_text_encoder.py +0 -0
  48. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ace_step_tokenizer.py +0 -0
  49. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ace_step_vae.py +0 -0
  50. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/anima_dit.py +0 -0
  51. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ernie_image_dit.py +0 -0
  52. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ernie_image_text_encoder.py +0 -0
  53. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux2_text_encoder.py +0 -0
  54. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux2_vae.py +0 -0
  55. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_controlnet.py +0 -0
  56. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_dit.py +0 -0
  57. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_infiniteyou.py +0 -0
  58. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_ipadapter.py +0 -0
  59. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_lora_encoder.py +0 -0
  60. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_lora_patcher.py +0 -0
  61. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_text_encoder_clip.py +0 -0
  62. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_text_encoder_t5.py +0 -0
  63. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_vae.py +0 -0
  64. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/flux_value_control.py +0 -0
  65. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/general_modules.py +0 -0
  66. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/joyai_image_dit.py +0 -0
  67. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/joyai_image_text_encoder.py +0 -0
  68. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/longcat_video_dit.py +0 -0
  69. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ltx2_audio_vae.py +0 -0
  70. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ltx2_common.py +0 -0
  71. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ltx2_dit.py +0 -0
  72. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ltx2_text_encoder.py +0 -0
  73. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ltx2_upsampler.py +0 -0
  74. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/ltx2_video_vae.py +0 -0
  75. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/model_loader.py +0 -0
  76. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/mova_audio_dit.py +0 -0
  77. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/mova_audio_vae.py +0 -0
  78. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/mova_dual_tower_bridge.py +0 -0
  79. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/nexus_gen.py +0 -0
  80. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/nexus_gen_ar_model.py +0 -0
  81. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/nexus_gen_projector.py +0 -0
  82. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/qwen_image_controlnet.py +0 -0
  83. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/qwen_image_dit.py +0 -0
  84. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/qwen_image_image2lora.py +0 -0
  85. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/qwen_image_text_encoder.py +0 -0
  86. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/qwen_image_vae.py +0 -0
  87. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/sd_text_encoder.py +0 -0
  88. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/stable_diffusion_text_encoder.py +0 -0
  89. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/stable_diffusion_unet.py +0 -0
  90. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/stable_diffusion_vae.py +0 -0
  91. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/stable_diffusion_xl_text_encoder.py +0 -0
  92. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/stable_diffusion_xl_unet.py +0 -0
  93. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/step1x_connector.py +0 -0
  94. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/step1x_text_encoder.py +0 -0
  95. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_animate_adapter.py +0 -0
  96. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_camera_controller.py +0 -0
  97. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_dit.py +0 -0
  98. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_dit_s2v.py +0 -0
  99. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_image_encoder.py +0 -0
  100. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_mot.py +0 -0
  101. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_motion_controller.py +0 -0
  102. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_text_encoder.py +0 -0
  103. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_vace.py +0 -0
  104. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wan_video_vae.py +0 -0
  105. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wantodance.py +0 -0
  106. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/wav2vec.py +0 -0
  107. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/z_image_controlnet.py +0 -0
  108. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/z_image_dit.py +0 -0
  109. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/z_image_image2lora.py +0 -0
  110. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/models/z_image_text_encoder.py +0 -0
  111. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/ace_step.py +0 -0
  112. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/anima_image.py +0 -0
  113. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/ernie_image.py +0 -0
  114. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/flux_image.py +0 -0
  115. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/joyai_image.py +0 -0
  116. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/ltx2_audio_video.py +0 -0
  117. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/mova_audio_video.py +0 -0
  118. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/qwen_image.py +0 -0
  119. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/stable_diffusion.py +0 -0
  120. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/stable_diffusion_xl.py +0 -0
  121. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/pipelines/z_image.py +0 -0
  122. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/controlnet/__init__.py +0 -0
  123. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/controlnet/annotator.py +0 -0
  124. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/controlnet/controlnet_input.py +0 -0
  125. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/data/__init__.py +0 -0
  126. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/data/audio.py +0 -0
  127. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/data/audio_video.py +0 -0
  128. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/data/media_io_ltx2.py +0 -0
  129. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/lora/__init__.py +0 -0
  130. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/lora/flux.py +0 -0
  131. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/lora/general.py +0 -0
  132. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/lora/merge.py +0 -0
  133. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/lora/reset_rank.py +0 -0
  134. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/ses/__init__.py +0 -0
  135. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/ses/ses.py +0 -0
  136. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/__init__.py +0 -0
  137. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ace_step_conditioner.py +0 -0
  138. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ace_step_dit.py +0 -0
  139. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py +0 -0
  140. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py +0 -0
  141. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/anima_dit.py +0 -0
  142. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/dino_v3.py +0 -0
  143. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ernie_image_text_encoder.py +0 -0
  144. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/flux2_text_encoder.py +0 -0
  145. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/flux_controlnet.py +0 -0
  146. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/flux_dit.py +0 -0
  147. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/flux_infiniteyou.py +0 -0
  148. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/flux_ipadapter.py +0 -0
  149. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py +0 -0
  150. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py +0 -0
  151. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/flux_vae.py +0 -0
  152. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/joyai_image_text_encoder.py +0 -0
  153. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py +0 -0
  154. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ltx2_dit.py +0 -0
  155. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ltx2_text_encoder.py +0 -0
  156. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/ltx2_video_vae.py +0 -0
  157. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/nexus_gen.py +0 -0
  158. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/nexus_gen_projector.py +0 -0
  159. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/qwen_image_text_encoder.py +0 -0
  160. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/stable_diffusion_text_encoder.py +0 -0
  161. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/stable_diffusion_vae.py +0 -0
  162. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/stable_diffusion_xl_text_encoder.py +0 -0
  163. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/step1x_connector.py +0 -0
  164. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/wan_video_animate_adapter.py +0 -0
  165. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/wan_video_dit.py +0 -0
  166. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/wan_video_image_encoder.py +0 -0
  167. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/wan_video_mot.py +0 -0
  168. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/wan_video_vace.py +0 -0
  169. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/wan_video_vae.py +0 -0
  170. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/wans2v_audio_encoder.py +0 -0
  171. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/z_image_dit.py +0 -0
  172. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/state_dict_converters/z_image_text_encoder.py +0 -0
  173. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/xfuser/__init__.py +0 -0
  174. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/utils/xfuser/xdit_context_parallel.py +0 -0
  175. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth/version.py +0 -0
  176. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth.egg-info/dependency_links.txt +0 -0
  177. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth.egg-info/requires.txt +0 -0
  178. {diffsynth-2.0.10 → diffsynth-2.0.11}/diffsynth.egg-info/top_level.txt +0 -0
  179. {diffsynth-2.0.10 → diffsynth-2.0.11}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth
3
- Version: 2.0.10
3
+ Version: 2.0.11
4
4
  Summary: Enjoy the magic of Diffusion models!
5
5
  Author: ModelScope Team
6
6
  License: Apache-2.0
@@ -34,6 +34,15 @@ We believe that a well-developed open-source code framework can lower the thresh
34
34
 
35
35
  > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
36
36
 
37
+ - **April 28, 2026** 🔥 We are excited to announce the release of **Diffusion Templates**, a plugin framework designed for Diffusion models that significantly lowers the barrier to training controllable generative models. Let's explore this cutting-edge technology together!
38
+ * Open-source code: [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
39
+ * Technical report: [arXiv](https://arxiv.org/abs/2604.24351)
40
+ * Project homepage: [GitHub](https://modelscope.github.io/diffusion-templates-web/)
41
+ * Documentation: [English Version](https://diffsynth-studio-doc.readthedocs.io/en/latest/Diffusion_Templates/Introducing_Diffusion_Templates.html) | [Chinese Version](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/Diffusion_Templates/Introducing_Diffusion_Templates.html)
42
+ * Online demo: [ModelScope](https://modelscope.cn/studios/DiffSynth-Studio/Diffusion-Templates)
43
+ * Model collections: [ModelScope](https://modelscope.cn/collections/DiffSynth-Studio/KleinBase4B-Templates) | [ModelScope International](https://modelscope.ai/collections/DiffSynth-Studio/KleinBase4B-Templates) | [HuggingFace](https://huggingface.co/collections/DiffSynth-Studio/kleinbase4b-templates)
44
+ * Datasets: [ModelScope](https://modelscope.cn/collections/DiffSynth-Studio/ImagePulseV2) | [ModelScope International](https://modelscope.ai/collections/DiffSynth-Studio/ImagePulseV2) | [HuggingFace](https://huggingface.co/collections/DiffSynth-Studio/imagepulsev2)
45
+
37
46
  - **April 27, 2026** We support ACE-Step-1.5! Support includes text-to-music generation, low VRAM inference, and LoRA training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/ACE-Step.md) and [example code](/examples/ace_step/).
38
47
 
39
48
  - **April 27, 2026**: We have reinstated support for the Stable Diffusion v1.5 and SDXL models, providing academic research support exclusively for these two model types.
@@ -96,7 +105,7 @@ We believe that a well-developed open-source code framework can lower the thresh
96
105
 
97
106
  - **August 20, 2025** We open-sourced the [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) model, improving the editing effect of Qwen-Image-Edit on low-resolution image inputs. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
98
107
 
99
- - **August 19, 2025** 🔥 Qwen-Image-Edit open-sourced, welcome a new member to the image editing model family!
108
+ - **August 19, 2025** Qwen-Image-Edit open-sourced, welcome a new member to the image editing model family!
100
109
 
101
110
  - **August 18, 2025** We trained and open-sourced the Qwen-Image inpainting ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py).
102
111
 
@@ -112,7 +121,7 @@ We believe that a well-developed open-source code framework can lower the thresh
112
121
 
113
122
  - **August 5, 2025** We open-sourced the distilled acceleration model [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) for Qwen-Image, achieving approximately 5x acceleration.
114
123
 
115
- - **August 4, 2025** 🔥 Qwen-Image open-sourced, welcome a new member to the image generation model family!
124
+ - **August 4, 2025** Qwen-Image open-sourced, welcome a new member to the image generation model family!
116
125
 
117
126
  - **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) open-sourced, a text-to-image model focused on aesthetic photography. We provided comprehensive support in a timely manner, including low VRAM layer-by-layer offload, LoRA training, and full training. For more details, please refer to [./examples/flux/](./examples/flux/).
118
127
 
@@ -479,6 +488,17 @@ Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/)
479
488
  |[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
480
489
  |[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
481
490
  |[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
491
+ |[DiffSynth-Studio/Template-KleinBase4B-Aesthetic](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Aesthetic)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Aesthetic.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Aesthetic.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Aesthetic.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Aesthetic.py)|-|-|
492
+ |[DiffSynth-Studio/Template-KleinBase4B-Brightness](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Brightness)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Brightness.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Brightness.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Brightness.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Brightness.py)|-|-|
493
+ |[DiffSynth-Studio/Template-KleinBase4B-Age](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Age)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Age.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Age.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Age.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Age.py)|-|-|
494
+ |[DiffSynth-Studio/Template-KleinBase4B-ControlNet](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-ControlNet)|[code](/examples/flux2/model_inference/Template-KleinBase4B-ControlNet.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-ControlNet.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-ControlNet.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-ControlNet.py)|-|-|
495
+ |[DiffSynth-Studio/Template-KleinBase4B-Edit](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Edit)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Edit.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Edit.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Edit.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Edit.py)|-|-|
496
+ |[DiffSynth-Studio/Template-KleinBase4B-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Inpaint)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Inpaint.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Inpaint.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Inpaint.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Inpaint.py)|-|-|
497
+ |[DiffSynth-Studio/Template-KleinBase4B-PandaMeme](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-PandaMeme)|[code](/examples/flux2/model_inference/Template-KleinBase4B-PandaMeme.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-PandaMeme.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-PandaMeme.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-PandaMeme.py)|-|-|
498
+ |[DiffSynth-Studio/Template-KleinBase4B-Sharpness](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Sharpness)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Sharpness.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Sharpness.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Sharpness.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Sharpness.py)|-|-|
499
+ |[DiffSynth-Studio/Template-KleinBase4B-SoftRGB](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-SoftRGB)|[code](/examples/flux2/model_inference/Template-KleinBase4B-SoftRGB.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-SoftRGB.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-SoftRGB.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-SoftRGB.py)|-|-|
500
+ |[DiffSynth-Studio/Template-KleinBase4B-Upscaler](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Upscaler)|[code](/examples/flux2/model_inference/Template-KleinBase4B-Upscaler.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Upscaler.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-Upscaler.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-Upscaler.py)|-|-|
501
+ |[DiffSynth-Studio/Template-KleinBase4B-ContentRef](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-ContentRef)|[code](/examples/flux2/model_inference/Template-KleinBase4B-ContentRef.py)|[code](/examples/flux2/model_inference_low_vram/Template-KleinBase4B-ContentRef.py)|[code](/examples/flux2/model_training/full/Template-KleinBase4B-ContentRef.sh)|[code](/examples/flux2/model_training/validate_full/Template-KleinBase4B-ContentRef.py)|-|-|
482
502
 
483
503
  </details>
484
504
 
@@ -3,12 +3,13 @@ import torch
3
3
  import numpy as np
4
4
  from einops import repeat, reduce
5
5
  from typing import Union
6
- from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
6
+ from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type, enable_vram_management
7
7
  from ..core.device.npu_compatible_device import get_device_type
8
8
  from ..utils.lora import GeneralLoRALoader
9
9
  from ..models.model_loader import ModelPool
10
10
  from ..utils.controlnet import ControlNetInput
11
11
  from ..core.device import get_device_name, IS_NPU_AVAILABLE
12
+ from .template import load_template_model, load_template_data_processor
12
13
 
13
14
 
14
15
  class PipelineUnit:
@@ -319,14 +320,21 @@ class BasePipeline(torch.nn.Module):
319
320
 
320
321
 
321
322
  def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
323
+ # Positive side forward
322
324
  if inputs_shared.get("positive_only_lora", None) is not None:
323
- self.clear_lora(verbose=0)
324
325
  self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
325
326
  noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
327
+ if inputs_shared.get("positive_only_lora", None) is not None:
328
+ self.clear_lora(verbose=0)
329
+
326
330
  if cfg_scale != 1.0:
327
- if inputs_shared.get("positive_only_lora", None) is not None:
328
- self.clear_lora(verbose=0)
331
+ # Negative side forward
332
+ if inputs_shared.get("negative_only_lora", None) is not None:
333
+ self.load_lora(self.dit, state_dict=inputs_shared["negative_only_lora"], verbose=0)
329
334
  noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
335
+ if inputs_shared.get("negative_only_lora", None) is not None:
336
+ self.clear_lora(verbose=0)
337
+
330
338
  if isinstance(noise_pred_posi, tuple):
331
339
  # Separately handling different output types of latents, eg. video and audio latents.
332
340
  noise_pred = tuple(
@@ -338,6 +346,31 @@ class BasePipeline(torch.nn.Module):
338
346
  else:
339
347
  noise_pred = noise_pred_posi
340
348
  return noise_pred
349
+
350
+
351
+ def load_training_template_model(self, model_config: ModelConfig = None):
352
+ if model_config is not None:
353
+ model_config.download_if_necessary()
354
+ self.template_model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
355
+ self.template_data_processor = load_template_data_processor(model_config.path)()
356
+
357
+
358
+ def enable_lora_hot_loading(self, model: torch.nn.Module):
359
+ if hasattr(model, "vram_management_enabled") and getattr(model, "vram_management_enabled"):
360
+ return model
361
+ module_map = {torch.nn.Linear: AutoWrappedLinear}
362
+ vram_config = {
363
+ "offload_dtype": self.torch_dtype,
364
+ "offload_device": self.device,
365
+ "onload_dtype": self.torch_dtype,
366
+ "onload_device": self.device,
367
+ "preparing_dtype": self.torch_dtype,
368
+ "preparing_device": self.device,
369
+ "computation_dtype": self.torch_dtype,
370
+ "computation_device": self.device,
371
+ }
372
+ model = enable_vram_management(model, module_map, vram_config=vram_config)
373
+ return model
341
374
 
342
375
  def compile_pipeline(self, mode: str = "default", dynamic: bool = True, fullgraph: bool = False, compile_models: list = None, **kwargs):
343
376
  """
@@ -3,6 +3,11 @@ import torch
3
3
 
4
4
 
5
5
  def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
6
+ if "lora" in inputs:
7
+ # Image-to-LoRA models need to load lora here.
8
+ pipe.clear_lora(verbose=0)
9
+ pipe.load_lora(pipe.dit, state_dict=inputs["lora"], hotload=True, verbose=0)
10
+
6
11
  max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
7
12
  min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
8
13
 
@@ -60,6 +60,11 @@ def add_gradient_config(parser: argparse.ArgumentParser):
60
60
  parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
61
61
  return parser
62
62
 
63
+ def add_template_model_config(parser: argparse.ArgumentParser):
64
+ parser.add_argument("--template_model_id_or_path", type=str, default=None, help="Model ID of path of template models.")
65
+ parser.add_argument("--enable_lora_hot_loading", default=False, action="store_true", help="Whether to enable LoRA hot-loading. Only available for image-to-lora models.")
66
+ return parser
67
+
63
68
  def add_general_config(parser: argparse.ArgumentParser):
64
69
  parser = add_dataset_base_config(parser)
65
70
  parser = add_model_config(parser)
@@ -67,4 +72,5 @@ def add_general_config(parser: argparse.ArgumentParser):
67
72
  parser = add_output_config(parser)
68
73
  parser = add_lora_config(parser)
69
74
  parser = add_gradient_config(parser)
75
+ parser = add_template_model_config(parser)
70
76
  return parser
@@ -0,0 +1,203 @@
1
+ import torch, os, importlib, warnings, json, inspect
2
+ from typing import Dict, List, Tuple, Union
3
+ from ..core import ModelConfig, load_model
4
+ from ..core.device.npu_compatible_device import get_device_type
5
+ from ..utils.lora.merge import merge_lora
6
+
7
+
8
+ KVCache = Dict[str, Tuple[torch.Tensor, torch.Tensor]]
9
+
10
+
11
+ class TemplateModel(torch.nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ @torch.no_grad()
16
+ def process_inputs(self, **kwargs):
17
+ return {}
18
+
19
+ def forward(self, **kwargs):
20
+ raise NotImplementedError()
21
+
22
+
23
+ def check_template_model_format(model):
24
+ if not hasattr(model, "process_inputs"):
25
+ raise NotImplementedError("`process_inputs` is not implemented in the Template model.")
26
+ if "kwargs" not in inspect.signature(model.process_inputs).parameters:
27
+ raise NotImplementedError("`**kwargs` is not included in `process_inputs`.")
28
+ if not hasattr(model, "forward"):
29
+ raise NotImplementedError("`forward` is not implemented in the Template model.")
30
+ if "kwargs" not in inspect.signature(model.forward).parameters:
31
+ raise NotImplementedError("`**kwargs` is not included in `forward`.")
32
+
33
+
34
+ def load_template_model(path, torch_dtype=torch.bfloat16, device="cuda", verbose=1):
35
+ spec = importlib.util.spec_from_file_location("template_model", os.path.join(path, "model.py"))
36
+ module = importlib.util.module_from_spec(spec)
37
+ spec.loader.exec_module(module)
38
+ template_model_path = getattr(module, 'TEMPLATE_MODEL_PATH') if hasattr(module, 'TEMPLATE_MODEL_PATH') else None
39
+ if template_model_path is not None:
40
+ # With `TEMPLATE_MODEL_PATH`, a pretrained model will be loaded.
41
+ model = load_model(
42
+ model_class=getattr(module, 'TEMPLATE_MODEL'),
43
+ config=getattr(module, 'TEMPLATE_MODEL_CONFIG') if hasattr(module, 'TEMPLATE_MODEL_CONFIG') else None,
44
+ path=os.path.join(path, getattr(module, 'TEMPLATE_MODEL_PATH')),
45
+ torch_dtype=torch_dtype,
46
+ device=device,
47
+ )
48
+ else:
49
+ # Without `TEMPLATE_MODEL_PATH`, a randomly initialized model or a non-model module will be loaded.
50
+ model = module.TEMPLATE_MODEL()
51
+ if hasattr(model, "to"):
52
+ model = model.to(dtype=torch_dtype, device=device)
53
+ if hasattr(model, "eval"):
54
+ model = model.eval()
55
+ check_template_model_format(model)
56
+ if verbose > 0:
57
+ metadata = {
58
+ "model_architecture": getattr(module, 'TEMPLATE_MODEL').__name__,
59
+ "code_path": os.path.join(path, "model.py"),
60
+ "weight_path": template_model_path,
61
+ }
62
+ print(f"Template model loaded: {json.dumps(metadata, indent=4)}")
63
+ return model
64
+
65
+
66
+ def load_template_data_processor(path):
67
+ spec = importlib.util.spec_from_file_location("template_model", os.path.join(path, "model.py"))
68
+ module = importlib.util.module_from_spec(spec)
69
+ spec.loader.exec_module(module)
70
+ if hasattr(module, 'TEMPLATE_DATA_PROCESSOR'):
71
+ processor = getattr(module, 'TEMPLATE_DATA_PROCESSOR')
72
+ return processor
73
+ else:
74
+ return None
75
+
76
+
77
+ class TemplatePipeline(torch.nn.Module):
78
+ def __init__(
79
+ self,
80
+ torch_dtype: torch.dtype = torch.bfloat16,
81
+ device: Union[str, torch.device] = get_device_type(),
82
+ model_configs: list[ModelConfig] = [],
83
+ lazy_loading: bool = False,
84
+ ):
85
+ super().__init__()
86
+ self.torch_dtype = torch_dtype
87
+ self.device = device
88
+ self.model_configs = model_configs
89
+ self.lazy_loading = lazy_loading
90
+ if lazy_loading:
91
+ for model_config in model_configs:
92
+ TemplatePipeline.check_vram_config(model_config)
93
+ model_config.download_if_necessary()
94
+ self.models = None
95
+ else:
96
+ models = []
97
+ for model_config in model_configs:
98
+ TemplatePipeline.check_vram_config(model_config)
99
+ model_config.download_if_necessary()
100
+ model = load_template_model(model_config.path, torch_dtype=torch_dtype, device=device)
101
+ models.append(model)
102
+ self.models = torch.nn.ModuleList(models)
103
+
104
+ def merge_kv_cache(self, kv_cache_list: List[KVCache]) -> KVCache:
105
+ names = {}
106
+ for kv_cache in kv_cache_list:
107
+ for name in kv_cache:
108
+ names[name] = None
109
+ kv_cache_merged = {}
110
+ for name in names:
111
+ kv_list = [kv_cache.get(name) for kv_cache in kv_cache_list]
112
+ kv_list = [kv for kv in kv_list if kv is not None]
113
+ if len(kv_list) > 0:
114
+ k = torch.concat([kv[0] for kv in kv_list], dim=1)
115
+ v = torch.concat([kv[1] for kv in kv_list], dim=1)
116
+ kv_cache_merged[name] = (k, v)
117
+ return kv_cache_merged
118
+
119
+ def merge_template_cache(self, template_cache_list):
120
+ params = sorted(list(set(sum([list(template_cache.keys()) for template_cache in template_cache_list], []))))
121
+ template_cache_merged = {}
122
+ for param in params:
123
+ data = [template_cache[param] for template_cache in template_cache_list if param in template_cache]
124
+ if param == "kv_cache":
125
+ data = self.merge_kv_cache(data)
126
+ elif param == "lora":
127
+ data = merge_lora(data)
128
+ elif len(data) == 1:
129
+ data = data[0]
130
+ else:
131
+ print(f"Conflict detected: `{param}` appears in the outputs of multiple Template models. Only the first one will be retained.")
132
+ data = data[0]
133
+ template_cache_merged[param] = data
134
+ return template_cache_merged
135
+
136
+ @staticmethod
137
+ def check_vram_config(model_config: ModelConfig):
138
+ params = [
139
+ model_config.offload_device, model_config.offload_dtype,
140
+ model_config.onload_device, model_config.onload_dtype,
141
+ model_config.preparing_device, model_config.preparing_dtype,
142
+ model_config.computation_device, model_config.computation_dtype,
143
+ ]
144
+ for param in params:
145
+ if param is not None:
146
+ warnings.warn("TemplatePipeline doesn't support VRAM management. VRAM config will be ignored.")
147
+
148
+ @staticmethod
149
+ def from_pretrained(
150
+ torch_dtype: torch.dtype = torch.bfloat16,
151
+ device: Union[str, torch.device] = get_device_type(),
152
+ model_configs: list[ModelConfig] = [],
153
+ lazy_loading: bool = False,
154
+ ):
155
+ pipe = TemplatePipeline(torch_dtype, device, model_configs, lazy_loading)
156
+ return pipe
157
+
158
+ def fetch_model(self, model_id):
159
+ if self.lazy_loading:
160
+ model_config = self.model_configs[model_id]
161
+ model_config.download_if_necessary()
162
+ model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
163
+ else:
164
+ model = self.models[model_id]
165
+ return model
166
+
167
+ def call_single_side(self, pipe=None, inputs: List[Dict] = None):
168
+ model = None
169
+ onload_model_id = -1
170
+ template_cache = []
171
+ for i in inputs:
172
+ model_id = i.get("model_id", 0)
173
+ if model_id != onload_model_id:
174
+ model = self.fetch_model(model_id)
175
+ onload_model_id = model_id
176
+ cache = model.process_inputs(pipe=pipe, **i)
177
+ cache = model.forward(pipe=pipe, **cache)
178
+ template_cache.append(cache)
179
+ template_cache = self.merge_template_cache(template_cache)
180
+ return template_cache
181
+
182
+ @torch.no_grad()
183
+ def __call__(
184
+ self,
185
+ pipe=None,
186
+ template_inputs: List[Dict] = None,
187
+ negative_template_inputs: List[Dict] = None,
188
+ **kwargs,
189
+ ):
190
+ template_cache = self.call_single_side(pipe=pipe, inputs=template_inputs or [])
191
+ negative_template_cache = self.call_single_side(pipe=pipe, inputs=negative_template_inputs or [])
192
+ required_params = list(inspect.signature(pipe.__call__).parameters.keys())
193
+ for param in template_cache:
194
+ if param in required_params:
195
+ kwargs[param] = template_cache[param]
196
+ else:
197
+ print(f"`{param}` is not included in the inputs of `{pipe.__class__.__name__}`. This parameter will be ignored.")
198
+ for param in negative_template_cache:
199
+ if "negative_" + param in required_params:
200
+ kwargs["negative_" + param] = negative_template_cache[param]
201
+ else:
202
+ print(f"`{'negative_' + param}` is not included in the inputs of `{pipe.__class__.__name__}`. This parameter will be ignored.")
203
+ return pipe(**kwargs)
@@ -6,6 +6,7 @@ from peft import LoraConfig, inject_adapter_in_model
6
6
 
7
7
 
8
8
  class GeneralUnit_RemoveCache(PipelineUnit):
9
+ # Only used for training
9
10
  def __init__(self, required_params=tuple(), force_remove_params_shared=tuple(), force_remove_params_posi=tuple(), force_remove_params_nega=tuple()):
10
11
  super().__init__(take_over=True)
11
12
  self.required_params = required_params
@@ -27,6 +28,47 @@ class GeneralUnit_RemoveCache(PipelineUnit):
27
28
  return inputs_shared, inputs_posi, inputs_nega
28
29
 
29
30
 
31
+ class GeneralUnit_TemplateProcessInputs(PipelineUnit):
32
+ # Only used for training
33
+ def __init__(self, data_processor):
34
+ super().__init__(
35
+ input_params=("template_inputs",),
36
+ output_params=("template_inputs",),
37
+ )
38
+ self.data_processor = data_processor
39
+
40
+ def process(self, pipe, template_inputs):
41
+ if not hasattr(pipe, "template_model") or template_inputs is None:
42
+ return {}
43
+ if self.data_processor is not None:
44
+ template_inputs = self.data_processor(**template_inputs)
45
+ template_inputs = pipe.template_model.process_inputs(pipe=pipe, **template_inputs)
46
+ return {"template_inputs": template_inputs}
47
+
48
+
49
+ class GeneralUnit_TemplateForward(PipelineUnit):
50
+ # Only used for training
51
+ def __init__(self, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
52
+ super().__init__(
53
+ input_params=("template_inputs",),
54
+ output_params=("kv_cache",),
55
+ onload_model_names=("template_model",)
56
+ )
57
+ self.use_gradient_checkpointing = use_gradient_checkpointing
58
+ self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
59
+
60
+ def process(self, pipe, template_inputs):
61
+ if not hasattr(pipe, "template_model") or template_inputs is None:
62
+ return {}
63
+ template_cache = pipe.template_model.forward(
64
+ **template_inputs,
65
+ pipe=pipe,
66
+ use_gradient_checkpointing=self.use_gradient_checkpointing,
67
+ use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload,
68
+ )
69
+ return template_cache
70
+
71
+
30
72
  class DiffusionTrainingModule(torch.nn.Module):
31
73
  def __init__(self):
32
74
  super().__init__()
@@ -209,6 +251,16 @@ class DiffusionTrainingModule(torch.nn.Module):
209
251
  else:
210
252
  lora_target_modules = lora_target_modules.split(",")
211
253
  return lora_target_modules
254
+
255
+
256
+ def load_training_template_model(self, pipe, path_or_model_id, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
257
+ if path_or_model_id is None:
258
+ return pipe
259
+ model_config = self.parse_path_or_model_id(path_or_model_id)
260
+ pipe.load_training_template_model(model_config)
261
+ pipe.units.append(GeneralUnit_TemplateProcessInputs(pipe.template_data_processor))
262
+ pipe.units.append(GeneralUnit_TemplateForward(use_gradient_checkpointing, use_gradient_checkpointing_offload))
263
+ return pipe
212
264
 
213
265
 
214
266
  def switch_pipe_to_training_mode(
@@ -1,12 +1,16 @@
1
- from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTModel, DINOv3ViTConfig
2
- from transformers import DINOv3ViTImageProcessor
3
- import torch
4
-
1
+ import torch, warnings
2
+ try:
3
+ from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTModel
4
+ except:
5
+ warnings.warn(f"Cannot import `DINOv3ViTModel`. `DINOv3ImageEncoder` is not available. Please update `transformers` by `pip install -U transformers`.")
6
+ DINOv3ViTModel = torch.nn.Module
5
7
  from ..core.device.npu_compatible_device import get_device_type
6
8
 
7
9
 
8
10
  class DINOv3ImageEncoder(DINOv3ViTModel):
9
11
  def __init__(self):
12
+ from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
13
+ from transformers import DINOv3ViTImageProcessor
10
14
  config = DINOv3ViTConfig(
11
15
  architectures = [
12
16
  "DINOv3ViTModel"