diffsynth 2.0.7__tar.gz → 2.0.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. {diffsynth-2.0.7 → diffsynth-2.0.8}/PKG-INFO +1 -1
  2. {diffsynth-2.0.7 → diffsynth-2.0.8}/README.md +69 -0
  3. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/configs/model_configs.py +17 -1
  4. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/configs/vram_management_module_maps.py +12 -0
  5. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/diffusion/flow_match.py +15 -2
  6. diffsynth-2.0.8/diffsynth/models/ernie_image_dit.py +362 -0
  7. diffsynth-2.0.8/diffsynth/models/ernie_image_text_encoder.py +76 -0
  8. diffsynth-2.0.8/diffsynth/pipelines/ernie_image.py +266 -0
  9. diffsynth-2.0.8/diffsynth/utils/state_dict_converters/ernie_image_text_encoder.py +21 -0
  10. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth.egg-info/PKG-INFO +1 -1
  11. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth.egg-info/SOURCES.txt +4 -0
  12. {diffsynth-2.0.7 → diffsynth-2.0.8}/pyproject.toml +1 -1
  13. {diffsynth-2.0.7 → diffsynth-2.0.8}/LICENSE +0 -0
  14. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/__init__.py +0 -0
  15. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/configs/__init__.py +0 -0
  16. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/__init__.py +0 -0
  17. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/attention/__init__.py +0 -0
  18. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/attention/attention.py +0 -0
  19. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/data/__init__.py +0 -0
  20. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/data/operators.py +0 -0
  21. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/data/unified_dataset.py +0 -0
  22. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/device/__init__.py +0 -0
  23. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/device/npu_compatible_device.py +0 -0
  24. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/gradient/__init__.py +0 -0
  25. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/gradient/gradient_checkpoint.py +0 -0
  26. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/loader/__init__.py +0 -0
  27. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/loader/config.py +0 -0
  28. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/loader/file.py +0 -0
  29. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/loader/model.py +0 -0
  30. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/npu_patch/npu_fused_operator.py +0 -0
  31. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/vram/__init__.py +0 -0
  32. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/vram/disk_map.py +0 -0
  33. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/vram/initialization.py +0 -0
  34. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/vram/layers.py +0 -0
  35. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/diffusion/__init__.py +0 -0
  36. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/diffusion/base_pipeline.py +0 -0
  37. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/diffusion/logger.py +0 -0
  38. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/diffusion/loss.py +0 -0
  39. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/diffusion/parsers.py +0 -0
  40. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/diffusion/runner.py +0 -0
  41. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/diffusion/training_module.py +0 -0
  42. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/anima_dit.py +0 -0
  43. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/dinov3_image_encoder.py +0 -0
  44. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux2_dit.py +0 -0
  45. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux2_text_encoder.py +0 -0
  46. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux2_vae.py +0 -0
  47. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_controlnet.py +0 -0
  48. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_dit.py +0 -0
  49. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_infiniteyou.py +0 -0
  50. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_ipadapter.py +0 -0
  51. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_lora_encoder.py +0 -0
  52. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_lora_patcher.py +0 -0
  53. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_text_encoder_clip.py +0 -0
  54. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_text_encoder_t5.py +0 -0
  55. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_vae.py +0 -0
  56. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_value_control.py +0 -0
  57. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/general_modules.py +0 -0
  58. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/longcat_video_dit.py +0 -0
  59. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/ltx2_audio_vae.py +0 -0
  60. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/ltx2_common.py +0 -0
  61. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/ltx2_dit.py +0 -0
  62. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/ltx2_text_encoder.py +0 -0
  63. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/ltx2_upsampler.py +0 -0
  64. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/ltx2_video_vae.py +0 -0
  65. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/model_loader.py +0 -0
  66. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/mova_audio_dit.py +0 -0
  67. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/mova_audio_vae.py +0 -0
  68. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/mova_dual_tower_bridge.py +0 -0
  69. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/nexus_gen.py +0 -0
  70. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/nexus_gen_ar_model.py +0 -0
  71. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/nexus_gen_projector.py +0 -0
  72. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/qwen_image_controlnet.py +0 -0
  73. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/qwen_image_dit.py +0 -0
  74. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/qwen_image_image2lora.py +0 -0
  75. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/qwen_image_text_encoder.py +0 -0
  76. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/qwen_image_vae.py +0 -0
  77. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/sd_text_encoder.py +0 -0
  78. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/siglip2_image_encoder.py +0 -0
  79. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/step1x_connector.py +0 -0
  80. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/step1x_text_encoder.py +0 -0
  81. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_animate_adapter.py +0 -0
  82. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_camera_controller.py +0 -0
  83. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_dit.py +0 -0
  84. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_dit_s2v.py +0 -0
  85. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_image_encoder.py +0 -0
  86. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_mot.py +0 -0
  87. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_motion_controller.py +0 -0
  88. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_text_encoder.py +0 -0
  89. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_vace.py +0 -0
  90. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_vae.py +0 -0
  91. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wantodance.py +0 -0
  92. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wav2vec.py +0 -0
  93. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/z_image_controlnet.py +0 -0
  94. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/z_image_dit.py +0 -0
  95. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/z_image_image2lora.py +0 -0
  96. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/z_image_text_encoder.py +0 -0
  97. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/pipelines/anima_image.py +0 -0
  98. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/pipelines/flux2_image.py +0 -0
  99. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/pipelines/flux_image.py +0 -0
  100. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/pipelines/ltx2_audio_video.py +0 -0
  101. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/pipelines/mova_audio_video.py +0 -0
  102. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/pipelines/qwen_image.py +0 -0
  103. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/pipelines/wan_video.py +0 -0
  104. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/pipelines/z_image.py +0 -0
  105. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/controlnet/__init__.py +0 -0
  106. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/controlnet/annotator.py +0 -0
  107. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/controlnet/controlnet_input.py +0 -0
  108. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/data/__init__.py +0 -0
  109. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/data/audio.py +0 -0
  110. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/data/audio_video.py +0 -0
  111. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/data/media_io_ltx2.py +0 -0
  112. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/lora/__init__.py +0 -0
  113. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/lora/flux.py +0 -0
  114. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/lora/general.py +0 -0
  115. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/lora/merge.py +0 -0
  116. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/lora/reset_rank.py +0 -0
  117. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/ses/__init__.py +0 -0
  118. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/ses/ses.py +0 -0
  119. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/__init__.py +0 -0
  120. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/anima_dit.py +0 -0
  121. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/flux2_text_encoder.py +0 -0
  122. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/flux_controlnet.py +0 -0
  123. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/flux_dit.py +0 -0
  124. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/flux_infiniteyou.py +0 -0
  125. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/flux_ipadapter.py +0 -0
  126. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py +0 -0
  127. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py +0 -0
  128. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/flux_vae.py +0 -0
  129. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py +0 -0
  130. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/ltx2_dit.py +0 -0
  131. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/ltx2_text_encoder.py +0 -0
  132. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/ltx2_video_vae.py +0 -0
  133. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/nexus_gen.py +0 -0
  134. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/nexus_gen_projector.py +0 -0
  135. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/qwen_image_text_encoder.py +0 -0
  136. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/step1x_connector.py +0 -0
  137. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/wan_video_animate_adapter.py +0 -0
  138. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/wan_video_dit.py +0 -0
  139. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/wan_video_image_encoder.py +0 -0
  140. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/wan_video_mot.py +0 -0
  141. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/wan_video_vace.py +0 -0
  142. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/wan_video_vae.py +0 -0
  143. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/wans2v_audio_encoder.py +0 -0
  144. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/z_image_dit.py +0 -0
  145. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/z_image_text_encoder.py +0 -0
  146. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/xfuser/__init__.py +0 -0
  147. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/xfuser/xdit_context_parallel.py +0 -0
  148. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/version.py +0 -0
  149. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth.egg-info/dependency_links.txt +0 -0
  150. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth.egg-info/requires.txt +0 -0
  151. {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth.egg-info/top_level.txt +0 -0
  152. {diffsynth-2.0.7 → diffsynth-2.0.8}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth
3
- Version: 2.0.7
3
+ Version: 2.0.8
4
4
  Summary: Enjoy the magic of Diffusion models!
5
5
  Author: ModelScope Team
6
6
  License: Apache-2.0
@@ -7,6 +7,7 @@
7
7
  [![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues)
8
8
  [![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
9
9
  [![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
10
+ [![Discord](https://badgen.net//discord/members/Mm9suEeUDc)](https://discord.gg/Mm9suEeUDc)
10
11
 
11
12
  [切换到中文版](./README_zh.md)
12
13
 
@@ -32,6 +33,7 @@ We believe that a well-developed open-source code framework can lower the thresh
32
33
  > DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
33
34
 
34
35
  > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
36
+
35
37
  - **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
36
38
 
37
39
  - **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/).
@@ -875,6 +877,67 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
875
877
 
876
878
  </details>
877
879
 
880
+ #### ERNIE-Image: [/docs/en/Model_Details/ERNIE-Image.md](/docs/en/Model_Details/ERNIE-Image.md)
881
+
882
+ <details>
883
+
884
+ <summary>Quick Start</summary>
885
+
886
+ Running the following code will quickly load the [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
887
+
888
+ ```python
889
+ from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
890
+ import torch
891
+
892
+ vram_config = {
893
+ "offload_dtype": torch.bfloat16,
894
+ "offload_device": "cpu",
895
+ "onload_dtype": torch.bfloat16,
896
+ "onload_device": "cpu",
897
+ "preparing_dtype": torch.bfloat16,
898
+ "preparing_device": "cuda",
899
+ "computation_dtype": torch.bfloat16,
900
+ "computation_device": "cuda",
901
+ }
902
+ pipe = ErnieImagePipeline.from_pretrained(
903
+ torch_dtype=torch.bfloat16,
904
+ device='cuda',
905
+ model_configs=[
906
+ ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
907
+ ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
908
+ ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
909
+ ],
910
+ tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
911
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
912
+ )
913
+
914
+ image = pipe(
915
+ prompt="一只黑白相间的中华田园犬",
916
+ negative_prompt="",
917
+ height=1024,
918
+ width=1024,
919
+ seed=42,
920
+ num_inference_steps=50,
921
+ cfg_scale=4.0,
922
+ )
923
+ image.save("output.jpg")
924
+ ```
925
+
926
+ </details>
927
+
928
+ <details>
929
+
930
+ <summary>Examples</summary>
931
+
932
+ Example code for ERNIE-Image is available at: [/examples/ernie_image/](/examples/ernie_image/)
933
+
934
+ | Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
935
+ |-|-|-|-|-|-|-|
936
+ |[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
937
+ |[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
938
+
939
+ </details>
940
+
878
941
  ## Innovative Achievements
879
942
 
880
943
  DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.
@@ -1029,3 +1092,9 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
1029
1092
  https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
1030
1093
 
1031
1094
  </details>
1095
+
1096
+ ## Contact Us
1097
+
1098
+ |Discord:https://discord.gg/Mm9suEeUDc|
1099
+ |-|
1100
+ |<img width="160" height="160" alt="Image" src="https://github.com/user-attachments/assets/29bdc97b-e35d-4fea-88d6-32e35182e458" />|
@@ -541,6 +541,22 @@ flux2_series = [
541
541
  },
542
542
  ]
543
543
 
544
+ ernie_image_series = [
545
+ {
546
+ # Example: ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
547
+ "model_hash": "584c13713849f1af4e03d5f1858b8b7b",
548
+ "model_name": "ernie_image_dit",
549
+ "model_class": "diffsynth.models.ernie_image_dit.ErnieImageDiT",
550
+ },
551
+ {
552
+ # Example: ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors")
553
+ "model_hash": "404ed9f40796a38dd34c1620f1920207",
554
+ "model_name": "ernie_image_text_encoder",
555
+ "model_class": "diffsynth.models.ernie_image_text_encoder.ErnieImageTextEncoder",
556
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ernie_image_text_encoder.ErnieImageTextEncoderStateDictConverter",
557
+ },
558
+ ]
559
+
544
560
  z_image_series = [
545
561
  {
546
562
  # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
@@ -884,4 +900,4 @@ mova_series = [
884
900
  "model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
885
901
  },
886
902
  ]
887
- MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series
903
+ MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series
@@ -267,6 +267,18 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
267
267
  "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
268
268
  "torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
269
269
  },
270
+ "diffsynth.models.ernie_image_dit.ErnieImageDiT": {
271
+ "diffsynth.models.ernie_image_dit.ErnieImageRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
272
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
273
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
274
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
275
+ "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
276
+ },
277
+ "diffsynth.models.ernie_image_text_encoder.ErnieImageTextEncoder": {
278
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
279
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
280
+ "transformers.models.ministral3.modeling_ministral3.Ministral3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
281
+ },
270
282
  }
271
283
 
272
284
  def QwenImageTextEncoder_Module_Map_Updater():
@@ -4,7 +4,7 @@ from typing_extensions import Literal
4
4
 
5
5
  class FlowMatchScheduler():
6
6
 
7
- def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
7
+ def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image"] = "FLUX.1"):
8
8
  self.set_timesteps_fn = {
9
9
  "FLUX.1": FlowMatchScheduler.set_timesteps_flux,
10
10
  "Wan": FlowMatchScheduler.set_timesteps_wan,
@@ -13,6 +13,7 @@ class FlowMatchScheduler():
13
13
  "Z-Image": FlowMatchScheduler.set_timesteps_z_image,
14
14
  "LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
15
15
  "Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
16
+ "ERNIE-Image": FlowMatchScheduler.set_timesteps_ernie_image,
16
17
  }.get(template, FlowMatchScheduler.set_timesteps_flux)
17
18
  self.num_train_timesteps = 1000
18
19
 
@@ -129,6 +130,18 @@ class FlowMatchScheduler():
129
130
  timesteps = sigmas * num_train_timesteps
130
131
  return sigmas, timesteps
131
132
 
133
+ @staticmethod
134
+ def set_timesteps_ernie_image(num_inference_steps=50, denoising_strength=1.0, shift=3.0):
135
+ sigma_min = 0.0
136
+ sigma_max = 1.0
137
+ num_train_timesteps = 1000
138
+ sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
139
+ sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
140
+ if shift is not None and shift != 1.0:
141
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
142
+ timesteps = sigmas * num_train_timesteps
143
+ return sigmas, timesteps
144
+
132
145
  @staticmethod
133
146
  def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
134
147
  sigma_min = 0.0
@@ -185,7 +198,7 @@ class FlowMatchScheduler():
185
198
  bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
186
199
  bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
187
200
  self.linear_timesteps_weights = bsmntw_weighing
188
-
201
+
189
202
  def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
190
203
  self.sigmas, self.timesteps = self.set_timesteps_fn(
191
204
  num_inference_steps=num_inference_steps,
@@ -0,0 +1,362 @@
1
+ """
2
+ Ernie-Image DiT for DiffSynth-Studio.
3
+
4
+ Refactored from diffusers ErnieImageTransformer2DModel to use DiffSynth core modules.
5
+ Default parameters from actual checkpoint config.json (PaddlePaddle/ERNIE-Image transformer).
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from typing import Optional, Tuple
12
+
13
+ from ..core.attention import attention_forward
14
+ from ..core.gradient import gradient_checkpoint_forward
15
+ from .flux2_dit import Timesteps, TimestepEmbedding
16
+
17
+
18
+ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
19
+ assert dim % 2 == 0
20
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
21
+ omega = 1.0 / (theta ** scale)
22
+ out = torch.einsum("...n,d->...nd", pos, omega)
23
+ return out.float()
24
+
25
+
26
+ class ErnieImageEmbedND3(nn.Module):
27
+ def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]):
28
+ super().__init__()
29
+ self.dim = dim
30
+ self.theta = theta
31
+ self.axes_dim = list(axes_dim)
32
+
33
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
34
+ emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
35
+ emb = emb.unsqueeze(2)
36
+ return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1)
37
+
38
+
39
+ class ErnieImagePatchEmbedDynamic(nn.Module):
40
+ def __init__(self, in_channels: int, embed_dim: int, patch_size: int):
41
+ super().__init__()
42
+ self.patch_size = patch_size
43
+ self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ x = self.proj(x)
47
+ batch_size, dim, height, width = x.shape
48
+ return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous()
49
+
50
+
51
+ class ErnieImageSingleStreamAttnProcessor:
52
+ def __call__(
53
+ self,
54
+ attn: "ErnieImageAttention",
55
+ hidden_states: torch.Tensor,
56
+ attention_mask: Optional[torch.Tensor] = None,
57
+ freqs_cis: Optional[torch.Tensor] = None,
58
+ ) -> torch.Tensor:
59
+ query = attn.to_q(hidden_states)
60
+ key = attn.to_k(hidden_states)
61
+ value = attn.to_v(hidden_states)
62
+
63
+ query = query.unflatten(-1, (attn.heads, -1))
64
+ key = key.unflatten(-1, (attn.heads, -1))
65
+ value = value.unflatten(-1, (attn.heads, -1))
66
+
67
+ if attn.norm_q is not None:
68
+ query = attn.norm_q(query)
69
+ if attn.norm_k is not None:
70
+ key = attn.norm_k(key)
71
+
72
+ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
73
+ rot_dim = freqs_cis.shape[-1]
74
+ x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
75
+ cos_ = torch.cos(freqs_cis).to(x.dtype)
76
+ sin_ = torch.sin(freqs_cis).to(x.dtype)
77
+ x1, x2 = x.chunk(2, dim=-1)
78
+ x_rotated = torch.cat((-x2, x1), dim=-1)
79
+ return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
80
+
81
+ if freqs_cis is not None:
82
+ query = apply_rotary_emb(query, freqs_cis)
83
+ key = apply_rotary_emb(key, freqs_cis)
84
+
85
+ if attention_mask is not None and attention_mask.ndim == 2:
86
+ attention_mask = attention_mask[:, None, None, :]
87
+
88
+ hidden_states = attention_forward(
89
+ query, key, value,
90
+ q_pattern="b s n d",
91
+ k_pattern="b s n d",
92
+ v_pattern="b s n d",
93
+ out_pattern="b s n d",
94
+ attn_mask=attention_mask,
95
+ )
96
+
97
+ hidden_states = hidden_states.flatten(2, 3)
98
+ hidden_states = hidden_states.to(query.dtype)
99
+ output = attn.to_out[0](hidden_states)
100
+
101
+ return output
102
+
103
+
104
+ class ErnieImageAttention(nn.Module):
105
+ def __init__(
106
+ self,
107
+ query_dim: int,
108
+ heads: int = 8,
109
+ dim_head: int = 64,
110
+ dropout: float = 0.0,
111
+ bias: bool = False,
112
+ qk_norm: str = "rms_norm",
113
+ out_bias: bool = True,
114
+ eps: float = 1e-5,
115
+ out_dim: int = None,
116
+ elementwise_affine: bool = True,
117
+ ):
118
+ super().__init__()
119
+
120
+ self.head_dim = dim_head
121
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
122
+ self.query_dim = query_dim
123
+ self.out_dim = out_dim if out_dim is not None else query_dim
124
+ self.heads = out_dim // dim_head if out_dim is not None else heads
125
+
126
+ self.use_bias = bias
127
+ self.dropout = dropout
128
+
129
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
130
+ self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
131
+ self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
132
+
133
+ if qk_norm == "layer_norm":
134
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
135
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
136
+ elif qk_norm == "rms_norm":
137
+ self.norm_q = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
138
+ self.norm_k = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
139
+ else:
140
+ raise ValueError(
141
+ f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'rms_norm'."
142
+ )
143
+
144
+ self.to_out = nn.ModuleList([])
145
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
146
+
147
+ self.processor = ErnieImageSingleStreamAttnProcessor()
148
+
149
+ def forward(
150
+ self,
151
+ hidden_states: torch.Tensor,
152
+ attention_mask: Optional[torch.Tensor] = None,
153
+ image_rotary_emb: Optional[torch.Tensor] = None,
154
+ ) -> torch.Tensor:
155
+ return self.processor(self, hidden_states, attention_mask, image_rotary_emb)
156
+
157
+
158
+ class ErnieImageFeedForward(nn.Module):
159
+ def __init__(self, hidden_size: int, ffn_hidden_size: int):
160
+ super().__init__()
161
+ self.gate_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
162
+ self.up_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
163
+ self.linear_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
164
+
165
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
166
+ return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x)))
167
+
168
+
169
+ class ErnieImageRMSNorm(nn.Module):
170
+ def __init__(self, dim: int, eps: float = 1e-6):
171
+ super().__init__()
172
+ self.eps = eps
173
+ self.weight = nn.Parameter(torch.ones(dim))
174
+
175
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
176
+ input_dtype = hidden_states.dtype
177
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
178
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
179
+ hidden_states = hidden_states * self.weight
180
+ return hidden_states.to(input_dtype)
181
+
182
+
183
+ class ErnieImageSharedAdaLNBlock(nn.Module):
184
+ def __init__(
185
+ self,
186
+ hidden_size: int,
187
+ num_heads: int,
188
+ ffn_hidden_size: int,
189
+ eps: float = 1e-6,
190
+ qk_layernorm: bool = True,
191
+ ):
192
+ super().__init__()
193
+ self.adaLN_sa_ln = ErnieImageRMSNorm(hidden_size, eps=eps)
194
+ self.self_attention = ErnieImageAttention(
195
+ query_dim=hidden_size,
196
+ dim_head=hidden_size // num_heads,
197
+ heads=num_heads,
198
+ qk_norm="rms_norm" if qk_layernorm else None,
199
+ eps=eps,
200
+ bias=False,
201
+ out_bias=False,
202
+ )
203
+ self.adaLN_mlp_ln = ErnieImageRMSNorm(hidden_size, eps=eps)
204
+ self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size)
205
+
206
+ def forward(
207
+ self,
208
+ x: torch.Tensor,
209
+ rotary_pos_emb: torch.Tensor,
210
+ temb: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
211
+ attention_mask: Optional[torch.Tensor] = None,
212
+ ) -> torch.Tensor:
213
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb
214
+ residual = x
215
+ x = self.adaLN_sa_ln(x)
216
+ x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
217
+ x_bsh = x.permute(1, 0, 2)
218
+ attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
219
+ attn_out = attn_out.permute(1, 0, 2)
220
+ x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
221
+ residual = x
222
+ x = self.adaLN_mlp_ln(x)
223
+ x = (x.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
224
+ return residual + (gate_mlp.float() * self.mlp(x).float()).to(x.dtype)
225
+
226
+
227
+ class ErnieImageAdaLNContinuous(nn.Module):
228
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
229
+ super().__init__()
230
+ self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=eps)
231
+ self.linear = nn.Linear(hidden_size, hidden_size * 2)
232
+
233
+ def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
234
+ scale, shift = self.linear(conditioning).chunk(2, dim=-1)
235
+ x = self.norm(x)
236
+ x = x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
237
+ return x
238
+
239
+
240
+ class ErnieImageDiT(nn.Module):
241
+ """
242
+ Ernie-Image DiT model for DiffSynth-Studio.
243
+
244
+ Architecture: SharedAdaLN + RoPE 3D + Joint Image-Text Attention.
245
+ Internal format: [S, B, H] for transformer blocks, [B, S, H] for attention.
246
+ """
247
+
248
+ def __init__(
249
+ self,
250
+ hidden_size: int = 4096,
251
+ num_attention_heads: int = 32,
252
+ num_layers: int = 36,
253
+ ffn_hidden_size: int = 12288,
254
+ in_channels: int = 128,
255
+ out_channels: int = 128,
256
+ patch_size: int = 1,
257
+ text_in_dim: int = 3072,
258
+ rope_theta: int = 256,
259
+ rope_axes_dim: Tuple[int, int, int] = (32, 48, 48),
260
+ eps: float = 1e-6,
261
+ qk_layernorm: bool = True,
262
+ ):
263
+ super().__init__()
264
+ self.hidden_size = hidden_size
265
+ self.num_heads = num_attention_heads
266
+ self.head_dim = hidden_size // num_attention_heads
267
+ self.num_layers = num_layers
268
+ self.patch_size = patch_size
269
+ self.in_channels = in_channels
270
+ self.out_channels = out_channels
271
+ self.text_in_dim = text_in_dim
272
+
273
+ self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size)
274
+ self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None
275
+ self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0)
276
+ self.time_embedding = TimestepEmbedding(hidden_size, hidden_size)
277
+ self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim)
278
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size))
279
+ nn.init.zeros_(self.adaLN_modulation[-1].weight)
280
+ nn.init.zeros_(self.adaLN_modulation[-1].bias)
281
+ self.layers = nn.ModuleList([
282
+ ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm)
283
+ for _ in range(num_layers)
284
+ ])
285
+ self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps)
286
+ self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels)
287
+ nn.init.zeros_(self.final_linear.weight)
288
+ nn.init.zeros_(self.final_linear.bias)
289
+
290
+ def forward(
291
+ self,
292
+ hidden_states: torch.Tensor,
293
+ timestep: torch.Tensor,
294
+ text_bth: torch.Tensor,
295
+ text_lens: torch.Tensor,
296
+ use_gradient_checkpointing: bool = False,
297
+ use_gradient_checkpointing_offload: bool = False,
298
+ ) -> torch.Tensor:
299
+ device, dtype = hidden_states.device, hidden_states.dtype
300
+ B, C, H, W = hidden_states.shape
301
+ p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size
302
+ N_img = Hp * Wp
303
+
304
+ img_sbh = self.x_embedder(hidden_states).transpose(0, 1).contiguous()
305
+
306
+ if self.text_proj is not None and text_bth.numel() > 0:
307
+ text_bth = self.text_proj(text_bth)
308
+ Tmax = text_bth.shape[1]
309
+ text_sbh = text_bth.transpose(0, 1).contiguous()
310
+
311
+ x = torch.cat([img_sbh, text_sbh], dim=0)
312
+ S = x.shape[0]
313
+
314
+ text_ids = torch.cat([
315
+ torch.arange(Tmax, device=device, dtype=torch.float32).view(1, Tmax, 1).expand(B, -1, -1),
316
+ torch.zeros((B, Tmax, 2), device=device)
317
+ ], dim=-1) if Tmax > 0 else torch.zeros((B, 0, 3), device=device)
318
+ grid_yx = torch.stack(
319
+ torch.meshgrid(torch.arange(Hp, device=device, dtype=torch.float32),
320
+ torch.arange(Wp, device=device, dtype=torch.float32), indexing="ij"),
321
+ dim=-1
322
+ ).reshape(-1, 2)
323
+ image_ids = torch.cat([
324
+ text_lens.float().view(B, 1, 1).expand(-1, N_img, -1),
325
+ grid_yx.view(1, N_img, 2).expand(B, -1, -1)
326
+ ], dim=-1)
327
+ rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1))
328
+
329
+ valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) if Tmax > 0 else torch.zeros((B, 0), device=device, dtype=torch.bool)
330
+ attention_mask = torch.cat([
331
+ torch.ones((B, N_img), device=device, dtype=torch.bool),
332
+ valid_text
333
+ ], dim=1)[:, None, None, :]
334
+
335
+ sample = self.time_proj(timestep.to(dtype))
336
+ sample = sample.to(self.time_embedding.linear_1.weight.dtype)
337
+ c = self.time_embedding(sample)
338
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
339
+ t.unsqueeze(0).expand(S, -1, -1).contiguous()
340
+ for t in self.adaLN_modulation(c).chunk(6, dim=-1)
341
+ ]
342
+
343
+ for layer in self.layers:
344
+ temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp]
345
+ if torch.is_grad_enabled() and use_gradient_checkpointing:
346
+ x = gradient_checkpoint_forward(
347
+ layer,
348
+ use_gradient_checkpointing,
349
+ use_gradient_checkpointing_offload,
350
+ x,
351
+ rotary_pos_emb,
352
+ temb,
353
+ attention_mask,
354
+ )
355
+ else:
356
+ x = layer(x, rotary_pos_emb, temb, attention_mask)
357
+
358
+ x = self.final_norm(x, c).type_as(x)
359
+ patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous()
360
+ output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W)
361
+
362
+ return output
@@ -0,0 +1,76 @@
1
+ """
2
+ Ernie-Image TextEncoder for DiffSynth-Studio.
3
+
4
+ Wraps transformers Ministral3Model to output text embeddings.
5
+ Pattern: lazy import + manual config dict + torch.nn.Module wrapper.
6
+ Only loads the text (language) model, ignoring vision components.
7
+ """
8
+
9
+ import torch
10
+
11
+
12
+ class ErnieImageTextEncoder(torch.nn.Module):
13
+ """
14
+ Text encoder using Ministral3Model (transformers).
15
+ Only the text_config portion of the full Mistral3Model checkpoint.
16
+ Uses the base model (no lm_head) since the checkpoint only has embeddings.
17
+ """
18
+
19
+ def __init__(self):
20
+ super().__init__()
21
+ from transformers import Ministral3Config, Ministral3Model
22
+
23
+ text_config = {
24
+ "attention_dropout": 0.0,
25
+ "bos_token_id": 1,
26
+ "dtype": "bfloat16",
27
+ "eos_token_id": 2,
28
+ "head_dim": 128,
29
+ "hidden_act": "silu",
30
+ "hidden_size": 3072,
31
+ "initializer_range": 0.02,
32
+ "intermediate_size": 9216,
33
+ "max_position_embeddings": 262144,
34
+ "model_type": "ministral3",
35
+ "num_attention_heads": 32,
36
+ "num_hidden_layers": 26,
37
+ "num_key_value_heads": 8,
38
+ "pad_token_id": 11,
39
+ "rms_norm_eps": 1e-05,
40
+ "rope_parameters": {
41
+ "beta_fast": 32.0,
42
+ "beta_slow": 1.0,
43
+ "factor": 16.0,
44
+ "llama_4_scaling_beta": 0.1,
45
+ "mscale": 1.0,
46
+ "mscale_all_dim": 1.0,
47
+ "original_max_position_embeddings": 16384,
48
+ "rope_theta": 1000000.0,
49
+ "rope_type": "yarn",
50
+ "type": "yarn",
51
+ },
52
+ "sliding_window": None,
53
+ "tie_word_embeddings": True,
54
+ "use_cache": True,
55
+ "vocab_size": 131072,
56
+ }
57
+ config = Ministral3Config(**text_config)
58
+ self.model = Ministral3Model(config)
59
+ self.config = config
60
+
61
+ def forward(
62
+ self,
63
+ input_ids=None,
64
+ attention_mask=None,
65
+ position_ids=None,
66
+ **kwargs,
67
+ ):
68
+ outputs = self.model(
69
+ input_ids=input_ids,
70
+ attention_mask=attention_mask,
71
+ position_ids=position_ids,
72
+ output_hidden_states=True,
73
+ return_dict=True,
74
+ **kwargs,
75
+ )
76
+ return (outputs.hidden_states,)