hcpdiff 0.9.1__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +80 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +1 -2
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +201 -0
- hcpdiff/easy/cfg/sdxl_train.py +140 -0
- hcpdiff/easy/cfg/t2i.py +177 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +18 -17
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +7 -2
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +24 -7
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +114 -125
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -130
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +78 -46
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.1.dist-info/METADATA +285 -0
- hcpdiff-2.1.dist-info/RECORD +114 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
- hcpdiff-2.1.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/deprecated/__init__.py +0 -1
- hcpdiff/deprecated/cfg_converter.py +0 -81
- hcpdiff/deprecated/lora_convert.py +0 -31
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -566
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/viser_fast.py +0 -138
- hcpdiff/visualizer.py +0 -265
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.1.dist-info/METADATA +0 -199
- hcpdiff-0.9.1.dist-info/RECORD +0 -160
- hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,285 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: hcpdiff
|
3
|
+
Version: 2.1
|
4
|
+
Summary: A universal Diffusion toolbox
|
5
|
+
Home-page: https://github.com/IrisRainbowNeko/HCP-Diffusion
|
6
|
+
Author: Ziyi Dong
|
7
|
+
Author-email: rainbow-neko@outlook.com
|
8
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
9
|
+
Classifier: Operating System :: OS Independent
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
11
|
+
Classifier: Programming Language :: Python :: 3.8
|
12
|
+
Classifier: Programming Language :: Python :: 3.9
|
13
|
+
Classifier: Programming Language :: Python :: 3.10
|
14
|
+
Classifier: Programming Language :: Python :: 3.11
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
17
|
+
Requires-Python: >=3.8
|
18
|
+
Description-Content-Type: text/markdown
|
19
|
+
License-File: LICENSE
|
20
|
+
Requires-Dist: rainbowneko
|
21
|
+
Requires-Dist: diffusers
|
22
|
+
Requires-Dist: matplotlib
|
23
|
+
Requires-Dist: pyarrow
|
24
|
+
Requires-Dist: transformers>=4.25.1
|
25
|
+
Requires-Dist: pytorch-msssim
|
26
|
+
Requires-Dist: lmdb
|
27
|
+
Dynamic: author
|
28
|
+
Dynamic: author-email
|
29
|
+
Dynamic: classifier
|
30
|
+
Dynamic: description
|
31
|
+
Dynamic: description-content-type
|
32
|
+
Dynamic: home-page
|
33
|
+
Dynamic: license-file
|
34
|
+
Dynamic: requires-dist
|
35
|
+
Dynamic: requires-python
|
36
|
+
Dynamic: summary
|
37
|
+
|
38
|
+
# HCP-Diffusion V2
|
39
|
+
|
40
|
+
[](https://pypi.org/project/hcpdiff/)
|
41
|
+
[](https://github.com/7eu7d7/HCP-Diffusion/stargazers)
|
42
|
+
[](https://github.com/7eu7d7/HCP-Diffusion/blob/master/LICENSE)
|
43
|
+
[](https://codecov.io/gh/7eu7d7/HCP-Diffusion)
|
44
|
+
[](https://github.com/7eu7d7/HCP-Diffusion/issues)
|
45
|
+
|
46
|
+
[📘中文说明](./README_cn.md)
|
47
|
+
|
48
|
+
[📘English document](https://hcpdiff.readthedocs.io/en/latest/)
|
49
|
+
[📘中文文档](https://hcpdiff.readthedocs.io/zh_CN/latest/)
|
50
|
+
|
51
|
+
Old HCP-Diffusion V1 at [main branch](https://github.com/IrisRainbowNeko/HCP-Diffusion/tree/main)
|
52
|
+
|
53
|
+
## Introduction
|
54
|
+
|
55
|
+
**HCP-Diffusion** is a Diffusion model toolbox built on top of the [🐱 RainbowNeko Engine](https://github.com/IrisRainbowNeko/RainbowNekoEngine).
|
56
|
+
It features a clean code structure and a flexible **Python-based configuration file**, making it easier to conduct and manage complex experiments. It includes a wide variety of training components, and compared to existing frameworks, it's more extensible, flexible, and user-friendly.
|
57
|
+
|
58
|
+
HCP-Diffusion allows you to use a single `.py` config file to unify training workflows across popular methods and model architectures, including Prompt-tuning (Textual Inversion), DreamArtist, Fine-tuning, DreamBooth, LoRA, ControlNet, ....
|
59
|
+
Different techniques can also be freely combined.
|
60
|
+
|
61
|
+
This framework also implements **DreamArtist++**, an upgraded version of DreamArtist based on LoRA. It enables high generalization and controllability with just a single image for training.
|
62
|
+
Compared to the original DreamArtist, it offers better stability, image quality, controllability, and faster training.
|
63
|
+
|
64
|
+
---
|
65
|
+
|
66
|
+
## Installation
|
67
|
+
|
68
|
+
Install via pip:
|
69
|
+
|
70
|
+
```bash
|
71
|
+
pip install hcpdiff
|
72
|
+
# Initialize configuration
|
73
|
+
hcpinit
|
74
|
+
```
|
75
|
+
|
76
|
+
Install from source:
|
77
|
+
|
78
|
+
```bash
|
79
|
+
git clone https://github.com/7eu7d7/HCP-Diffusion.git
|
80
|
+
cd HCP-Diffusion
|
81
|
+
pip install -e .
|
82
|
+
# Initialize configuration
|
83
|
+
hcpinit
|
84
|
+
```
|
85
|
+
|
86
|
+
Use xFormers to reduce memory usage and accelerate training:
|
87
|
+
|
88
|
+
```bash
|
89
|
+
# Choose the appropriate xformers version for your PyTorch version
|
90
|
+
pip install xformers==?
|
91
|
+
```
|
92
|
+
|
93
|
+
## 🚀 Python Configuration Files
|
94
|
+
RainbowNeko Engine supports configuration files written in a Python-like syntax. This allows users to call functions and classes directly within the configuration file, with function parameters inheritable from parent configuration files. The framework automatically handles the formatting of these configuration files.
|
95
|
+
|
96
|
+
For example, consider the following configuration file:
|
97
|
+
```python
|
98
|
+
dict(
|
99
|
+
layer=Linear(in_features=4, out_features=4)
|
100
|
+
)
|
101
|
+
```
|
102
|
+
During parsing, this will be automatically compiled into:
|
103
|
+
```python
|
104
|
+
dict(
|
105
|
+
layer=dict(_target_=Linear, in_features=4, out_features=4)
|
106
|
+
)
|
107
|
+
```
|
108
|
+
After parsing, the framework will instantiate the components accordingly. This means users can write configuration files using familiar Python syntax.
|
109
|
+
|
110
|
+
---
|
111
|
+
|
112
|
+
## ✨ Features
|
113
|
+
|
114
|
+
<details>
|
115
|
+
<summary>Features</summary>
|
116
|
+
|
117
|
+
### 📦 Model Support
|
118
|
+
|
119
|
+
| Model Name | Status |
|
120
|
+
|--------------------------|-------------|
|
121
|
+
| Stable Diffusion 1.5 | ✅ Supported |
|
122
|
+
| Stable Diffusion XL (SDXL)| ✅ Supported |
|
123
|
+
| PixArt | ✅ Supported |
|
124
|
+
| FLUX | 🚧 In Development |
|
125
|
+
| Stable Diffusion 3 (SD3) | 🚧 In Development |
|
126
|
+
|
127
|
+
---
|
128
|
+
|
129
|
+
### 🧠 Fine-Tuning Capabilities
|
130
|
+
|
131
|
+
| Feature | Description/Support |
|
132
|
+
|----------------------------------|---------------------|
|
133
|
+
| LoRA Layer-wise Configuration | ✅ Supported (including Conv2d) |
|
134
|
+
| Layer-wise Fine-Tuning | ✅ Supported |
|
135
|
+
| Multi-token Prompt-Tuning | ✅ Supported |
|
136
|
+
| Layer-wise Model Merging | ✅ Supported |
|
137
|
+
| Custom Optimizers | ✅ Supported (Lion, DAdaptation, pytorch-optimizer, etc.) |
|
138
|
+
| Custom LR Schedulers | ✅ Supported |
|
139
|
+
|
140
|
+
---
|
141
|
+
|
142
|
+
### 🧩 Extension Method Support
|
143
|
+
|
144
|
+
| Method | Status |
|
145
|
+
|--------------------------------|-------------|
|
146
|
+
| ControlNet (including training)| ✅ Supported |
|
147
|
+
| DreamArtist / DreamArtist++ | ✅ Supported |
|
148
|
+
| Token Attention Adjustment | ✅ Supported |
|
149
|
+
| Max Sentence Length Extension | ✅ Supported |
|
150
|
+
| Textual Inversion (Custom Tokens)| ✅ Supported |
|
151
|
+
| CLIP Skip | ✅ Supported |
|
152
|
+
|
153
|
+
---
|
154
|
+
|
155
|
+
### 🚀 Training Acceleration
|
156
|
+
|
157
|
+
| Tool/Library | Supported Modules |
|
158
|
+
|---------------------------------------------------|---------------------------|
|
159
|
+
| [🤗 Accelerate](https://github.com/huggingface/accelerate) | ✅ Supported |
|
160
|
+
| [Colossal-AI](https://github.com/hpcaitech/ColossalAI) | ✅ Supported |
|
161
|
+
| [xFormers](https://github.com/facebookresearch/xformers) | ✅ Supported (UNet and text encoder) |
|
162
|
+
|
163
|
+
---
|
164
|
+
|
165
|
+
### 🗂 Dataset Support
|
166
|
+
|
167
|
+
| Feature | Description |
|
168
|
+
|----------------------------------|-------------|
|
169
|
+
| Aspect Ratio Bucket (ARB) | ✅ Auto-clustering supported |
|
170
|
+
| Multi-source / Multi-dataset | ✅ Supported |
|
171
|
+
| LMDB | ✅ Supported |
|
172
|
+
| webdataset | 🚧 In Development |
|
173
|
+
| Local Attention Enhancement | ✅ Supported |
|
174
|
+
| Tag Shuffling & Dropout | ✅ Multiple tag editing strategies |
|
175
|
+
|
176
|
+
---
|
177
|
+
|
178
|
+
### 📉 Supported Loss Functions
|
179
|
+
|
180
|
+
| Loss Type | Description |
|
181
|
+
|------------|-------------|
|
182
|
+
| Min-SNR | ✅ Supported |
|
183
|
+
| SSIM | ✅ Supported |
|
184
|
+
| GWLoss | ✅ Supported |
|
185
|
+
|
186
|
+
---
|
187
|
+
|
188
|
+
### 🌫 Supported Diffusion Strategies
|
189
|
+
|
190
|
+
| Strategy Type | Status |
|
191
|
+
|------------------|--------------|
|
192
|
+
| DDPM | ✅ Supported |
|
193
|
+
| EDM | ✅ Supported |
|
194
|
+
| Flow Matching | ✅ Supported |
|
195
|
+
|
196
|
+
---
|
197
|
+
|
198
|
+
### 🧠 Automatic Evaluation (Step Selection Assistant)
|
199
|
+
|
200
|
+
| Feature | Description/Status |
|
201
|
+
|------------------|------------------------------------------|
|
202
|
+
| Image Preview | ✅ Supported (workflow preview) |
|
203
|
+
| FID | 🚧 In Development |
|
204
|
+
| CLIP Score | 🚧 In Development |
|
205
|
+
| CCIP Score | 🚧 In Development |
|
206
|
+
| Corrupt Score | 🚧 In Development |
|
207
|
+
|
208
|
+
</details>
|
209
|
+
|
210
|
+
---
|
211
|
+
|
212
|
+
## Getting Started
|
213
|
+
|
214
|
+
### Training
|
215
|
+
|
216
|
+
HCP-Diffusion provides training scripts based on 🤗 Accelerate.
|
217
|
+
|
218
|
+
```bash
|
219
|
+
# Multi-GPU training, configure GPUs in cfgs/launcher/multi.yaml
|
220
|
+
hcp_train --cfg cfgs/train/py/your_config.py
|
221
|
+
|
222
|
+
# Single-GPU training, configure GPU in cfgs/launcher/single.yaml
|
223
|
+
hcp_train_1gpu --cfg cfgs/train/py/your_config.py
|
224
|
+
```
|
225
|
+
|
226
|
+
You can also override config items via command line:
|
227
|
+
|
228
|
+
```bash
|
229
|
+
# Override base model path
|
230
|
+
hcp_train --cfg cfgs/train/py/your_config.py model.wrapper.models.ckpt_path=pretrained_model_path
|
231
|
+
```
|
232
|
+
|
233
|
+
### Image Generation
|
234
|
+
|
235
|
+
Use the workflow defined in the Python config to generate images:
|
236
|
+
|
237
|
+
```bash
|
238
|
+
hcp_run --cfg cfgs/workflow/text2img.py
|
239
|
+
```
|
240
|
+
|
241
|
+
Or override parameters via command line:
|
242
|
+
|
243
|
+
```bash
|
244
|
+
hcp_run --cfg cfgs/workflow/text2img_cli.py \
|
245
|
+
pretrained_model=pretrained_model_path \
|
246
|
+
prompt='positive_prompt' \
|
247
|
+
negative_prompt='negative_prompt' \
|
248
|
+
seed=42
|
249
|
+
```
|
250
|
+
|
251
|
+
### Tutorials
|
252
|
+
|
253
|
+
🚧 In Development
|
254
|
+
|
255
|
+
---
|
256
|
+
|
257
|
+
## Contributing
|
258
|
+
|
259
|
+
We welcome contributions to support more models and features.
|
260
|
+
|
261
|
+
---
|
262
|
+
|
263
|
+
## Team
|
264
|
+
|
265
|
+
Maintained by [HCP-Lab at Sun Yat-sen University](https://www.sysu-hcp.net/).
|
266
|
+
|
267
|
+
---
|
268
|
+
|
269
|
+
## Citation
|
270
|
+
|
271
|
+
```bibtex
|
272
|
+
@article{DBLP:journals/corr/abs-2211-11337,
|
273
|
+
author = {Ziyi Dong and
|
274
|
+
Pengxu Wei and
|
275
|
+
Liang Lin},
|
276
|
+
title = {DreamArtist: Towards Controllable One-Shot Text-to-Image Generation
|
277
|
+
via Positive-Negative Prompt-Tuning},
|
278
|
+
journal = {CoRR},
|
279
|
+
volume = {abs/2211.11337},
|
280
|
+
year = {2022},
|
281
|
+
doi = {10.48550/arXiv.2211.11337},
|
282
|
+
eprinttype = {arXiv},
|
283
|
+
eprint = {2211.11337},
|
284
|
+
}
|
285
|
+
```
|
@@ -0,0 +1,114 @@
|
|
1
|
+
hcpdiff/__init__.py,sha256=dwNwrEgvG4g60fGMG6b50K3q3AWD1XCfzlIgbxkSUpE,177
|
2
|
+
hcpdiff/train_colo.py,sha256=EsuNSzLBvGTZWU_LEk0JpP-F5eNW0lwkawIRAX38jmE,9250
|
3
|
+
hcpdiff/train_deepspeed.py,sha256=PwyNukWi0of6TXy_VRDgBQSMLCZBhipO5g3Lq0nCYNk,2988
|
4
|
+
hcpdiff/trainer_ac.py,sha256=6KAzo54in7ZRHud_rHjJdwRRZ4uWtc0B4SxVCxgcrmM,2990
|
5
|
+
hcpdiff/trainer_ac_single.py,sha256=0PIC5EScqcxp49EaeIWq4KS5K_09OZfKajqbFu-hUb8,1108
|
6
|
+
hcpdiff/ckpt_manager/__init__.py,sha256=LfMwz9R4jV4xpiSFt5vhpwaF7-8UHEZ_iDoW-3QGvt0,239
|
7
|
+
hcpdiff/ckpt_manager/ckpt.py,sha256=Pa3uXQbCi2T99mpV5fYddQ-OGHcpk8r1ll-0lmP_WXk,965
|
8
|
+
hcpdiff/ckpt_manager/loader.py,sha256=Ch1xsZmseq4nyPhpox9-nebN-dZB4k0rqBEHos-ZLso,3245
|
9
|
+
hcpdiff/ckpt_manager/format/__init__.py,sha256=a3cdKkOTDgdVbDQwSC4mlxOigjX2hBvRb5_X7E3TQWs,237
|
10
|
+
hcpdiff/ckpt_manager/format/diffusers.py,sha256=T81WN95Nj1il9DfQp9iioVn0uqFEWOlmdIYs2beNOFU,3769
|
11
|
+
hcpdiff/ckpt_manager/format/emb.py,sha256=FrqfTfJ8H7f0Zw17NTWCP2AJtpsJI5oXR5IAd4NekhU,680
|
12
|
+
hcpdiff/ckpt_manager/format/lora_webui.py,sha256=j7SpXnSx_Ys8tnWBgojuB1HEJIm46lhCBuNNYLhaF9w,9824
|
13
|
+
hcpdiff/ckpt_manager/format/sd_single.py,sha256=LpCAL_7nAVooCHTFznVVsNMku1G3C77NBORxxr8GDtQ,2328
|
14
|
+
hcpdiff/data/__init__.py,sha256=-z47HsEQSubc-AfriVComMACbQXlXTWAKMOPBkATHxA,258
|
15
|
+
hcpdiff/data/dataset.py,sha256=1k4GldW13eVyqK_9hrQniqr3_XYAapnWF7iXl_1GXGg,877
|
16
|
+
hcpdiff/data/cache/__init__.py,sha256=ToCmokYH6DghlSwm7HJFirPRIWJ0LkgzqVOYlgoAkQw,25
|
17
|
+
hcpdiff/data/cache/vae.py,sha256=gB89zs4CdNlvukDXhVYU9QZrY6VTFUWfzjeF2psNQ50,4070
|
18
|
+
hcpdiff/data/handler/__init__.py,sha256=D1HyqY0qfrUHgf25itpYj57JUvgn06G6EQ9d2vRRtys,236
|
19
|
+
hcpdiff/data/handler/controlnet.py,sha256=bRDMD9BP8-VaG5VrxzvcFKfkqeTbChNfrJSZ3vXbQgY,658
|
20
|
+
hcpdiff/data/handler/diffusion.py,sha256=8n60UYdGNR08xw45HoI4EB5AaIui03tSGNDfjazO-5w,3516
|
21
|
+
hcpdiff/data/handler/text.py,sha256=gOzqB2oEkEUbiuy0kZWduo0c-w4Buu60KI6q6Nyl3aM,4208
|
22
|
+
hcpdiff/data/source/__init__.py,sha256=AB1VicA272KjTm-Q5L6XvDM8CLQhVPylAPuPMtpfw4g,158
|
23
|
+
hcpdiff/data/source/folder_class.py,sha256=bs4qPMTzwcnT6ZFlT3tpi9sclsRF9a2MBA1pQD-9EYs,961
|
24
|
+
hcpdiff/data/source/text2img.py,sha256=MWXqAEbzmK6pkBY40t9u37ngY25mgdKQ2idwNld8-bo,1826
|
25
|
+
hcpdiff/data/source/text2img_cond.py,sha256=yj1KpARA2rkjENutnnzC4uDkcU2Rye21FL2VdC25Hac,585
|
26
|
+
hcpdiff/diffusion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
|
+
hcpdiff/diffusion/noise/__init__.py,sha256=seBpOtd0YsU53PqMn7Nyl_RtwoC-ONEIOX7v2XLGpZQ,93
|
28
|
+
hcpdiff/diffusion/noise/pyramid_noise.py,sha256=KbpyMT1BHNIaAa7g5eECDkTttOMoMWVFmbP-ekBsuEY,1693
|
29
|
+
hcpdiff/diffusion/noise/zero_terminal.py,sha256=EfVOaqrTCfw11AolDBl0LIOey3uQT1bDw2XKr2Bm434,1532
|
30
|
+
hcpdiff/diffusion/sampler/__init__.py,sha256=pSHsKpLjscY5yLbdzHeBUeK9nFDuVeMIIeA_k6FQFdY,158
|
31
|
+
hcpdiff/diffusion/sampler/base.py,sha256=2AuPVT2ZSXYt2etZmHMyNKuGlT5zn6KIkoMz4m5PGcs,2577
|
32
|
+
hcpdiff/diffusion/sampler/ddpm.py,sha256=raqSuKsEPN1AEqRVCuBdMAOnKDoeJTRO17wtLBNJCf4,523
|
33
|
+
hcpdiff/diffusion/sampler/diffusers.py,sha256=XIu-oIlT4LnAYI2-yyIoNeIMeSQe5YpeEvn-FkGVFnE,2684
|
34
|
+
hcpdiff/diffusion/sampler/edm.py,sha256=5W4pv8hxQsPpJGiFBgElZxR3C9S8kWAhzGKejEVwq3I,753
|
35
|
+
hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py,sha256=kmIoWgsWqij6b7KYon3UOCSC07sRJCo-DPR6qkJwUd0,184
|
36
|
+
hcpdiff/diffusion/sampler/sigma_scheduler/base.py,sha256=9-JI-jwf7xZoQUtrU0qfbjkhNZT8a_tmapLtwVbFUx0,381
|
37
|
+
hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py,sha256=2PMIpg2K6CVoxew1y1pIqvCHbdggC-m3amrOYk15OdQ,8107
|
38
|
+
hcpdiff/diffusion/sampler/sigma_scheduler/edm.py,sha256=fOPB3lgnS9uVo4oW26Fur_nc8X_wQ6mmUcbkKhnoQjs,1900
|
39
|
+
hcpdiff/easy/__init__.py,sha256=-emoyCOZlLCu3KNMI8L4qapUEtEYFSoiGU6-rKv1at4,149
|
40
|
+
hcpdiff/easy/sampler.py,sha256=dQSBkeGh71O0DAmZLhTHTbk1bY7XzyUCeW1oJO14A4I,1250
|
41
|
+
hcpdiff/easy/cfg/__init__.py,sha256=aVDEDPxHdX5n-aFkP_4ic8ZhQfSeKu8lZOkgW_4m398,221
|
42
|
+
hcpdiff/easy/cfg/sd15_train.py,sha256=LRCJLHNU0JEd1m3MC_NFWUCw5LmwztiLiJlV7u_DeKM,6493
|
43
|
+
hcpdiff/easy/cfg/sdxl_train.py,sha256=R0wolSVOrRlI9A-vAfz592SzSnwuDd4ku1oc5yRKrfU,4038
|
44
|
+
hcpdiff/easy/cfg/t2i.py,sha256=6Pyy4werXNalwoBBHVMBLBg67kMS85Heb7R3t26GJqQ,6871
|
45
|
+
hcpdiff/easy/model/__init__.py,sha256=CA-7r3R2Jgweekk1XNByFYttLolbWyUV2bCnXygcD8w,133
|
46
|
+
hcpdiff/easy/model/cnet.py,sha256=m0NTH9V1kLzb5GybwBrSNT0KvTcRpPfGkzUeMz9jZZQ,1084
|
47
|
+
hcpdiff/easy/model/loader.py,sha256=Tdx-lhQEYf2NYjVM1A5B8x6ZZpJKcXUkFIPIbr7h7XM,3456
|
48
|
+
hcpdiff/evaluate/__init__.py,sha256=CtNzi8xdUWZBDBcP5TZTDMcRyykaOJhBIxTJgVuMabo,35
|
49
|
+
hcpdiff/evaluate/previewer.py,sha256=QiYYiEJBKP06uL3wKLhnpIGUZuAkr9BuHxTE29obpXI,2432
|
50
|
+
hcpdiff/loss/__init__.py,sha256=2dwPczSiv3rB5fzOeYbl5ZHpMU-qXOQlXeOiXdxcxwM,173
|
51
|
+
hcpdiff/loss/base.py,sha256=3bvgMbwyPOEA9iSkv0hRHw4VnKjkUCZAENNnDMFilYM,1780
|
52
|
+
hcpdiff/loss/gw.py,sha256=0yi1kozuII3xZA6FnjOhINtvScWt1MyBZLBtMKmgojM,1224
|
53
|
+
hcpdiff/loss/ssim.py,sha256=YofadvBkc6sklxBUx1p3ADw5OHOZPK3kaHz8FH5a6m4,1281
|
54
|
+
hcpdiff/loss/vlb.py,sha256=s78iBnXUiDWfGf7mYmhUnHqxqea5gSByKOoqBrX6bzU,3222
|
55
|
+
hcpdiff/loss/weighting.py,sha256=UR2PyZ1JTNOydXMw4e1Fh52XmtwKaviPvddcmVKCTlI,2242
|
56
|
+
hcpdiff/models/__init__.py,sha256=eQS7DPiGLiE1MFRkZj_17IY3IsfDUVcYpcOmhHb5B9o,472
|
57
|
+
hcpdiff/models/cfg_context.py,sha256=e2B3K1KwJhzbD6xdJUOyNtl_XgQ0296XI3FHw3gvZF4,1502
|
58
|
+
hcpdiff/models/container.py,sha256=z3p5TmQhxdzXSIfofz55_bmEhSsgUJsy1o9EcDs8Oeo,696
|
59
|
+
hcpdiff/models/controlnet.py,sha256=VIkUzJCVpCqqQOtRSLQPfbcDy9CsXutxLeZB6PdZfA0,7809
|
60
|
+
hcpdiff/models/lora_base.py,sha256=LGwBD9KP6qf4pgTx24i5-JLo4rDBQ6jFfterQKBjTbE,6758
|
61
|
+
hcpdiff/models/lora_base_patch.py,sha256=WW3CULnROTxKXyynJiqirhHYCKN5JtxLhVpT5b7AUQg,6532
|
62
|
+
hcpdiff/models/lora_layers.py,sha256=O9W_Ue71lHj7Y_GbpioF4Hc3h2-z_zOqck93VYUra6s,7777
|
63
|
+
hcpdiff/models/lora_layers_patch.py,sha256=GYFYsJD2VSLZfdnLma9CmQEHz09HROFJcc4wc_gs9f0,8198
|
64
|
+
hcpdiff/models/text_emb_ex.py,sha256=a5QImxzvj0zWR12qXOPP9kmpESl8J9VLabA0W9D_i_c,7867
|
65
|
+
hcpdiff/models/textencoder_ex.py,sha256=JrTQ30Avx8tPbdr-Q6K5BvEWCEdsu8Z7eSOzMqpUuzg,8270
|
66
|
+
hcpdiff/models/tokenizer_ex.py,sha256=zKUn4BY7b3yXwK9PWkZtQKJPyKYwUc07E-hwB9NQybs,2446
|
67
|
+
hcpdiff/models/compose/__init__.py,sha256=lTNFTGg5csqvUuys22RqgjmWlk_7Okw6ZTsnTi1pqCg,217
|
68
|
+
hcpdiff/models/compose/compose_hook.py,sha256=FfDSfn5FuLFGM80HMUwiUopy1P4xDbvKSBDuA6QK2So,6112
|
69
|
+
hcpdiff/models/compose/compose_textencoder.py,sha256=tiFoStKOIEH9YzsZQrLki4gra18kMy3wSzSUrVQG1sk,6607
|
70
|
+
hcpdiff/models/compose/compose_tokenizer.py,sha256=g3l0pOFv6p7Iigxm6Pqt_iTUXBlO1_SWAQOt0m54IoE,3033
|
71
|
+
hcpdiff/models/compose/sdxl_composer.py,sha256=NtMGaFGZTfKsPJSVi2yT-UM6K1WKWtk99XxVmTcKlk8,2164
|
72
|
+
hcpdiff/models/wrapper/__init__.py,sha256=HbGQmFnfccr-dtvZKjEv-pmR4cCnF4fwGLKS3tuG_OY,135
|
73
|
+
hcpdiff/models/wrapper/pixart.py,sha256=nRUvHSHn4TYg_smC0xpeW-GtUgXss-MuaVPTHpMozDE,1147
|
74
|
+
hcpdiff/models/wrapper/sd.py,sha256=D7VDI4OmbLTk9mzYta-C4LJjWfZmuBiDub4t8v1-M9o,11711
|
75
|
+
hcpdiff/models/wrapper/utils.py,sha256=NyebMoAPnrgcTHbiIocSD-eGdGdD-V1G_TQuWsRWufw,665
|
76
|
+
hcpdiff/parser/__init__.py,sha256=-2dDZ2Ii4zoGQqDTme94q4PpJbBiV6HS5BsDASz4Xbo,33
|
77
|
+
hcpdiff/parser/embpt.py,sha256=LgwZ0f0tLn3DrTo5ZpSCsZcA5330UpiW_sK96yEPmOM,1307
|
78
|
+
hcpdiff/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
79
|
+
hcpdiff/tools/convert_caption_txt2json.py,sha256=tbBgIphJWvXUoXjtwsnLX2w9IZEY3jTgxbTvUMgukbM,945
|
80
|
+
hcpdiff/tools/convert_old_lora.py,sha256=yIP9RGcyQbwT2NNAZtTLgBXs6XJOHRvoHQep0SdqDho,453
|
81
|
+
hcpdiff/tools/create_embedding.py,sha256=qbBx3mFPeVcVSeMWn5_-3Dq7EqWn0UTOQg6OurIwDeo,4577
|
82
|
+
hcpdiff/tools/dataset_generator.py,sha256=CG28IP1w0XAlEG1xPB0jQi40hVaFGxLe76tnV6hya-E,4639
|
83
|
+
hcpdiff/tools/diffusers2sd.py,sha256=MdZirlyztaaO4UAYU_AJa25eTFKfJa08nbHbdcqOdhs,13300
|
84
|
+
hcpdiff/tools/download_hf_model.py,sha256=3XZE-GYotE0W5ZQNk78ljZgolYhjbNKxfU-tB9I0yDY,947
|
85
|
+
hcpdiff/tools/embedding_convert.py,sha256=JN56dVFcMng8kq9YGS02_Umco71UjDYvwQGeqbhsJXU,1498
|
86
|
+
hcpdiff/tools/gen_from_ptlist.py,sha256=j4_wNd9JmsWUOcgkP4q4P1gB1SVnzy0VxY4pJqXM07Q,3105
|
87
|
+
hcpdiff/tools/init_proj.py,sha256=XrXxxhIaItywG7HsrloJo-x8w9suZiY35daelzZvjrg,194
|
88
|
+
hcpdiff/tools/lora_convert.py,sha256=So14WvSVIm6rU4m1XCajFXDnhq7abpZS95SLbaoyBFU,10058
|
89
|
+
hcpdiff/tools/save_model.py,sha256=gbfYi_EfEBZEUcDjle6MDHA19sQWY0zA8_y_LMzHQ7M,428
|
90
|
+
hcpdiff/tools/sd2diffusers.py,sha256=vB6OnBLw60sJkdpVZcYEPtKAZW1h8ErbSGSRq0uAiIk,16855
|
91
|
+
hcpdiff/utils/__init__.py,sha256=VOLhdNq2mRyqmWxrssIWSZtR_PQ8rFwo2u0uq6GbLHA,45
|
92
|
+
hcpdiff/utils/colo_utils.py,sha256=JyLUvVnISa48CnryNLrgVxMo-jxu2UhBq70eYPrkjuI,837
|
93
|
+
hcpdiff/utils/inpaint_pipe.py,sha256=CRy1MUlPmHifCAbZnKOP0qbLp2grn7ZbVeaB2qIA4ig,42862
|
94
|
+
hcpdiff/utils/net_utils.py,sha256=gdwLYDNKV2t3SP0jBIO3d0HtY6E7jRaf_rmPT8gKZZE,9762
|
95
|
+
hcpdiff/utils/pipe_hook.py,sha256=-UDX3FtZGl-bxSk13gdbPXc1OvtbCcpk_fvKxLQo3Ag,31987
|
96
|
+
hcpdiff/utils/utils.py,sha256=hZnZP1IETgVpScxES0yIuRfc34TnzvAqmgOTK_56ssw,4976
|
97
|
+
hcpdiff/workflow/__init__.py,sha256=t7Zyc0XFORdNvcwHp9AsCtEkhJ3l7Hm41ugngIL0Sag,867
|
98
|
+
hcpdiff/workflow/diffusion.py,sha256=yrl2cXE2d2FNeVzYZDRQNLjy5-QnVgOWioIHSmszk2Y,8662
|
99
|
+
hcpdiff/workflow/fast.py,sha256=kZt7bKrvpFInSn7GzbkTkpoCSM0Z6IbDjgaDvcbFYf8,1024
|
100
|
+
hcpdiff/workflow/flow.py,sha256=FFbFFOAXT4c31L5bHBEB_qeVGuBQDLYhq8kTD1chGNo,2548
|
101
|
+
hcpdiff/workflow/io.py,sha256=aTrMR3s44apVJpnSyvZIabW2Op0tslk_Z9JFJl5svm0,2635
|
102
|
+
hcpdiff/workflow/model.py,sha256=1gj5yOTefYTnGXVR6JPAfxIwuB69YwN6E-BontRcuyQ,2913
|
103
|
+
hcpdiff/workflow/text.py,sha256=FSFUm_zEeZjMeg0qRXZAPplnJkg2pR_2FA3XljpoN2w,5110
|
104
|
+
hcpdiff/workflow/utils.py,sha256=xojaMG4lHsymslc8df5uiVXmmBVWpn_Phqka8qzJEWw,2226
|
105
|
+
hcpdiff/workflow/vae.py,sha256=cingDPkIOc4qGpOwwhXJK4EQbGoIxO583pm6gGov5t8,3118
|
106
|
+
hcpdiff/workflow/daam/__init__.py,sha256=ySIDaxloN-D3qM7OuVaG1BR3D-CibDoXYpoTgw0zUhU,59
|
107
|
+
hcpdiff/workflow/daam/act.py,sha256=tHbsFWTYYU4bvcZOo1Bpi_z6ofpJatRYccl4vvf8wIA,2756
|
108
|
+
hcpdiff/workflow/daam/hook.py,sha256=z9f9mBjKW21xuUZ-iQxQ0HbWOBXtZrisFB0VNMq6d0U,4383
|
109
|
+
hcpdiff-2.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
110
|
+
hcpdiff-2.1.dist-info/METADATA,sha256=NpBZuj23d1gTKPQhJ0TBRV8QsfICa4LCGSk6PJNniSw,9248
|
111
|
+
hcpdiff-2.1.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
112
|
+
hcpdiff-2.1.dist-info/entry_points.txt,sha256=86wPOMzsfWWflTJ-sQPLc7WG5Vtu0kGYBH9C_vR3ur8,207
|
113
|
+
hcpdiff-2.1.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
|
114
|
+
hcpdiff-2.1.dist-info/RECORD,,
|
hcpdiff/ckpt_manager/base.py
DELETED
@@ -1,16 +0,0 @@
|
|
1
|
-
from diffusers import StableDiffusionPipeline
|
2
|
-
from diffusers.models.lora import LoRACompatibleLinear
|
3
|
-
|
4
|
-
class CkptManagerBase:
|
5
|
-
def __init__(self, **kwargs):
|
6
|
-
pass
|
7
|
-
|
8
|
-
def set_save_dir(self, save_dir, emb_dir=None):
|
9
|
-
raise NotImplementedError()
|
10
|
-
|
11
|
-
def save(self, step, unet, TE, lora_unet, lora_TE, all_plugin_unet, all_plugin_TE, embs, pipe: StableDiffusionPipeline, **kwargs):
|
12
|
-
raise NotImplementedError()
|
13
|
-
|
14
|
-
@classmethod
|
15
|
-
def load(cls, pretrained_model, **kwargs) -> StableDiffusionPipeline:
|
16
|
-
raise NotImplementedError
|
@@ -1,45 +0,0 @@
|
|
1
|
-
from .base import CkptManagerBase
|
2
|
-
import os
|
3
|
-
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
|
4
|
-
from hcpdiff.models.plugin import BasePluginBlock
|
5
|
-
|
6
|
-
|
7
|
-
class CkptManagerDiffusers(CkptManagerBase):
|
8
|
-
|
9
|
-
def set_save_dir(self, save_dir, emb_dir=None):
|
10
|
-
os.makedirs(save_dir, exist_ok=True)
|
11
|
-
self.save_dir = save_dir
|
12
|
-
self.emb_dir = emb_dir
|
13
|
-
|
14
|
-
def save(self, step, unet, TE, lora_unet, lora_TE, all_plugin_unet, all_plugin_TE, embs, pipe: StableDiffusionPipeline, **kwargs):
|
15
|
-
def state_dict_unet(*args, model=unet, **kwargs):
|
16
|
-
plugin_names = {k for k, v in model.named_modules() if isinstance(v, BasePluginBlock)}
|
17
|
-
model_sd = {}
|
18
|
-
for k, v in model.state_dict_().items():
|
19
|
-
for name in plugin_names:
|
20
|
-
if k.startswith(name):
|
21
|
-
break
|
22
|
-
else:
|
23
|
-
model_sd[k] = v
|
24
|
-
return model_sd
|
25
|
-
unet.state_dict_ = unet.state_dict
|
26
|
-
unet.state_dict = state_dict_unet
|
27
|
-
|
28
|
-
def state_dict_TE(*args, model=TE, **kwargs):
|
29
|
-
plugin_names = {k for k, v in model.named_modules() if isinstance(v, BasePluginBlock)}
|
30
|
-
model_sd = {}
|
31
|
-
for k, v in model.state_dict_().items():
|
32
|
-
for name in plugin_names:
|
33
|
-
if k.startswith(name):
|
34
|
-
break
|
35
|
-
else:
|
36
|
-
model_sd[k] = v
|
37
|
-
return model_sd
|
38
|
-
TE.state_dict_ = TE.state_dict
|
39
|
-
TE.state_dict = state_dict_TE
|
40
|
-
|
41
|
-
pipe.save_pretrained(os.path.join(self.save_dir, f"model-{step}"), **kwargs)
|
42
|
-
|
43
|
-
@classmethod
|
44
|
-
def load(cls, pretrained_model, **kwargs) -> StableDiffusionPipeline:
|
45
|
-
return StableDiffusionPipeline.from_pretrained(pretrained_model, **kwargs)
|
hcpdiff/ckpt_manager/ckpt_pkl.py
DELETED
@@ -1,138 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
ckpt_pkl.py
|
3
|
-
====================
|
4
|
-
:Name: save model with torch
|
5
|
-
:Author: Dong Ziyi
|
6
|
-
:Affiliation: HCP Lab, SYSU
|
7
|
-
:Created: 8/04/2023
|
8
|
-
:Licence: MIT
|
9
|
-
"""
|
10
|
-
|
11
|
-
from typing import Dict
|
12
|
-
import os
|
13
|
-
|
14
|
-
import torch
|
15
|
-
from torch import nn
|
16
|
-
|
17
|
-
from hcpdiff.models.lora_base import LoraBlock, LoraGroup, split_state
|
18
|
-
from hcpdiff.models.plugin import PluginGroup, BasePluginBlock
|
19
|
-
from hcpdiff.utils.net_utils import save_emb
|
20
|
-
from .base import CkptManagerBase
|
21
|
-
|
22
|
-
class CkptManagerPKL(CkptManagerBase):
|
23
|
-
def __init__(self, plugin_from_raw=False, **kwargs):
|
24
|
-
self.plugin_from_raw = plugin_from_raw
|
25
|
-
|
26
|
-
def set_save_dir(self, save_dir, emb_dir=None):
|
27
|
-
os.makedirs(save_dir, exist_ok=True)
|
28
|
-
self.save_dir = save_dir
|
29
|
-
self.emb_dir = emb_dir
|
30
|
-
|
31
|
-
def exclude_state(self, state, key):
|
32
|
-
if key is None:
|
33
|
-
return state
|
34
|
-
else:
|
35
|
-
return {k: v for k, v in state.items() if key not in k}
|
36
|
-
|
37
|
-
def save_model(self, model: nn.Module, name, step, model_ema=None, exclude_key=None):
|
38
|
-
sd_model = {
|
39
|
-
'base': self.exclude_state(LoraBlock.extract_trainable_state_without_lora(model), exclude_key),
|
40
|
-
}
|
41
|
-
if model_ema is not None:
|
42
|
-
sd_ema, sd_ema_lora = split_state(model_ema.state_dict())
|
43
|
-
sd_model['base_ema'] = self.exclude_state(sd_ema, exclude_key)
|
44
|
-
self._save_ckpt(sd_model, name, step)
|
45
|
-
|
46
|
-
def save_plugins(self, host_model: nn.Module, plugins: Dict[str, PluginGroup], name:str, step:int, model_ema=None):
|
47
|
-
if len(plugins)>0:
|
48
|
-
sd_plugin={}
|
49
|
-
for plugin_name, plugin in plugins.items():
|
50
|
-
sd_plugin['plugin'] = plugin.state_dict(host_model if self.plugin_from_raw else None)
|
51
|
-
if model_ema is not None:
|
52
|
-
sd_plugin['plugin_ema'] = plugin.state_dict(model_ema)
|
53
|
-
self._save_ckpt(sd_plugin, f'{name}-{plugin_name}', step)
|
54
|
-
|
55
|
-
def save_model_with_lora(self, model: nn.Module, lora_blocks: LoraGroup, name:str, step:int, model_ema=None,
|
56
|
-
exclude_key=None):
|
57
|
-
sd_model = {
|
58
|
-
'base': self.exclude_state(BasePluginBlock.extract_state_without_plugin(model, trainable=True), exclude_key),
|
59
|
-
} if model is not None else {}
|
60
|
-
if (lora_blocks is not None) and (not lora_blocks.empty()):
|
61
|
-
sd_model['lora'] = lora_blocks.state_dict(model if self.plugin_from_raw else None)
|
62
|
-
|
63
|
-
if model_ema is not None:
|
64
|
-
ema_state = model_ema.state_dict()
|
65
|
-
if model is not None:
|
66
|
-
sd_ema = {k:ema_state[k] for k in sd_model['base'].keys()}
|
67
|
-
sd_model['base_ema'] = self.exclude_state(sd_ema, exclude_key)
|
68
|
-
if (lora_blocks is not None) and (not lora_blocks.empty()):
|
69
|
-
sd_model['lora_ema'] = lora_blocks.state_dict(model_ema)
|
70
|
-
|
71
|
-
self._save_ckpt(sd_model, name, step)
|
72
|
-
|
73
|
-
def _save_ckpt(self, sd_model, name=None, step=None, save_path=None):
|
74
|
-
if save_path is None:
|
75
|
-
save_path = os.path.join(self.save_dir, f"{name}-{step}.ckpt")
|
76
|
-
torch.save(sd_model, save_path)
|
77
|
-
|
78
|
-
def load_ckpt(self, ckpt_path, map_location='cpu'):
|
79
|
-
return torch.load(ckpt_path, map_location=map_location)
|
80
|
-
|
81
|
-
def load_ckpt_to_model(self, model: nn.Module, ckpt_path, model_ema=None):
|
82
|
-
sd = self.load_ckpt(ckpt_path)
|
83
|
-
if 'base' in sd:
|
84
|
-
model.load_state_dict(sd['base'], strict=False)
|
85
|
-
if 'lora' in sd:
|
86
|
-
model.load_state_dict(sd['lora'], strict=False)
|
87
|
-
if 'plugin' in sd:
|
88
|
-
model.load_state_dict(sd['plugin'], strict=False)
|
89
|
-
|
90
|
-
if model_ema is not None:
|
91
|
-
if 'base' in sd:
|
92
|
-
model_ema.load_state_dict(sd['base_ema'])
|
93
|
-
if 'lora' in sd:
|
94
|
-
model_ema.load_state_dict(sd['lora_ema'])
|
95
|
-
if 'plugin' in sd:
|
96
|
-
model_ema.load_state_dict(sd['plugin_ema'])
|
97
|
-
|
98
|
-
def save_embedding(self, train_pts, step, replace):
|
99
|
-
for k, v in train_pts.items():
|
100
|
-
save_path = os.path.join(self.save_dir, f"{k}-{step}.pt")
|
101
|
-
save_emb(save_path, v.data, replace=True)
|
102
|
-
if replace:
|
103
|
-
save_emb(f'{k}.pt', v.data, replace=True)
|
104
|
-
|
105
|
-
def save(self, step, unet, TE, lora_unet, lora_TE, all_plugin_unet, all_plugin_TE, embs, pipe):
|
106
|
-
'''
|
107
|
-
|
108
|
-
:param step:
|
109
|
-
:param unet:
|
110
|
-
:param TE:
|
111
|
-
:param lora_unet: [pos, neg]
|
112
|
-
:param lora_TE: [pos, neg]
|
113
|
-
:param all_plugin_unet:
|
114
|
-
:param all_plugin_TE:
|
115
|
-
:param emb:
|
116
|
-
:param pipe:
|
117
|
-
:return:
|
118
|
-
'''
|
119
|
-
self.save_model_with_lora(unet, lora_unet[0], model_ema=getattr(self, 'ema_unet', None), name='unet', step=step)
|
120
|
-
self.save_plugins(unet, all_plugin_unet, name='unet', step=step, model_ema=getattr(self, 'ema_unet', None))
|
121
|
-
|
122
|
-
if TE is not None:
|
123
|
-
# exclude_key: embeddings should not save with text-encoder
|
124
|
-
self.save_model_with_lora(TE, lora_TE[0], model_ema=getattr(self, 'ema_text_encoder', None),
|
125
|
-
name='text_encoder', step=step, exclude_key='emb_ex.')
|
126
|
-
self.save_plugins(TE, all_plugin_TE, name='text_encoder', step=step,
|
127
|
-
model_ema=getattr(self, 'ema_text_encoder', None))
|
128
|
-
|
129
|
-
if lora_unet[1] is not None:
|
130
|
-
self.save_model_with_lora(None, lora_unet[1], name='unet-neg', step=step)
|
131
|
-
if lora_TE[1] is not None:
|
132
|
-
self.save_model_with_lora(None, lora_TE[1], name='text_encoder-neg', step=step)
|
133
|
-
|
134
|
-
self.save_embedding(embs, step, False)
|
135
|
-
|
136
|
-
@classmethod
|
137
|
-
def load(cls, pretrained_model):
|
138
|
-
raise NotImplementedError(f'{cls} dose not support load()')
|
@@ -1,64 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
ckpt_safetensors.py
|
3
|
-
====================
|
4
|
-
:Name: save model with safetensors
|
5
|
-
:Author: Dong Ziyi
|
6
|
-
:Affiliation: HCP Lab, SYSU
|
7
|
-
:Created: 8/04/2023
|
8
|
-
:Licence: MIT
|
9
|
-
"""
|
10
|
-
|
11
|
-
import os
|
12
|
-
import torch
|
13
|
-
from safetensors import safe_open
|
14
|
-
from safetensors.torch import save_file
|
15
|
-
|
16
|
-
from .ckpt_pkl import CkptManagerPKL
|
17
|
-
|
18
|
-
class CkptManagerSafe(CkptManagerPKL):
|
19
|
-
|
20
|
-
def _save_ckpt(self, sd_model, name=None, step=None, save_path=None):
|
21
|
-
if save_path is None:
|
22
|
-
save_path = os.path.join(self.save_dir, f"{name}-{step}.safetensors")
|
23
|
-
sd_unfold = self.unfold_dict(sd_model)
|
24
|
-
for k, v in sd_unfold.items():
|
25
|
-
if not v.is_contiguous():
|
26
|
-
sd_unfold[k] = v.contiguous()
|
27
|
-
save_file(sd_unfold, save_path)
|
28
|
-
|
29
|
-
def load_ckpt(self, ckpt_path, map_location='cpu'):
|
30
|
-
with safe_open(ckpt_path, framework="pt", device=map_location) as f:
|
31
|
-
sd_fold = self.fold_dict(f)
|
32
|
-
return sd_fold
|
33
|
-
|
34
|
-
@staticmethod
|
35
|
-
def unfold_dict(data, split_key=':'):
|
36
|
-
dict_unfold={}
|
37
|
-
|
38
|
-
def unfold(prefix, dict_fold):
|
39
|
-
for k,v in dict_fold.items():
|
40
|
-
k_new = k if prefix=='' else f'{prefix}{split_key}{k}'
|
41
|
-
if isinstance(v, dict):
|
42
|
-
unfold(k_new, v)
|
43
|
-
elif isinstance(v, list) or isinstance(v, tuple):
|
44
|
-
unfold(k_new, {i:d for i,d in enumerate(v)})
|
45
|
-
else:
|
46
|
-
dict_unfold[k_new]=v
|
47
|
-
|
48
|
-
unfold('', data)
|
49
|
-
return dict_unfold
|
50
|
-
|
51
|
-
@staticmethod
|
52
|
-
def fold_dict(safe_f, split_key=':'):
|
53
|
-
dict_fold = {}
|
54
|
-
|
55
|
-
for k in safe_f.keys():
|
56
|
-
k_list = k.split(split_key)
|
57
|
-
dict_last = dict_fold
|
58
|
-
for item in k_list[:-1]:
|
59
|
-
if item not in dict_last:
|
60
|
-
dict_last[item] = {}
|
61
|
-
dict_last = dict_last[item]
|
62
|
-
dict_last[k_list[-1]]=safe_f.get_tensor(k)
|
63
|
-
|
64
|
-
return dict_fold
|