hcpdiff 0.9.1__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (210) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/controlnet.py +16 -16
  58. hcpdiff/models/lora_base_patch.py +14 -25
  59. hcpdiff/models/lora_layers.py +3 -9
  60. hcpdiff/models/lora_layers_patch.py +14 -24
  61. hcpdiff/models/text_emb_ex.py +84 -6
  62. hcpdiff/models/textencoder_ex.py +54 -18
  63. hcpdiff/models/wrapper/__init__.py +3 -0
  64. hcpdiff/models/wrapper/pixart.py +19 -0
  65. hcpdiff/models/wrapper/sd.py +218 -0
  66. hcpdiff/models/wrapper/utils.py +20 -0
  67. hcpdiff/parser/__init__.py +1 -0
  68. hcpdiff/parser/embpt.py +32 -0
  69. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  70. hcpdiff/tools/dataset_generator.py +94 -0
  71. hcpdiff/tools/download_hf_model.py +24 -0
  72. hcpdiff/tools/init_proj.py +3 -21
  73. hcpdiff/tools/lora_convert.py +18 -17
  74. hcpdiff/tools/save_model.py +12 -0
  75. hcpdiff/tools/sd2diffusers.py +1 -1
  76. hcpdiff/train_colo.py +1 -1
  77. hcpdiff/train_deepspeed.py +1 -1
  78. hcpdiff/trainer_ac.py +79 -0
  79. hcpdiff/trainer_ac_single.py +31 -0
  80. hcpdiff/utils/__init__.py +0 -2
  81. hcpdiff/utils/inpaint_pipe.py +7 -2
  82. hcpdiff/utils/net_utils.py +29 -6
  83. hcpdiff/utils/pipe_hook.py +24 -7
  84. hcpdiff/utils/utils.py +21 -4
  85. hcpdiff/workflow/__init__.py +15 -10
  86. hcpdiff/workflow/daam/__init__.py +1 -0
  87. hcpdiff/workflow/daam/act.py +66 -0
  88. hcpdiff/workflow/daam/hook.py +109 -0
  89. hcpdiff/workflow/diffusion.py +114 -125
  90. hcpdiff/workflow/fast.py +31 -0
  91. hcpdiff/workflow/flow.py +67 -0
  92. hcpdiff/workflow/io.py +36 -130
  93. hcpdiff/workflow/model.py +46 -43
  94. hcpdiff/workflow/text.py +78 -46
  95. hcpdiff/workflow/utils.py +32 -12
  96. hcpdiff/workflow/vae.py +37 -38
  97. hcpdiff-2.1.dist-info/METADATA +285 -0
  98. hcpdiff-2.1.dist-info/RECORD +114 -0
  99. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  100. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  101. hcpdiff/ckpt_manager/base.py +0 -16
  102. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  103. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  104. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  105. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  106. hcpdiff/data/bucket.py +0 -358
  107. hcpdiff/data/caption_loader.py +0 -80
  108. hcpdiff/data/cond_dataset.py +0 -40
  109. hcpdiff/data/crop_info_dataset.py +0 -40
  110. hcpdiff/data/data_processor.py +0 -33
  111. hcpdiff/data/pair_dataset.py +0 -146
  112. hcpdiff/data/sampler.py +0 -54
  113. hcpdiff/data/source/base.py +0 -30
  114. hcpdiff/data/utils.py +0 -80
  115. hcpdiff/deprecated/__init__.py +0 -1
  116. hcpdiff/deprecated/cfg_converter.py +0 -81
  117. hcpdiff/deprecated/lora_convert.py +0 -31
  118. hcpdiff/infer_workflow.py +0 -57
  119. hcpdiff/loggers/__init__.py +0 -13
  120. hcpdiff/loggers/base_logger.py +0 -76
  121. hcpdiff/loggers/cli_logger.py +0 -40
  122. hcpdiff/loggers/preview/__init__.py +0 -1
  123. hcpdiff/loggers/preview/image_previewer.py +0 -149
  124. hcpdiff/loggers/tensorboard_logger.py +0 -30
  125. hcpdiff/loggers/wandb_logger.py +0 -31
  126. hcpdiff/loggers/webui_logger.py +0 -9
  127. hcpdiff/loss/min_snr_loss.py +0 -52
  128. hcpdiff/models/layers.py +0 -81
  129. hcpdiff/models/plugin.py +0 -348
  130. hcpdiff/models/wrapper.py +0 -75
  131. hcpdiff/noise/__init__.py +0 -3
  132. hcpdiff/noise/noise_base.py +0 -16
  133. hcpdiff/noise/pyramid_noise.py +0 -50
  134. hcpdiff/noise/zero_terminal.py +0 -44
  135. hcpdiff/train_ac.py +0 -566
  136. hcpdiff/train_ac_single.py +0 -39
  137. hcpdiff/utils/caption_tools.py +0 -105
  138. hcpdiff/utils/cfg_net_tools.py +0 -321
  139. hcpdiff/utils/cfg_resolvers.py +0 -16
  140. hcpdiff/utils/ema.py +0 -52
  141. hcpdiff/utils/img_size_tool.py +0 -248
  142. hcpdiff/vis/__init__.py +0 -3
  143. hcpdiff/vis/base_interface.py +0 -12
  144. hcpdiff/vis/disk_interface.py +0 -48
  145. hcpdiff/vis/webui_interface.py +0 -17
  146. hcpdiff/viser_fast.py +0 -138
  147. hcpdiff/visualizer.py +0 -265
  148. hcpdiff/visualizer_reloadable.py +0 -237
  149. hcpdiff/workflow/base.py +0 -59
  150. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  198. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  206. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  207. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  208. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  209. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,285 @@
