diffsynth 2.0.7__tar.gz → 2.0.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diffsynth-2.0.7 → diffsynth-2.0.8}/PKG-INFO +1 -1
- {diffsynth-2.0.7 → diffsynth-2.0.8}/README.md +69 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/configs/model_configs.py +17 -1
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/configs/vram_management_module_maps.py +12 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/diffusion/flow_match.py +15 -2
- diffsynth-2.0.8/diffsynth/models/ernie_image_dit.py +362 -0
- diffsynth-2.0.8/diffsynth/models/ernie_image_text_encoder.py +76 -0
- diffsynth-2.0.8/diffsynth/pipelines/ernie_image.py +266 -0
- diffsynth-2.0.8/diffsynth/utils/state_dict_converters/ernie_image_text_encoder.py +21 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth.egg-info/PKG-INFO +1 -1
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth.egg-info/SOURCES.txt +4 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/pyproject.toml +1 -1
- {diffsynth-2.0.7 → diffsynth-2.0.8}/LICENSE +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/__init__.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/configs/__init__.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/__init__.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/attention/__init__.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/attention/attention.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/data/__init__.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/data/operators.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/data/unified_dataset.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/device/__init__.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/device/npu_compatible_device.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/gradient/__init__.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/gradient/gradient_checkpoint.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/loader/__init__.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/loader/config.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/loader/file.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/loader/model.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/npu_patch/npu_fused_operator.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/vram/__init__.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/vram/disk_map.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/vram/initialization.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/core/vram/layers.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/diffusion/__init__.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/diffusion/base_pipeline.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/diffusion/logger.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/diffusion/loss.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/diffusion/parsers.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/diffusion/runner.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/diffusion/training_module.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/anima_dit.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/dinov3_image_encoder.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux2_dit.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux2_text_encoder.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux2_vae.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_controlnet.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_dit.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_infiniteyou.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_ipadapter.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_lora_encoder.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_lora_patcher.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_text_encoder_clip.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_text_encoder_t5.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_vae.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/flux_value_control.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/general_modules.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/longcat_video_dit.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/ltx2_audio_vae.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/ltx2_common.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/ltx2_dit.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/ltx2_text_encoder.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/ltx2_upsampler.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/ltx2_video_vae.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/model_loader.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/mova_audio_dit.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/mova_audio_vae.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/mova_dual_tower_bridge.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/nexus_gen.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/nexus_gen_ar_model.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/nexus_gen_projector.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/qwen_image_controlnet.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/qwen_image_dit.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/qwen_image_image2lora.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/qwen_image_text_encoder.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/qwen_image_vae.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/sd_text_encoder.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/siglip2_image_encoder.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/step1x_connector.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/step1x_text_encoder.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_animate_adapter.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_camera_controller.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_dit.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_dit_s2v.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_image_encoder.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_mot.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_motion_controller.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_text_encoder.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_vace.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wan_video_vae.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wantodance.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/wav2vec.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/z_image_controlnet.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/z_image_dit.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/z_image_image2lora.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/models/z_image_text_encoder.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/pipelines/anima_image.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/pipelines/flux2_image.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/pipelines/flux_image.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/pipelines/ltx2_audio_video.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/pipelines/mova_audio_video.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/pipelines/qwen_image.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/pipelines/wan_video.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/pipelines/z_image.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/controlnet/__init__.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/controlnet/annotator.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/controlnet/controlnet_input.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/data/__init__.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/data/audio.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/data/audio_video.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/data/media_io_ltx2.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/lora/__init__.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/lora/flux.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/lora/general.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/lora/merge.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/lora/reset_rank.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/ses/__init__.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/ses/ses.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/__init__.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/anima_dit.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/flux2_text_encoder.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/flux_controlnet.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/flux_dit.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/flux_infiniteyou.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/flux_ipadapter.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/flux_vae.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/ltx2_dit.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/ltx2_text_encoder.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/ltx2_video_vae.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/nexus_gen.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/nexus_gen_projector.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/qwen_image_text_encoder.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/step1x_connector.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/wan_video_animate_adapter.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/wan_video_dit.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/wan_video_image_encoder.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/wan_video_mot.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/wan_video_vace.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/wan_video_vae.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/wans2v_audio_encoder.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/z_image_dit.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/state_dict_converters/z_image_text_encoder.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/xfuser/__init__.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/utils/xfuser/xdit_context_parallel.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth/version.py +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth.egg-info/dependency_links.txt +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth.egg-info/requires.txt +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/diffsynth.egg-info/top_level.txt +0 -0
- {diffsynth-2.0.7 → diffsynth-2.0.8}/setup.cfg +0 -0
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
|
8
8
|
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
|
9
9
|
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
|
10
|
+
[](https://discord.gg/Mm9suEeUDc)
|
|
10
11
|
|
|
11
12
|
[切换到中文版](./README_zh.md)
|
|
12
13
|
|
|
@@ -32,6 +33,7 @@ We believe that a well-developed open-source code framework can lower the thresh
|
|
|
32
33
|
> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
|
|
33
34
|
|
|
34
35
|
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
|
36
|
+
|
|
35
37
|
- **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
|
|
36
38
|
|
|
37
39
|
- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/).
|
|
@@ -875,6 +877,67 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
|
|
|
875
877
|
|
|
876
878
|
</details>
|
|
877
879
|
|
|
880
|
+
#### ERNIE-Image: [/docs/en/Model_Details/ERNIE-Image.md](/docs/en/Model_Details/ERNIE-Image.md)
|
|
881
|
+
|
|
882
|
+
<details>
|
|
883
|
+
|
|
884
|
+
<summary>Quick Start</summary>
|
|
885
|
+
|
|
886
|
+
Running the following code will quickly load the [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
|
|
887
|
+
|
|
888
|
+
```python
|
|
889
|
+
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
|
|
890
|
+
import torch
|
|
891
|
+
|
|
892
|
+
vram_config = {
|
|
893
|
+
"offload_dtype": torch.bfloat16,
|
|
894
|
+
"offload_device": "cpu",
|
|
895
|
+
"onload_dtype": torch.bfloat16,
|
|
896
|
+
"onload_device": "cpu",
|
|
897
|
+
"preparing_dtype": torch.bfloat16,
|
|
898
|
+
"preparing_device": "cuda",
|
|
899
|
+
"computation_dtype": torch.bfloat16,
|
|
900
|
+
"computation_device": "cuda",
|
|
901
|
+
}
|
|
902
|
+
pipe = ErnieImagePipeline.from_pretrained(
|
|
903
|
+
torch_dtype=torch.bfloat16,
|
|
904
|
+
device='cuda',
|
|
905
|
+
model_configs=[
|
|
906
|
+
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
|
907
|
+
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
|
908
|
+
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
|
909
|
+
],
|
|
910
|
+
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
|
911
|
+
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
|
912
|
+
)
|
|
913
|
+
|
|
914
|
+
image = pipe(
|
|
915
|
+
prompt="一只黑白相间的中华田园犬",
|
|
916
|
+
negative_prompt="",
|
|
917
|
+
height=1024,
|
|
918
|
+
width=1024,
|
|
919
|
+
seed=42,
|
|
920
|
+
num_inference_steps=50,
|
|
921
|
+
cfg_scale=4.0,
|
|
922
|
+
)
|
|
923
|
+
image.save("output.jpg")
|
|
924
|
+
```
|
|
925
|
+
|
|
926
|
+
</details>
|
|
927
|
+
|
|
928
|
+
<details>
|
|
929
|
+
|
|
930
|
+
<summary>Examples</summary>
|
|
931
|
+
|
|
932
|
+
Example code for ERNIE-Image is available at: [/examples/ernie_image/](/examples/ernie_image/)
|
|
933
|
+
|
|
934
|
+
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
|
935
|
+
|-|-|-|-|-|-|-|
|
|
936
|
+
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|
|
937
|
+
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
|
|
938
|
+
|
|
939
|
+
</details>
|
|
940
|
+
|
|
878
941
|
## Innovative Achievements
|
|
879
942
|
|
|
880
943
|
DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.
|
|
@@ -1029,3 +1092,9 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
|
|
|
1029
1092
|
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
|
1030
1093
|
|
|
1031
1094
|
</details>
|
|
1095
|
+
|
|
1096
|
+
## Contact Us
|
|
1097
|
+
|
|
1098
|
+
|Discord:https://discord.gg/Mm9suEeUDc|
|
|
1099
|
+
|-|
|
|
1100
|
+
|<img width="160" height="160" alt="Image" src="https://github.com/user-attachments/assets/29bdc97b-e35d-4fea-88d6-32e35182e458" />|
|
|
@@ -541,6 +541,22 @@ flux2_series = [
|
|
|
541
541
|
},
|
|
542
542
|
]
|
|
543
543
|
|
|
544
|
+
ernie_image_series = [
|
|
545
|
+
{
|
|
546
|
+
# Example: ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
|
547
|
+
"model_hash": "584c13713849f1af4e03d5f1858b8b7b",
|
|
548
|
+
"model_name": "ernie_image_dit",
|
|
549
|
+
"model_class": "diffsynth.models.ernie_image_dit.ErnieImageDiT",
|
|
550
|
+
},
|
|
551
|
+
{
|
|
552
|
+
# Example: ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors")
|
|
553
|
+
"model_hash": "404ed9f40796a38dd34c1620f1920207",
|
|
554
|
+
"model_name": "ernie_image_text_encoder",
|
|
555
|
+
"model_class": "diffsynth.models.ernie_image_text_encoder.ErnieImageTextEncoder",
|
|
556
|
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ernie_image_text_encoder.ErnieImageTextEncoderStateDictConverter",
|
|
557
|
+
},
|
|
558
|
+
]
|
|
559
|
+
|
|
544
560
|
z_image_series = [
|
|
545
561
|
{
|
|
546
562
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
|
|
@@ -884,4 +900,4 @@ mova_series = [
|
|
|
884
900
|
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
|
|
885
901
|
},
|
|
886
902
|
]
|
|
887
|
-
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series
|
|
903
|
+
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series
|
|
@@ -267,6 +267,18 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
|
|
267
267
|
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
268
268
|
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
269
269
|
},
|
|
270
|
+
"diffsynth.models.ernie_image_dit.ErnieImageDiT": {
|
|
271
|
+
"diffsynth.models.ernie_image_dit.ErnieImageRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
272
|
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
|
273
|
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
274
|
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
275
|
+
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
276
|
+
},
|
|
277
|
+
"diffsynth.models.ernie_image_text_encoder.ErnieImageTextEncoder": {
|
|
278
|
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
|
279
|
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
280
|
+
"transformers.models.ministral3.modeling_ministral3.Ministral3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
281
|
+
},
|
|
270
282
|
}
|
|
271
283
|
|
|
272
284
|
def QwenImageTextEncoder_Module_Map_Updater():
|
|
@@ -4,7 +4,7 @@ from typing_extensions import Literal
|
|
|
4
4
|
|
|
5
5
|
class FlowMatchScheduler():
|
|
6
6
|
|
|
7
|
-
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
|
|
7
|
+
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image"] = "FLUX.1"):
|
|
8
8
|
self.set_timesteps_fn = {
|
|
9
9
|
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
|
10
10
|
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
|
@@ -13,6 +13,7 @@ class FlowMatchScheduler():
|
|
|
13
13
|
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
|
14
14
|
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
|
|
15
15
|
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
|
|
16
|
+
"ERNIE-Image": FlowMatchScheduler.set_timesteps_ernie_image,
|
|
16
17
|
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
|
17
18
|
self.num_train_timesteps = 1000
|
|
18
19
|
|
|
@@ -129,6 +130,18 @@ class FlowMatchScheduler():
|
|
|
129
130
|
timesteps = sigmas * num_train_timesteps
|
|
130
131
|
return sigmas, timesteps
|
|
131
132
|
|
|
133
|
+
@staticmethod
|
|
134
|
+
def set_timesteps_ernie_image(num_inference_steps=50, denoising_strength=1.0, shift=3.0):
|
|
135
|
+
sigma_min = 0.0
|
|
136
|
+
sigma_max = 1.0
|
|
137
|
+
num_train_timesteps = 1000
|
|
138
|
+
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
|
139
|
+
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
|
140
|
+
if shift is not None and shift != 1.0:
|
|
141
|
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
|
142
|
+
timesteps = sigmas * num_train_timesteps
|
|
143
|
+
return sigmas, timesteps
|
|
144
|
+
|
|
132
145
|
@staticmethod
|
|
133
146
|
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
|
|
134
147
|
sigma_min = 0.0
|
|
@@ -185,7 +198,7 @@ class FlowMatchScheduler():
|
|
|
185
198
|
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
|
|
186
199
|
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
|
|
187
200
|
self.linear_timesteps_weights = bsmntw_weighing
|
|
188
|
-
|
|
201
|
+
|
|
189
202
|
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
|
|
190
203
|
self.sigmas, self.timesteps = self.set_timesteps_fn(
|
|
191
204
|
num_inference_steps=num_inference_steps,
|
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Ernie-Image DiT for DiffSynth-Studio.
|
|
3
|
+
|
|
4
|
+
Refactored from diffusers ErnieImageTransformer2DModel to use DiffSynth core modules.
|
|
5
|
+
Default parameters from actual checkpoint config.json (PaddlePaddle/ERNIE-Image transformer).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from typing import Optional, Tuple
|
|
12
|
+
|
|
13
|
+
from ..core.attention import attention_forward
|
|
14
|
+
from ..core.gradient import gradient_checkpoint_forward
|
|
15
|
+
from .flux2_dit import Timesteps, TimestepEmbedding
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
|
19
|
+
assert dim % 2 == 0
|
|
20
|
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
|
21
|
+
omega = 1.0 / (theta ** scale)
|
|
22
|
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
|
23
|
+
return out.float()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ErnieImageEmbedND3(nn.Module):
|
|
27
|
+
def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.dim = dim
|
|
30
|
+
self.theta = theta
|
|
31
|
+
self.axes_dim = list(axes_dim)
|
|
32
|
+
|
|
33
|
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
|
34
|
+
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
|
|
35
|
+
emb = emb.unsqueeze(2)
|
|
36
|
+
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ErnieImagePatchEmbedDynamic(nn.Module):
|
|
40
|
+
def __init__(self, in_channels: int, embed_dim: int, patch_size: int):
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.patch_size = patch_size
|
|
43
|
+
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
|
|
44
|
+
|
|
45
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
46
|
+
x = self.proj(x)
|
|
47
|
+
batch_size, dim, height, width = x.shape
|
|
48
|
+
return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ErnieImageSingleStreamAttnProcessor:
|
|
52
|
+
def __call__(
|
|
53
|
+
self,
|
|
54
|
+
attn: "ErnieImageAttention",
|
|
55
|
+
hidden_states: torch.Tensor,
|
|
56
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
57
|
+
freqs_cis: Optional[torch.Tensor] = None,
|
|
58
|
+
) -> torch.Tensor:
|
|
59
|
+
query = attn.to_q(hidden_states)
|
|
60
|
+
key = attn.to_k(hidden_states)
|
|
61
|
+
value = attn.to_v(hidden_states)
|
|
62
|
+
|
|
63
|
+
query = query.unflatten(-1, (attn.heads, -1))
|
|
64
|
+
key = key.unflatten(-1, (attn.heads, -1))
|
|
65
|
+
value = value.unflatten(-1, (attn.heads, -1))
|
|
66
|
+
|
|
67
|
+
if attn.norm_q is not None:
|
|
68
|
+
query = attn.norm_q(query)
|
|
69
|
+
if attn.norm_k is not None:
|
|
70
|
+
key = attn.norm_k(key)
|
|
71
|
+
|
|
72
|
+
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
|
73
|
+
rot_dim = freqs_cis.shape[-1]
|
|
74
|
+
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
|
|
75
|
+
cos_ = torch.cos(freqs_cis).to(x.dtype)
|
|
76
|
+
sin_ = torch.sin(freqs_cis).to(x.dtype)
|
|
77
|
+
x1, x2 = x.chunk(2, dim=-1)
|
|
78
|
+
x_rotated = torch.cat((-x2, x1), dim=-1)
|
|
79
|
+
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
|
|
80
|
+
|
|
81
|
+
if freqs_cis is not None:
|
|
82
|
+
query = apply_rotary_emb(query, freqs_cis)
|
|
83
|
+
key = apply_rotary_emb(key, freqs_cis)
|
|
84
|
+
|
|
85
|
+
if attention_mask is not None and attention_mask.ndim == 2:
|
|
86
|
+
attention_mask = attention_mask[:, None, None, :]
|
|
87
|
+
|
|
88
|
+
hidden_states = attention_forward(
|
|
89
|
+
query, key, value,
|
|
90
|
+
q_pattern="b s n d",
|
|
91
|
+
k_pattern="b s n d",
|
|
92
|
+
v_pattern="b s n d",
|
|
93
|
+
out_pattern="b s n d",
|
|
94
|
+
attn_mask=attention_mask,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
hidden_states = hidden_states.flatten(2, 3)
|
|
98
|
+
hidden_states = hidden_states.to(query.dtype)
|
|
99
|
+
output = attn.to_out[0](hidden_states)
|
|
100
|
+
|
|
101
|
+
return output
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class ErnieImageAttention(nn.Module):
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
query_dim: int,
|
|
108
|
+
heads: int = 8,
|
|
109
|
+
dim_head: int = 64,
|
|
110
|
+
dropout: float = 0.0,
|
|
111
|
+
bias: bool = False,
|
|
112
|
+
qk_norm: str = "rms_norm",
|
|
113
|
+
out_bias: bool = True,
|
|
114
|
+
eps: float = 1e-5,
|
|
115
|
+
out_dim: int = None,
|
|
116
|
+
elementwise_affine: bool = True,
|
|
117
|
+
):
|
|
118
|
+
super().__init__()
|
|
119
|
+
|
|
120
|
+
self.head_dim = dim_head
|
|
121
|
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
|
122
|
+
self.query_dim = query_dim
|
|
123
|
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
|
124
|
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
|
125
|
+
|
|
126
|
+
self.use_bias = bias
|
|
127
|
+
self.dropout = dropout
|
|
128
|
+
|
|
129
|
+
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
|
130
|
+
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
|
131
|
+
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
|
132
|
+
|
|
133
|
+
if qk_norm == "layer_norm":
|
|
134
|
+
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
|
135
|
+
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
|
136
|
+
elif qk_norm == "rms_norm":
|
|
137
|
+
self.norm_q = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
|
138
|
+
self.norm_k = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'rms_norm'."
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
self.to_out = nn.ModuleList([])
|
|
145
|
+
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
|
146
|
+
|
|
147
|
+
self.processor = ErnieImageSingleStreamAttnProcessor()
|
|
148
|
+
|
|
149
|
+
def forward(
|
|
150
|
+
self,
|
|
151
|
+
hidden_states: torch.Tensor,
|
|
152
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
153
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
154
|
+
) -> torch.Tensor:
|
|
155
|
+
return self.processor(self, hidden_states, attention_mask, image_rotary_emb)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class ErnieImageFeedForward(nn.Module):
|
|
159
|
+
def __init__(self, hidden_size: int, ffn_hidden_size: int):
|
|
160
|
+
super().__init__()
|
|
161
|
+
self.gate_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
|
|
162
|
+
self.up_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
|
|
163
|
+
self.linear_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
|
|
164
|
+
|
|
165
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
166
|
+
return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x)))
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class ErnieImageRMSNorm(nn.Module):
|
|
170
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
|
171
|
+
super().__init__()
|
|
172
|
+
self.eps = eps
|
|
173
|
+
self.weight = nn.Parameter(torch.ones(dim))
|
|
174
|
+
|
|
175
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
176
|
+
input_dtype = hidden_states.dtype
|
|
177
|
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
|
178
|
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
|
179
|
+
hidden_states = hidden_states * self.weight
|
|
180
|
+
return hidden_states.to(input_dtype)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class ErnieImageSharedAdaLNBlock(nn.Module):
|
|
184
|
+
def __init__(
|
|
185
|
+
self,
|
|
186
|
+
hidden_size: int,
|
|
187
|
+
num_heads: int,
|
|
188
|
+
ffn_hidden_size: int,
|
|
189
|
+
eps: float = 1e-6,
|
|
190
|
+
qk_layernorm: bool = True,
|
|
191
|
+
):
|
|
192
|
+
super().__init__()
|
|
193
|
+
self.adaLN_sa_ln = ErnieImageRMSNorm(hidden_size, eps=eps)
|
|
194
|
+
self.self_attention = ErnieImageAttention(
|
|
195
|
+
query_dim=hidden_size,
|
|
196
|
+
dim_head=hidden_size // num_heads,
|
|
197
|
+
heads=num_heads,
|
|
198
|
+
qk_norm="rms_norm" if qk_layernorm else None,
|
|
199
|
+
eps=eps,
|
|
200
|
+
bias=False,
|
|
201
|
+
out_bias=False,
|
|
202
|
+
)
|
|
203
|
+
self.adaLN_mlp_ln = ErnieImageRMSNorm(hidden_size, eps=eps)
|
|
204
|
+
self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size)
|
|
205
|
+
|
|
206
|
+
def forward(
|
|
207
|
+
self,
|
|
208
|
+
x: torch.Tensor,
|
|
209
|
+
rotary_pos_emb: torch.Tensor,
|
|
210
|
+
temb: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
|
211
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
212
|
+
) -> torch.Tensor:
|
|
213
|
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb
|
|
214
|
+
residual = x
|
|
215
|
+
x = self.adaLN_sa_ln(x)
|
|
216
|
+
x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
|
|
217
|
+
x_bsh = x.permute(1, 0, 2)
|
|
218
|
+
attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
|
|
219
|
+
attn_out = attn_out.permute(1, 0, 2)
|
|
220
|
+
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
|
|
221
|
+
residual = x
|
|
222
|
+
x = self.adaLN_mlp_ln(x)
|
|
223
|
+
x = (x.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
|
|
224
|
+
return residual + (gate_mlp.float() * self.mlp(x).float()).to(x.dtype)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class ErnieImageAdaLNContinuous(nn.Module):
|
|
228
|
+
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
|
229
|
+
super().__init__()
|
|
230
|
+
self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=eps)
|
|
231
|
+
self.linear = nn.Linear(hidden_size, hidden_size * 2)
|
|
232
|
+
|
|
233
|
+
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
|
|
234
|
+
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
|
|
235
|
+
x = self.norm(x)
|
|
236
|
+
x = x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
|
|
237
|
+
return x
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class ErnieImageDiT(nn.Module):
|
|
241
|
+
"""
|
|
242
|
+
Ernie-Image DiT model for DiffSynth-Studio.
|
|
243
|
+
|
|
244
|
+
Architecture: SharedAdaLN + RoPE 3D + Joint Image-Text Attention.
|
|
245
|
+
Internal format: [S, B, H] for transformer blocks, [B, S, H] for attention.
|
|
246
|
+
"""
|
|
247
|
+
|
|
248
|
+
def __init__(
|
|
249
|
+
self,
|
|
250
|
+
hidden_size: int = 4096,
|
|
251
|
+
num_attention_heads: int = 32,
|
|
252
|
+
num_layers: int = 36,
|
|
253
|
+
ffn_hidden_size: int = 12288,
|
|
254
|
+
in_channels: int = 128,
|
|
255
|
+
out_channels: int = 128,
|
|
256
|
+
patch_size: int = 1,
|
|
257
|
+
text_in_dim: int = 3072,
|
|
258
|
+
rope_theta: int = 256,
|
|
259
|
+
rope_axes_dim: Tuple[int, int, int] = (32, 48, 48),
|
|
260
|
+
eps: float = 1e-6,
|
|
261
|
+
qk_layernorm: bool = True,
|
|
262
|
+
):
|
|
263
|
+
super().__init__()
|
|
264
|
+
self.hidden_size = hidden_size
|
|
265
|
+
self.num_heads = num_attention_heads
|
|
266
|
+
self.head_dim = hidden_size // num_attention_heads
|
|
267
|
+
self.num_layers = num_layers
|
|
268
|
+
self.patch_size = patch_size
|
|
269
|
+
self.in_channels = in_channels
|
|
270
|
+
self.out_channels = out_channels
|
|
271
|
+
self.text_in_dim = text_in_dim
|
|
272
|
+
|
|
273
|
+
self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size)
|
|
274
|
+
self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None
|
|
275
|
+
self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0)
|
|
276
|
+
self.time_embedding = TimestepEmbedding(hidden_size, hidden_size)
|
|
277
|
+
self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim)
|
|
278
|
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size))
|
|
279
|
+
nn.init.zeros_(self.adaLN_modulation[-1].weight)
|
|
280
|
+
nn.init.zeros_(self.adaLN_modulation[-1].bias)
|
|
281
|
+
self.layers = nn.ModuleList([
|
|
282
|
+
ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm)
|
|
283
|
+
for _ in range(num_layers)
|
|
284
|
+
])
|
|
285
|
+
self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps)
|
|
286
|
+
self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels)
|
|
287
|
+
nn.init.zeros_(self.final_linear.weight)
|
|
288
|
+
nn.init.zeros_(self.final_linear.bias)
|
|
289
|
+
|
|
290
|
+
def forward(
|
|
291
|
+
self,
|
|
292
|
+
hidden_states: torch.Tensor,
|
|
293
|
+
timestep: torch.Tensor,
|
|
294
|
+
text_bth: torch.Tensor,
|
|
295
|
+
text_lens: torch.Tensor,
|
|
296
|
+
use_gradient_checkpointing: bool = False,
|
|
297
|
+
use_gradient_checkpointing_offload: bool = False,
|
|
298
|
+
) -> torch.Tensor:
|
|
299
|
+
device, dtype = hidden_states.device, hidden_states.dtype
|
|
300
|
+
B, C, H, W = hidden_states.shape
|
|
301
|
+
p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size
|
|
302
|
+
N_img = Hp * Wp
|
|
303
|
+
|
|
304
|
+
img_sbh = self.x_embedder(hidden_states).transpose(0, 1).contiguous()
|
|
305
|
+
|
|
306
|
+
if self.text_proj is not None and text_bth.numel() > 0:
|
|
307
|
+
text_bth = self.text_proj(text_bth)
|
|
308
|
+
Tmax = text_bth.shape[1]
|
|
309
|
+
text_sbh = text_bth.transpose(0, 1).contiguous()
|
|
310
|
+
|
|
311
|
+
x = torch.cat([img_sbh, text_sbh], dim=0)
|
|
312
|
+
S = x.shape[0]
|
|
313
|
+
|
|
314
|
+
text_ids = torch.cat([
|
|
315
|
+
torch.arange(Tmax, device=device, dtype=torch.float32).view(1, Tmax, 1).expand(B, -1, -1),
|
|
316
|
+
torch.zeros((B, Tmax, 2), device=device)
|
|
317
|
+
], dim=-1) if Tmax > 0 else torch.zeros((B, 0, 3), device=device)
|
|
318
|
+
grid_yx = torch.stack(
|
|
319
|
+
torch.meshgrid(torch.arange(Hp, device=device, dtype=torch.float32),
|
|
320
|
+
torch.arange(Wp, device=device, dtype=torch.float32), indexing="ij"),
|
|
321
|
+
dim=-1
|
|
322
|
+
).reshape(-1, 2)
|
|
323
|
+
image_ids = torch.cat([
|
|
324
|
+
text_lens.float().view(B, 1, 1).expand(-1, N_img, -1),
|
|
325
|
+
grid_yx.view(1, N_img, 2).expand(B, -1, -1)
|
|
326
|
+
], dim=-1)
|
|
327
|
+
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1))
|
|
328
|
+
|
|
329
|
+
valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) if Tmax > 0 else torch.zeros((B, 0), device=device, dtype=torch.bool)
|
|
330
|
+
attention_mask = torch.cat([
|
|
331
|
+
torch.ones((B, N_img), device=device, dtype=torch.bool),
|
|
332
|
+
valid_text
|
|
333
|
+
], dim=1)[:, None, None, :]
|
|
334
|
+
|
|
335
|
+
sample = self.time_proj(timestep.to(dtype))
|
|
336
|
+
sample = sample.to(self.time_embedding.linear_1.weight.dtype)
|
|
337
|
+
c = self.time_embedding(sample)
|
|
338
|
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
|
339
|
+
t.unsqueeze(0).expand(S, -1, -1).contiguous()
|
|
340
|
+
for t in self.adaLN_modulation(c).chunk(6, dim=-1)
|
|
341
|
+
]
|
|
342
|
+
|
|
343
|
+
for layer in self.layers:
|
|
344
|
+
temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp]
|
|
345
|
+
if torch.is_grad_enabled() and use_gradient_checkpointing:
|
|
346
|
+
x = gradient_checkpoint_forward(
|
|
347
|
+
layer,
|
|
348
|
+
use_gradient_checkpointing,
|
|
349
|
+
use_gradient_checkpointing_offload,
|
|
350
|
+
x,
|
|
351
|
+
rotary_pos_emb,
|
|
352
|
+
temb,
|
|
353
|
+
attention_mask,
|
|
354
|
+
)
|
|
355
|
+
else:
|
|
356
|
+
x = layer(x, rotary_pos_emb, temb, attention_mask)
|
|
357
|
+
|
|
358
|
+
x = self.final_norm(x, c).type_as(x)
|
|
359
|
+
patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous()
|
|
360
|
+
output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W)
|
|
361
|
+
|
|
362
|
+
return output
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Ernie-Image TextEncoder for DiffSynth-Studio.
|
|
3
|
+
|
|
4
|
+
Wraps transformers Ministral3Model to output text embeddings.
|
|
5
|
+
Pattern: lazy import + manual config dict + torch.nn.Module wrapper.
|
|
6
|
+
Only loads the text (language) model, ignoring vision components.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ErnieImageTextEncoder(torch.nn.Module):
|
|
13
|
+
"""
|
|
14
|
+
Text encoder using Ministral3Model (transformers).
|
|
15
|
+
Only the text_config portion of the full Mistral3Model checkpoint.
|
|
16
|
+
Uses the base model (no lm_head) since the checkpoint only has embeddings.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
super().__init__()
|
|
21
|
+
from transformers import Ministral3Config, Ministral3Model
|
|
22
|
+
|
|
23
|
+
text_config = {
|
|
24
|
+
"attention_dropout": 0.0,
|
|
25
|
+
"bos_token_id": 1,
|
|
26
|
+
"dtype": "bfloat16",
|
|
27
|
+
"eos_token_id": 2,
|
|
28
|
+
"head_dim": 128,
|
|
29
|
+
"hidden_act": "silu",
|
|
30
|
+
"hidden_size": 3072,
|
|
31
|
+
"initializer_range": 0.02,
|
|
32
|
+
"intermediate_size": 9216,
|
|
33
|
+
"max_position_embeddings": 262144,
|
|
34
|
+
"model_type": "ministral3",
|
|
35
|
+
"num_attention_heads": 32,
|
|
36
|
+
"num_hidden_layers": 26,
|
|
37
|
+
"num_key_value_heads": 8,
|
|
38
|
+
"pad_token_id": 11,
|
|
39
|
+
"rms_norm_eps": 1e-05,
|
|
40
|
+
"rope_parameters": {
|
|
41
|
+
"beta_fast": 32.0,
|
|
42
|
+
"beta_slow": 1.0,
|
|
43
|
+
"factor": 16.0,
|
|
44
|
+
"llama_4_scaling_beta": 0.1,
|
|
45
|
+
"mscale": 1.0,
|
|
46
|
+
"mscale_all_dim": 1.0,
|
|
47
|
+
"original_max_position_embeddings": 16384,
|
|
48
|
+
"rope_theta": 1000000.0,
|
|
49
|
+
"rope_type": "yarn",
|
|
50
|
+
"type": "yarn",
|
|
51
|
+
},
|
|
52
|
+
"sliding_window": None,
|
|
53
|
+
"tie_word_embeddings": True,
|
|
54
|
+
"use_cache": True,
|
|
55
|
+
"vocab_size": 131072,
|
|
56
|
+
}
|
|
57
|
+
config = Ministral3Config(**text_config)
|
|
58
|
+
self.model = Ministral3Model(config)
|
|
59
|
+
self.config = config
|
|
60
|
+
|
|
61
|
+
def forward(
|
|
62
|
+
self,
|
|
63
|
+
input_ids=None,
|
|
64
|
+
attention_mask=None,
|
|
65
|
+
position_ids=None,
|
|
66
|
+
**kwargs,
|
|
67
|
+
):
|
|
68
|
+
outputs = self.model(
|
|
69
|
+
input_ids=input_ids,
|
|
70
|
+
attention_mask=attention_mask,
|
|
71
|
+
position_ids=position_ids,
|
|
72
|
+
output_hidden_states=True,
|
|
73
|
+
return_dict=True,
|
|
74
|
+
**kwargs,
|
|
75
|
+
)
|
|
76
|
+
return (outputs.hidden_states,)
|