1
+ Metadata-Version: 2.4
2
+ Name: hcpdiff
3
+ Version: 2.1
4
+ Summary: A universal Diffusion toolbox
5
+ Home-page: https://github.com/IrisRainbowNeko/HCP-Diffusion
6
+ Author: Ziyi Dong
7
+ Author-email: rainbow-neko@outlook.com
8
+ Classifier: License :: OSI Approved :: Apache Software License
9
+ Classifier: Operating System :: OS Independent
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Programming Language :: Python :: 3.8
12
+ Classifier: Programming Language :: Python :: 3.9
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Requires-Python: >=3.8
18
+ Description-Content-Type: text/markdown
19
+ License-File: LICENSE
20
+ Requires-Dist: rainbowneko
21
+ Requires-Dist: diffusers
22
+ Requires-Dist: matplotlib
23
+ Requires-Dist: pyarrow
24
+ Requires-Dist: transformers>=4.25.1
25
+ Requires-Dist: pytorch-msssim
26
+ Requires-Dist: lmdb
27
+ Dynamic: author
28
+ Dynamic: author-email
29
+ Dynamic: classifier
30
+ Dynamic: description
31
+ Dynamic: description-content-type
32
+ Dynamic: home-page
33
+ Dynamic: license-file
34
+ Dynamic: requires-dist
35
+ Dynamic: requires-python
36
+ Dynamic: summary
37
+
38
+ # HCP-Diffusion V2
39
+
40
+ [![PyPI](https://img.shields.io/pypi/v/hcpdiff)](https://pypi.org/project/hcpdiff/)
41
+ [![GitHub stars](https://img.shields.io/github/stars/7eu7d7/HCP-Diffusion)](https://github.com/7eu7d7/HCP-Diffusion/stargazers)
42
+ [![GitHub license](https://img.shields.io/github/license/7eu7d7/HCP-Diffusion)](https://github.com/7eu7d7/HCP-Diffusion/blob/master/LICENSE)
43
+ [![codecov](https://codecov.io/gh/7eu7d7/HCP-Diffusion/branch/main/graph/badge.svg)](https://codecov.io/gh/7eu7d7/HCP-Diffusion)
44
+ [![open issues](https://isitmaintained.com/badge/open/7eu7d7/HCP-Diffusion.svg)](https://github.com/7eu7d7/HCP-Diffusion/issues)
45
+
46
+ [📘中文说明](./README_cn.md)
47
+
48
+ [📘English document](https://hcpdiff.readthedocs.io/en/latest/)
49
+ [📘中文文档](https://hcpdiff.readthedocs.io/zh_CN/latest/)
50
+
51
+ Old HCP-Diffusion V1 at [main branch](https://github.com/IrisRainbowNeko/HCP-Diffusion/tree/main)
52
+
53
+ ## Introduction
54
+
55
+ **HCP-Diffusion** is a Diffusion model toolbox built on top of the [🐱 RainbowNeko Engine](https://github.com/IrisRainbowNeko/RainbowNekoEngine).
56
+ It features a clean code structure and a flexible **Python-based configuration file**, making it easier to conduct and manage complex experiments. It includes a wide variety of training components, and compared to existing frameworks, it's more extensible, flexible, and user-friendly.
57
+
58
+ HCP-Diffusion allows you to use a single `.py` config file to unify training workflows across popular methods and model architectures, including Prompt-tuning (Textual Inversion), DreamArtist, Fine-tuning, DreamBooth, LoRA, ControlNet, ....
59
+ Different techniques can also be freely combined.
60
+
61
+ This framework also implements **DreamArtist++**, an upgraded version of DreamArtist based on LoRA. It enables high generalization and controllability with just a single image for training.
62
+ Compared to the original DreamArtist, it offers better stability, image quality, controllability, and faster training.
63
+
64
+ ---
65
+
66
+ ## Installation
67
+
68
+ Install via pip:
69
+
70
+ ```bash
71
+ pip install hcpdiff
72
+ # Initialize configuration
73
+ hcpinit
74
+ ```
75
+
76
+ Install from source:
77
+
78
+ ```bash
79
+ git clone https://github.com/7eu7d7/HCP-Diffusion.git
80
+ cd HCP-Diffusion
81
+ pip install -e .
82
+ # Initialize configuration
83
+ hcpinit
84
+ ```
85
+
86
+ Use xFormers to reduce memory usage and accelerate training:
87
+
88
+ ```bash
89
+ # Choose the appropriate xformers version for your PyTorch version
90
+ pip install xformers==?
91
+ ```
92
+
93
+ ## 🚀 Python Configuration Files
94
+ RainbowNeko Engine supports configuration files written in a Python-like syntax. This allows users to call functions and classes directly within the configuration file, with function parameters inheritable from parent configuration files. The framework automatically handles the formatting of these configuration files.
95
+
96
+ For example, consider the following configuration file:
97
+ ```python
98
+ dict(
99
+ layer=Linear(in_features=4, out_features=4)
100
+ )
101
+ ```
102
+ During parsing, this will be automatically compiled into:
103
+ ```python
104
+ dict(
105
+ layer=dict(_target_=Linear, in_features=4, out_features=4)
106
+ )
107
+ ```
108
+ After parsing, the framework will instantiate the components accordingly. This means users can write configuration files using familiar Python syntax.
109
+
110
+ ---
111
+
112
+ ## ✨ Features
113
+
114
+ <details>
115
+ <summary>Features</summary>
116
+
117
+ ### 📦 Model Support
118
+
119
+ | Model Name | Status |
120
+ |--------------------------|-------------|
121
+ | Stable Diffusion 1.5 | ✅ Supported |
122
+ | Stable Diffusion XL (SDXL)| ✅ Supported |
123
+ | PixArt | ✅ Supported |
124
+ | FLUX | 🚧 In Development |
125
+ | Stable Diffusion 3 (SD3) | 🚧 In Development |
126
+
127
+ ---
128
+
129
+ ### 🧠 Fine-Tuning Capabilities
130
+
131
+ | Feature | Description/Support |
132
+ |----------------------------------|---------------------|
133
+ | LoRA Layer-wise Configuration | ✅ Supported (including Conv2d) |
134
+ | Layer-wise Fine-Tuning | ✅ Supported |
135
+ | Multi-token Prompt-Tuning | ✅ Supported |
136
+ | Layer-wise Model Merging | ✅ Supported |
137
+ | Custom Optimizers | ✅ Supported (Lion, DAdaptation, pytorch-optimizer, etc.) |
138
+ | Custom LR Schedulers | ✅ Supported |
139
+
140
+ ---
141
+
142
+ ### 🧩 Extension Method Support
143
+
144
+ | Method | Status |
145
+ |--------------------------------|-------------|
146
+ | ControlNet (including training)| ✅ Supported |
147
+ | DreamArtist / DreamArtist++ | ✅ Supported |
148
+ | Token Attention Adjustment | ✅ Supported |
149
+ | Max Sentence Length Extension | ✅ Supported |
150
+ | Textual Inversion (Custom Tokens)| ✅ Supported |
151
+ | CLIP Skip | ✅ Supported |
152
+
153
+ ---
154
+
155
+ ### 🚀 Training Acceleration
156
+
157
+ | Tool/Library | Supported Modules |
158
+ |---------------------------------------------------|---------------------------|
159
+ | [🤗 Accelerate](https://github.com/huggingface/accelerate) | ✅ Supported |
160
+ | [Colossal-AI](https://github.com/hpcaitech/ColossalAI) | ✅ Supported |
161
+ | [xFormers](https://github.com/facebookresearch/xformers) | ✅ Supported (UNet and text encoder) |
162
+
163
+ ---
164
+
165
+ ### 🗂 Dataset Support
166
+
167
+ | Feature | Description |
168
+ |----------------------------------|-------------|
169
+ | Aspect Ratio Bucket (ARB) | ✅ Auto-clustering supported |
170
+ | Multi-source / Multi-dataset | ✅ Supported |
171
+ | LMDB | ✅ Supported |
172
+ | webdataset | 🚧 In Development |
173
+ | Local Attention Enhancement | ✅ Supported |
174
+ | Tag Shuffling & Dropout | ✅ Multiple tag editing strategies |
175
+
176
+ ---
177
+
178
+ ### 📉 Supported Loss Functions
179
+
180
+ | Loss Type | Description |
181
+ |------------|-------------|
182
+ | Min-SNR | ✅ Supported |
183
+ | SSIM | ✅ Supported |
184
+ | GWLoss | ✅ Supported |
185
+
186
+ ---
187
+
188
+ ### 🌫 Supported Diffusion Strategies
189
+
190
+ | Strategy Type | Status |
191
+ |------------------|--------------|
192
+ | DDPM | ✅ Supported |
193
+ | EDM | ✅ Supported |
194
+ | Flow Matching | ✅ Supported |
195
+
196
+ ---
197
+
198
+ ### 🧠 Automatic Evaluation (Step Selection Assistant)
199
+
200
+ | Feature | Description/Status |
201
+ |------------------|------------------------------------------|
202
+ | Image Preview | ✅ Supported (workflow preview) |
203
+ | FID | 🚧 In Development |
204
+ | CLIP Score | 🚧 In Development |
205
+ | CCIP Score | 🚧 In Development |
206
+ | Corrupt Score | 🚧 In Development |
207
+
208
+ </details>
209
+
210
+ ---
211
+
212
+ ## Getting Started
213
+
214
+ ### Training
215
+
216
+ HCP-Diffusion provides training scripts based on 🤗 Accelerate.
217
+
218
+ ```bash
219
+ # Multi-GPU training, configure GPUs in cfgs/launcher/multi.yaml
220
+ hcp_train --cfg cfgs/train/py/your_config.py
221
+
222
+ # Single-GPU training, configure GPU in cfgs/launcher/single.yaml
223
+ hcp_train_1gpu --cfg cfgs/train/py/your_config.py
224
+ ```
225
+
226
+ You can also override config items via command line:
227
+
228
+ ```bash
229
+ # Override base model path
230
+ hcp_train --cfg cfgs/train/py/your_config.py model.wrapper.models.ckpt_path=pretrained_model_path
231
+ ```
232
+
233
+ ### Image Generation
234
+
235
+ Use the workflow defined in the Python config to generate images:
236
+
237
+ ```bash
238
+ hcp_run --cfg cfgs/workflow/text2img.py
239
+ ```
240
+
241
+ Or override parameters via command line:
242
+
243
+ ```bash
244
+ hcp_run --cfg cfgs/workflow/text2img_cli.py \
245
+ pretrained_model=pretrained_model_path \
246
+ prompt='positive_prompt' \
247
+ negative_prompt='negative_prompt' \
248
+ seed=42
249
+ ```
250
+
251
+ ### Tutorials
252
+
253
+ 🚧 In Development
254
+
255
+ ---
256
+
257
+ ## Contributing
258
+
259
+ We welcome contributions to support more models and features.
260
+
261
+ ---
262
+
263
+ ## Team
264
+
265
+ Maintained by [HCP-Lab at Sun Yat-sen University](https://www.sysu-hcp.net/).
266
+
267
+ ---
268
+
269
+ ## Citation
270
+
271
+ ```bibtex
272
+ @article{DBLP:journals/corr/abs-2211-11337,
273
+ author = {Ziyi Dong and
274
+ Pengxu Wei and
275
+ Liang Lin},
276
+ title = {DreamArtist: Towards Controllable One-Shot Text-to-Image Generation
277
+ via Positive-Negative Prompt-Tuning},
278
+ journal = {CoRR},
279
+ volume = {abs/2211.11337},
280
+ year = {2022},
281
+ doi = {10.48550/arXiv.2211.11337},
282
+ eprinttype = {arXiv},
283
+ eprint = {2211.11337},
284
+ }
285
+ ```
@@ -0,0 +1,114 @@
1
+ hcpdiff/__init__.py,sha256=dwNwrEgvG4g60fGMG6b50K3q3AWD1XCfzlIgbxkSUpE,177
2
+ hcpdiff/train_colo.py,sha256=EsuNSzLBvGTZWU_LEk0JpP-F5eNW0lwkawIRAX38jmE,9250
3
+ hcpdiff/train_deepspeed.py,sha256=PwyNukWi0of6TXy_VRDgBQSMLCZBhipO5g3Lq0nCYNk,2988
4
+ hcpdiff/trainer_ac.py,sha256=6KAzo54in7ZRHud_rHjJdwRRZ4uWtc0B4SxVCxgcrmM,2990
5
+ hcpdiff/trainer_ac_single.py,sha256=0PIC5EScqcxp49EaeIWq4KS5K_09OZfKajqbFu-hUb8,1108
6
+ hcpdiff/ckpt_manager/__init__.py,sha256=LfMwz9R4jV4xpiSFt5vhpwaF7-8UHEZ_iDoW-3QGvt0,239
7
+ hcpdiff/ckpt_manager/ckpt.py,sha256=Pa3uXQbCi2T99mpV5fYddQ-OGHcpk8r1ll-0lmP_WXk,965
8
+ hcpdiff/ckpt_manager/loader.py,sha256=Ch1xsZmseq4nyPhpox9-nebN-dZB4k0rqBEHos-ZLso,3245
9
+ hcpdiff/ckpt_manager/format/__init__.py,sha256=a3cdKkOTDgdVbDQwSC4mlxOigjX2hBvRb5_X7E3TQWs,237
10
+ hcpdiff/ckpt_manager/format/diffusers.py,sha256=T81WN95Nj1il9DfQp9iioVn0uqFEWOlmdIYs2beNOFU,3769
11
+ hcpdiff/ckpt_manager/format/emb.py,sha256=FrqfTfJ8H7f0Zw17NTWCP2AJtpsJI5oXR5IAd4NekhU,680
12
+ hcpdiff/ckpt_manager/format/lora_webui.py,sha256=j7SpXnSx_Ys8tnWBgojuB1HEJIm46lhCBuNNYLhaF9w,9824
13
+ hcpdiff/ckpt_manager/format/sd_single.py,sha256=LpCAL_7nAVooCHTFznVVsNMku1G3C77NBORxxr8GDtQ,2328
14
+ hcpdiff/data/__init__.py,sha256=-z47HsEQSubc-AfriVComMACbQXlXTWAKMOPBkATHxA,258
15
+ hcpdiff/data/dataset.py,sha256=1k4GldW13eVyqK_9hrQniqr3_XYAapnWF7iXl_1GXGg,877
16
+ hcpdiff/data/cache/__init__.py,sha256=ToCmokYH6DghlSwm7HJFirPRIWJ0LkgzqVOYlgoAkQw,25
17
+ hcpdiff/data/cache/vae.py,sha256=gB89zs4CdNlvukDXhVYU9QZrY6VTFUWfzjeF2psNQ50,4070
18
+ hcpdiff/data/handler/__init__.py,sha256=D1HyqY0qfrUHgf25itpYj57JUvgn06G6EQ9d2vRRtys,236
19
+ hcpdiff/data/handler/controlnet.py,sha256=bRDMD9BP8-VaG5VrxzvcFKfkqeTbChNfrJSZ3vXbQgY,658
20
+ hcpdiff/data/handler/diffusion.py,sha256=8n60UYdGNR08xw45HoI4EB5AaIui03tSGNDfjazO-5w,3516
21
+ hcpdiff/data/handler/text.py,sha256=gOzqB2oEkEUbiuy0kZWduo0c-w4Buu60KI6q6Nyl3aM,4208
22
+ hcpdiff/data/source/__init__.py,sha256=AB1VicA272KjTm-Q5L6XvDM8CLQhVPylAPuPMtpfw4g,158
23
+ hcpdiff/data/source/folder_class.py,sha256=bs4qPMTzwcnT6ZFlT3tpi9sclsRF9a2MBA1pQD-9EYs,961
24
+ hcpdiff/data/source/text2img.py,sha256=MWXqAEbzmK6pkBY40t9u37ngY25mgdKQ2idwNld8-bo,1826
25
+ hcpdiff/data/source/text2img_cond.py,sha256=yj1KpARA2rkjENutnnzC4uDkcU2Rye21FL2VdC25Hac,585
26
+ hcpdiff/diffusion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
+ hcpdiff/diffusion/noise/__init__.py,sha256=seBpOtd0YsU53PqMn7Nyl_RtwoC-ONEIOX7v2XLGpZQ,93
28
+ hcpdiff/diffusion/noise/pyramid_noise.py,sha256=KbpyMT1BHNIaAa7g5eECDkTttOMoMWVFmbP-ekBsuEY,1693
29
+ hcpdiff/diffusion/noise/zero_terminal.py,sha256=EfVOaqrTCfw11AolDBl0LIOey3uQT1bDw2XKr2Bm434,1532
30
+ hcpdiff/diffusion/sampler/__init__.py,sha256=pSHsKpLjscY5yLbdzHeBUeK9nFDuVeMIIeA_k6FQFdY,158
31
+ hcpdiff/diffusion/sampler/base.py,sha256=2AuPVT2ZSXYt2etZmHMyNKuGlT5zn6KIkoMz4m5PGcs,2577
32
+ hcpdiff/diffusion/sampler/ddpm.py,sha256=raqSuKsEPN1AEqRVCuBdMAOnKDoeJTRO17wtLBNJCf4,523
33
+ hcpdiff/diffusion/sampler/diffusers.py,sha256=XIu-oIlT4LnAYI2-yyIoNeIMeSQe5YpeEvn-FkGVFnE,2684
34
+ hcpdiff/diffusion/sampler/edm.py,sha256=5W4pv8hxQsPpJGiFBgElZxR3C9S8kWAhzGKejEVwq3I,753
35
+ hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py,sha256=kmIoWgsWqij6b7KYon3UOCSC07sRJCo-DPR6qkJwUd0,184
36
+ hcpdiff/diffusion/sampler/sigma_scheduler/base.py,sha256=9-JI-jwf7xZoQUtrU0qfbjkhNZT8a_tmapLtwVbFUx0,381
37
+ hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py,sha256=2PMIpg2K6CVoxew1y1pIqvCHbdggC-m3amrOYk15OdQ,8107
38
+ hcpdiff/diffusion/sampler/sigma_scheduler/edm.py,sha256=fOPB3lgnS9uVo4oW26Fur_nc8X_wQ6mmUcbkKhnoQjs,1900
39
+ hcpdiff/easy/__init__.py,sha256=-emoyCOZlLCu3KNMI8L4qapUEtEYFSoiGU6-rKv1at4,149
40
+ hcpdiff/easy/sampler.py,sha256=dQSBkeGh71O0DAmZLhTHTbk1bY7XzyUCeW1oJO14A4I,1250
41
+ hcpdiff/easy/cfg/__init__.py,sha256=aVDEDPxHdX5n-aFkP_4ic8ZhQfSeKu8lZOkgW_4m398,221
42
+ hcpdiff/easy/cfg/sd15_train.py,sha256=LRCJLHNU0JEd1m3MC_NFWUCw5LmwztiLiJlV7u_DeKM,6493
43
+ hcpdiff/easy/cfg/sdxl_train.py,sha256=R0wolSVOrRlI9A-vAfz592SzSnwuDd4ku1oc5yRKrfU,4038
44
+ hcpdiff/easy/cfg/t2i.py,sha256=6Pyy4werXNalwoBBHVMBLBg67kMS85Heb7R3t26GJqQ,6871
45
+ hcpdiff/easy/model/__init__.py,sha256=CA-7r3R2Jgweekk1XNByFYttLolbWyUV2bCnXygcD8w,133
46
+ hcpdiff/easy/model/cnet.py,sha256=m0NTH9V1kLzb5GybwBrSNT0KvTcRpPfGkzUeMz9jZZQ,1084
47
+ hcpdiff/easy/model/loader.py,sha256=Tdx-lhQEYf2NYjVM1A5B8x6ZZpJKcXUkFIPIbr7h7XM,3456
48
+ hcpdiff/evaluate/__init__.py,sha256=CtNzi8xdUWZBDBcP5TZTDMcRyykaOJhBIxTJgVuMabo,35
49
+ hcpdiff/evaluate/previewer.py,sha256=QiYYiEJBKP06uL3wKLhnpIGUZuAkr9BuHxTE29obpXI,2432
50
+ hcpdiff/loss/__init__.py,sha256=2dwPczSiv3rB5fzOeYbl5ZHpMU-qXOQlXeOiXdxcxwM,173
51
+ hcpdiff/loss/base.py,sha256=3bvgMbwyPOEA9iSkv0hRHw4VnKjkUCZAENNnDMFilYM,1780
52
+ hcpdiff/loss/gw.py,sha256=0yi1kozuII3xZA6FnjOhINtvScWt1MyBZLBtMKmgojM,1224
53
+ hcpdiff/loss/ssim.py,sha256=YofadvBkc6sklxBUx1p3ADw5OHOZPK3kaHz8FH5a6m4,1281
54
+ hcpdiff/loss/vlb.py,sha256=s78iBnXUiDWfGf7mYmhUnHqxqea5gSByKOoqBrX6bzU,3222
55
+ hcpdiff/loss/weighting.py,sha256=UR2PyZ1JTNOydXMw4e1Fh52XmtwKaviPvddcmVKCTlI,2242
56
+ hcpdiff/models/__init__.py,sha256=eQS7DPiGLiE1MFRkZj_17IY3IsfDUVcYpcOmhHb5B9o,472
57
+ hcpdiff/models/cfg_context.py,sha256=e2B3K1KwJhzbD6xdJUOyNtl_XgQ0296XI3FHw3gvZF4,1502
58
+ hcpdiff/models/container.py,sha256=z3p5TmQhxdzXSIfofz55_bmEhSsgUJsy1o9EcDs8Oeo,696
59
+ hcpdiff/models/controlnet.py,sha256=VIkUzJCVpCqqQOtRSLQPfbcDy9CsXutxLeZB6PdZfA0,7809
60
+ hcpdiff/models/lora_base.py,sha256=LGwBD9KP6qf4pgTx24i5-JLo4rDBQ6jFfterQKBjTbE,6758
61
+ hcpdiff/models/lora_base_patch.py,sha256=WW3CULnROTxKXyynJiqirhHYCKN5JtxLhVpT5b7AUQg,6532
62
+ hcpdiff/models/lora_layers.py,sha256=O9W_Ue71lHj7Y_GbpioF4Hc3h2-z_zOqck93VYUra6s,7777
63
+ hcpdiff/models/lora_layers_patch.py,sha256=GYFYsJD2VSLZfdnLma9CmQEHz09HROFJcc4wc_gs9f0,8198
64
+ hcpdiff/models/text_emb_ex.py,sha256=a5QImxzvj0zWR12qXOPP9kmpESl8J9VLabA0W9D_i_c,7867
65
+ hcpdiff/models/textencoder_ex.py,sha256=JrTQ30Avx8tPbdr-Q6K5BvEWCEdsu8Z7eSOzMqpUuzg,8270
66
+ hcpdiff/models/tokenizer_ex.py,sha256=zKUn4BY7b3yXwK9PWkZtQKJPyKYwUc07E-hwB9NQybs,2446
67
+ hcpdiff/models/compose/__init__.py,sha256=lTNFTGg5csqvUuys22RqgjmWlk_7Okw6ZTsnTi1pqCg,217
68
+ hcpdiff/models/compose/compose_hook.py,sha256=FfDSfn5FuLFGM80HMUwiUopy1P4xDbvKSBDuA6QK2So,6112
69
+ hcpdiff/models/compose/compose_textencoder.py,sha256=tiFoStKOIEH9YzsZQrLki4gra18kMy3wSzSUrVQG1sk,6607
70
+ hcpdiff/models/compose/compose_tokenizer.py,sha256=g3l0pOFv6p7Iigxm6Pqt_iTUXBlO1_SWAQOt0m54IoE,3033
71
+ hcpdiff/models/compose/sdxl_composer.py,sha256=NtMGaFGZTfKsPJSVi2yT-UM6K1WKWtk99XxVmTcKlk8,2164
72
+ hcpdiff/models/wrapper/__init__.py,sha256=HbGQmFnfccr-dtvZKjEv-pmR4cCnF4fwGLKS3tuG_OY,135
73
+ hcpdiff/models/wrapper/pixart.py,sha256=nRUvHSHn4TYg_smC0xpeW-GtUgXss-MuaVPTHpMozDE,1147
74
+ hcpdiff/models/wrapper/sd.py,sha256=D7VDI4OmbLTk9mzYta-C4LJjWfZmuBiDub4t8v1-M9o,11711
75
+ hcpdiff/models/wrapper/utils.py,sha256=NyebMoAPnrgcTHbiIocSD-eGdGdD-V1G_TQuWsRWufw,665
76
+ hcpdiff/parser/__init__.py,sha256=-2dDZ2Ii4zoGQqDTme94q4PpJbBiV6HS5BsDASz4Xbo,33
77
+ hcpdiff/parser/embpt.py,sha256=LgwZ0f0tLn3DrTo5ZpSCsZcA5330UpiW_sK96yEPmOM,1307
78
+ hcpdiff/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
79
+ hcpdiff/tools/convert_caption_txt2json.py,sha256=tbBgIphJWvXUoXjtwsnLX2w9IZEY3jTgxbTvUMgukbM,945
80
+ hcpdiff/tools/convert_old_lora.py,sha256=yIP9RGcyQbwT2NNAZtTLgBXs6XJOHRvoHQep0SdqDho,453
81
+ hcpdiff/tools/create_embedding.py,sha256=qbBx3mFPeVcVSeMWn5_-3Dq7EqWn0UTOQg6OurIwDeo,4577
82
+ hcpdiff/tools/dataset_generator.py,sha256=CG28IP1w0XAlEG1xPB0jQi40hVaFGxLe76tnV6hya-E,4639
83
+ hcpdiff/tools/diffusers2sd.py,sha256=MdZirlyztaaO4UAYU_AJa25eTFKfJa08nbHbdcqOdhs,13300
84
+ hcpdiff/tools/download_hf_model.py,sha256=3XZE-GYotE0W5ZQNk78ljZgolYhjbNKxfU-tB9I0yDY,947
85
+ hcpdiff/tools/embedding_convert.py,sha256=JN56dVFcMng8kq9YGS02_Umco71UjDYvwQGeqbhsJXU,1498
86
+ hcpdiff/tools/gen_from_ptlist.py,sha256=j4_wNd9JmsWUOcgkP4q4P1gB1SVnzy0VxY4pJqXM07Q,3105
87
+ hcpdiff/tools/init_proj.py,sha256=XrXxxhIaItywG7HsrloJo-x8w9suZiY35daelzZvjrg,194
88
+ hcpdiff/tools/lora_convert.py,sha256=So14WvSVIm6rU4m1XCajFXDnhq7abpZS95SLbaoyBFU,10058
89
+ hcpdiff/tools/save_model.py,sha256=gbfYi_EfEBZEUcDjle6MDHA19sQWY0zA8_y_LMzHQ7M,428
90
+ hcpdiff/tools/sd2diffusers.py,sha256=vB6OnBLw60sJkdpVZcYEPtKAZW1h8ErbSGSRq0uAiIk,16855
91
+ hcpdiff/utils/__init__.py,sha256=VOLhdNq2mRyqmWxrssIWSZtR_PQ8rFwo2u0uq6GbLHA,45
92
+ hcpdiff/utils/colo_utils.py,sha256=JyLUvVnISa48CnryNLrgVxMo-jxu2UhBq70eYPrkjuI,837
93
+ hcpdiff/utils/inpaint_pipe.py,sha256=CRy1MUlPmHifCAbZnKOP0qbLp2grn7ZbVeaB2qIA4ig,42862
94
+ hcpdiff/utils/net_utils.py,sha256=gdwLYDNKV2t3SP0jBIO3d0HtY6E7jRaf_rmPT8gKZZE,9762
95
+ hcpdiff/utils/pipe_hook.py,sha256=-UDX3FtZGl-bxSk13gdbPXc1OvtbCcpk_fvKxLQo3Ag,31987
96
+ hcpdiff/utils/utils.py,sha256=hZnZP1IETgVpScxES0yIuRfc34TnzvAqmgOTK_56ssw,4976
97
+ hcpdiff/workflow/__init__.py,sha256=t7Zyc0XFORdNvcwHp9AsCtEkhJ3l7Hm41ugngIL0Sag,867
98
+ hcpdiff/workflow/diffusion.py,sha256=yrl2cXE2d2FNeVzYZDRQNLjy5-QnVgOWioIHSmszk2Y,8662
99
+ hcpdiff/workflow/fast.py,sha256=kZt7bKrvpFInSn7GzbkTkpoCSM0Z6IbDjgaDvcbFYf8,1024
100
+ hcpdiff/workflow/flow.py,sha256=FFbFFOAXT4c31L5bHBEB_qeVGuBQDLYhq8kTD1chGNo,2548
101
+ hcpdiff/workflow/io.py,sha256=aTrMR3s44apVJpnSyvZIabW2Op0tslk_Z9JFJl5svm0,2635
102
+ hcpdiff/workflow/model.py,sha256=1gj5yOTefYTnGXVR6JPAfxIwuB69YwN6E-BontRcuyQ,2913
103
+ hcpdiff/workflow/text.py,sha256=FSFUm_zEeZjMeg0qRXZAPplnJkg2pR_2FA3XljpoN2w,5110
104
+ hcpdiff/workflow/utils.py,sha256=xojaMG4lHsymslc8df5uiVXmmBVWpn_Phqka8qzJEWw,2226
105
+ hcpdiff/workflow/vae.py,sha256=cingDPkIOc4qGpOwwhXJK4EQbGoIxO583pm6gGov5t8,3118
106
+ hcpdiff/workflow/daam/__init__.py,sha256=ySIDaxloN-D3qM7OuVaG1BR3D-CibDoXYpoTgw0zUhU,59
107
+ hcpdiff/workflow/daam/act.py,sha256=tHbsFWTYYU4bvcZOo1Bpi_z6ofpJatRYccl4vvf8wIA,2756
108
+ hcpdiff/workflow/daam/hook.py,sha256=z9f9mBjKW21xuUZ-iQxQ0HbWOBXtZrisFB0VNMq6d0U,4383
109
+ hcpdiff-2.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
110
+ hcpdiff-2.1.dist-info/METADATA,sha256=NpBZuj23d1gTKPQhJ0TBRV8QsfICa4LCGSk6PJNniSw,9248
111
+ hcpdiff-2.1.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
112
+ hcpdiff-2.1.dist-info/entry_points.txt,sha256=86wPOMzsfWWflTJ-sQPLc7WG5Vtu0kGYBH9C_vR3ur8,207
113
+ hcpdiff-2.1.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
114
+ hcpdiff-2.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.42.0)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -0,0 +1,5 @@
1
+ [console_scripts]
2
+ hcp_run = rainbowneko.infer.infer_workflow:run_workflow
3
+ hcp_train = hcpdiff.trainer_ac:hcp_train
4
+ hcp_train_1gpu = hcpdiff.trainer_ac_single:hcp_train
5
+ hcpinit = hcpdiff.tools.init_proj:main
@@ -1,16 +0,0 @@
1
- from diffusers import StableDiffusionPipeline
2
- from diffusers.models.lora import LoRACompatibleLinear
3
-
4
- class CkptManagerBase:
5
- def __init__(self, **kwargs):
6
- pass
7
-
8
- def set_save_dir(self, save_dir, emb_dir=None):
9
- raise NotImplementedError()
10
-
11
- def save(self, step, unet, TE, lora_unet, lora_TE, all_plugin_unet, all_plugin_TE, embs, pipe: StableDiffusionPipeline, **kwargs):
12
- raise NotImplementedError()
13
-
14
- @classmethod
15
- def load(cls, pretrained_model, **kwargs) -> StableDiffusionPipeline:
16
- raise NotImplementedError
@@ -1,45 +0,0 @@
1
- from .base import CkptManagerBase
2
- import os
3
- from diffusers import StableDiffusionPipeline, UNet2DConditionModel
4
- from hcpdiff.models.plugin import BasePluginBlock
5
-
6
-
7
- class CkptManagerDiffusers(CkptManagerBase):
8
-
9
- def set_save_dir(self, save_dir, emb_dir=None):
10
- os.makedirs(save_dir, exist_ok=True)
11
- self.save_dir = save_dir
12
- self.emb_dir = emb_dir
13
-
14
- def save(self, step, unet, TE, lora_unet, lora_TE, all_plugin_unet, all_plugin_TE, embs, pipe: StableDiffusionPipeline, **kwargs):
15
- def state_dict_unet(*args, model=unet, **kwargs):
16
- plugin_names = {k for k, v in model.named_modules() if isinstance(v, BasePluginBlock)}
17
- model_sd = {}
18
- for k, v in model.state_dict_().items():
19
- for name in plugin_names:
20
- if k.startswith(name):
21
- break
22
- else:
23
- model_sd[k] = v
24
- return model_sd
25
- unet.state_dict_ = unet.state_dict
26
- unet.state_dict = state_dict_unet
27
-
28
- def state_dict_TE(*args, model=TE, **kwargs):
29
- plugin_names = {k for k, v in model.named_modules() if isinstance(v, BasePluginBlock)}
30
- model_sd = {}
31
- for k, v in model.state_dict_().items():
32
- for name in plugin_names:
33
- if k.startswith(name):
34
- break
35
- else:
36
- model_sd[k] = v
37
- return model_sd
38
- TE.state_dict_ = TE.state_dict
39
- TE.state_dict = state_dict_TE
40
-
41
- pipe.save_pretrained(os.path.join(self.save_dir, f"model-{step}"), **kwargs)
42
-
43
- @classmethod
44
- def load(cls, pretrained_model, **kwargs) -> StableDiffusionPipeline:
45
- return StableDiffusionPipeline.from_pretrained(pretrained_model, **kwargs)
@@ -1,138 +0,0 @@
1
- """
2
- ckpt_pkl.py
3
- ====================
4
- :Name: save model with torch
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 8/04/2023
8
- :Licence: MIT
9
- """
10
-
11
- from typing import Dict
12
- import os
13
-
14
- import torch
15
- from torch import nn
16
-
17
- from hcpdiff.models.lora_base import LoraBlock, LoraGroup, split_state
18
- from hcpdiff.models.plugin import PluginGroup, BasePluginBlock
19
- from hcpdiff.utils.net_utils import save_emb
20
- from .base import CkptManagerBase
21
-
22
- class CkptManagerPKL(CkptManagerBase):
23
- def __init__(self, plugin_from_raw=False, **kwargs):
24
- self.plugin_from_raw = plugin_from_raw
25
-
26
- def set_save_dir(self, save_dir, emb_dir=None):
27
- os.makedirs(save_dir, exist_ok=True)
28
- self.save_dir = save_dir
29
- self.emb_dir = emb_dir
30
-
31
- def exclude_state(self, state, key):
32
- if key is None:
33
- return state
34
- else:
35
- return {k: v for k, v in state.items() if key not in k}
36
-
37
- def save_model(self, model: nn.Module, name, step, model_ema=None, exclude_key=None):
38
- sd_model = {
39
- 'base': self.exclude_state(LoraBlock.extract_trainable_state_without_lora(model), exclude_key),
40
- }
41
- if model_ema is not None:
42
- sd_ema, sd_ema_lora = split_state(model_ema.state_dict())
43
- sd_model['base_ema'] = self.exclude_state(sd_ema, exclude_key)
44
- self._save_ckpt(sd_model, name, step)
45
-
46
- def save_plugins(self, host_model: nn.Module, plugins: Dict[str, PluginGroup], name:str, step:int, model_ema=None):
47
- if len(plugins)>0:
48
- sd_plugin={}
49
- for plugin_name, plugin in plugins.items():
50
- sd_plugin['plugin'] = plugin.state_dict(host_model if self.plugin_from_raw else None)
51
- if model_ema is not None:
52
- sd_plugin['plugin_ema'] = plugin.state_dict(model_ema)
53
- self._save_ckpt(sd_plugin, f'{name}-{plugin_name}', step)
54
-
55
- def save_model_with_lora(self, model: nn.Module, lora_blocks: LoraGroup, name:str, step:int, model_ema=None,
56
- exclude_key=None):
57
- sd_model = {
58
- 'base': self.exclude_state(BasePluginBlock.extract_state_without_plugin(model, trainable=True), exclude_key),
59
- } if model is not None else {}
60
- if (lora_blocks is not None) and (not lora_blocks.empty()):
61
- sd_model['lora'] = lora_blocks.state_dict(model if self.plugin_from_raw else None)
62
-
63
- if model_ema is not None:
64
- ema_state = model_ema.state_dict()
65
- if model is not None:
66
- sd_ema = {k:ema_state[k] for k in sd_model['base'].keys()}
67
- sd_model['base_ema'] = self.exclude_state(sd_ema, exclude_key)
68
- if (lora_blocks is not None) and (not lora_blocks.empty()):
69
- sd_model['lora_ema'] = lora_blocks.state_dict(model_ema)
70
-
71
- self._save_ckpt(sd_model, name, step)
72
-
73
- def _save_ckpt(self, sd_model, name=None, step=None, save_path=None):
74
- if save_path is None:
75
- save_path = os.path.join(self.save_dir, f"{name}-{step}.ckpt")
76
- torch.save(sd_model, save_path)
77
-
78
- def load_ckpt(self, ckpt_path, map_location='cpu'):
79
- return torch.load(ckpt_path, map_location=map_location)
80
-
81
- def load_ckpt_to_model(self, model: nn.Module, ckpt_path, model_ema=None):
82
- sd = self.load_ckpt(ckpt_path)
83
- if 'base' in sd:
84
- model.load_state_dict(sd['base'], strict=False)
85
- if 'lora' in sd:
86
- model.load_state_dict(sd['lora'], strict=False)
87
- if 'plugin' in sd:
88
- model.load_state_dict(sd['plugin'], strict=False)
89
-
90
- if model_ema is not None:
91
- if 'base' in sd:
92
- model_ema.load_state_dict(sd['base_ema'])
93
- if 'lora' in sd:
94
- model_ema.load_state_dict(sd['lora_ema'])
95
- if 'plugin' in sd:
96
- model_ema.load_state_dict(sd['plugin_ema'])
97
-
98
- def save_embedding(self, train_pts, step, replace):
99
- for k, v in train_pts.items():
100
- save_path = os.path.join(self.save_dir, f"{k}-{step}.pt")
101
- save_emb(save_path, v.data, replace=True)
102
- if replace:
103
- save_emb(f'{k}.pt', v.data, replace=True)
104
-
105
- def save(self, step, unet, TE, lora_unet, lora_TE, all_plugin_unet, all_plugin_TE, embs, pipe):
106
- '''
107
-
108
- :param step:
109
- :param unet:
110
- :param TE:
111
- :param lora_unet: [pos, neg]
112
- :param lora_TE: [pos, neg]
113
- :param all_plugin_unet:
114
- :param all_plugin_TE:
115
- :param emb:
116
- :param pipe:
117
- :return:
118
- '''
119
- self.save_model_with_lora(unet, lora_unet[0], model_ema=getattr(self, 'ema_unet', None), name='unet', step=step)
120
- self.save_plugins(unet, all_plugin_unet, name='unet', step=step, model_ema=getattr(self, 'ema_unet', None))
121
-
122
- if TE is not None:
123
- # exclude_key: embeddings should not save with text-encoder
124
- self.save_model_with_lora(TE, lora_TE[0], model_ema=getattr(self, 'ema_text_encoder', None),
125
- name='text_encoder', step=step, exclude_key='emb_ex.')
126
- self.save_plugins(TE, all_plugin_TE, name='text_encoder', step=step,
127
- model_ema=getattr(self, 'ema_text_encoder', None))
128
-
129
- if lora_unet[1] is not None:
130
- self.save_model_with_lora(None, lora_unet[1], name='unet-neg', step=step)
131
- if lora_TE[1] is not None:
132
- self.save_model_with_lora(None, lora_TE[1], name='text_encoder-neg', step=step)
133
-
134
- self.save_embedding(embs, step, False)
135
-
136
- @classmethod
137
- def load(cls, pretrained_model):
138
- raise NotImplementedError(f'{cls} dose not support load()')
@@ -1,64 +0,0 @@
1
- """
2
- ckpt_safetensors.py
3
- ====================
4
- :Name: save model with safetensors
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 8/04/2023
8
- :Licence: MIT
9
- """
10
-
11
- import os
12
- import torch
13
- from safetensors import safe_open
14
- from safetensors.torch import save_file
15
-
16
- from .ckpt_pkl import CkptManagerPKL
17
-
18
- class CkptManagerSafe(CkptManagerPKL):
19
-
20
- def _save_ckpt(self, sd_model, name=None, step=None, save_path=None):
21
- if save_path is None:
22
- save_path = os.path.join(self.save_dir, f"{name}-{step}.safetensors")
23
- sd_unfold = self.unfold_dict(sd_model)
24
- for k, v in sd_unfold.items():
25
- if not v.is_contiguous():
26
- sd_unfold[k] = v.contiguous()
27
- save_file(sd_unfold, save_path)
28
-
29
- def load_ckpt(self, ckpt_path, map_location='cpu'):
30
- with safe_open(ckpt_path, framework="pt", device=map_location) as f:
31
- sd_fold = self.fold_dict(f)
32
- return sd_fold
33
-
34
- @staticmethod
35
- def unfold_dict(data, split_key=':'):
36
- dict_unfold={}
37
-
38
- def unfold(prefix, dict_fold):
39
- for k,v in dict_fold.items():
40
- k_new = k if prefix=='' else f'{prefix}{split_key}{k}'
41
- if isinstance(v, dict):
42
- unfold(k_new, v)
43
- elif isinstance(v, list) or isinstance(v, tuple):
44
- unfold(k_new, {i:d for i,d in enumerate(v)})
45
- else:
46
- dict_unfold[k_new]=v
47
-
48
- unfold('', data)
49
- return dict_unfold
50
-
51
- @staticmethod
52
- def fold_dict(safe_f, split_key=':'):
53
- dict_fold = {}
54
-
55
- for k in safe_f.keys():
56
- k_list = k.split(split_key)
57
- dict_last = dict_fold
58
- for item in k_list[:-1]:
59
- if item not in dict_last:
60
- dict_last[item] = {}
61
- dict_last = dict_last[item]
62
- dict_last[k_list[-1]]=safe_f.get_tensor(k)
63
-
64
- return dict_fold