hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +90 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +3 -3
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text.py +40 -0
  21. hcpdiff/data/source/text2img.py +36 -74
  22. hcpdiff/data/source/text2img_cond.py +9 -15
  23. hcpdiff/diffusion/__init__.py +0 -0
  24. hcpdiff/diffusion/noise/__init__.py +2 -0
  25. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  26. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  27. hcpdiff/diffusion/sampler/__init__.py +5 -0
  28. hcpdiff/diffusion/sampler/base.py +72 -0
  29. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  30. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  31. hcpdiff/diffusion/sampler/edm.py +22 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  35. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  36. hcpdiff/easy/__init__.py +2 -0
  37. hcpdiff/easy/cfg/__init__.py +3 -0
  38. hcpdiff/easy/cfg/sd15_train.py +207 -0
  39. hcpdiff/easy/cfg/sdxl_train.py +147 -0
  40. hcpdiff/easy/cfg/t2i.py +228 -0
  41. hcpdiff/easy/model/__init__.py +2 -0
  42. hcpdiff/easy/model/cnet.py +31 -0
  43. hcpdiff/easy/model/loader.py +79 -0
  44. hcpdiff/easy/sampler.py +46 -0
  45. hcpdiff/evaluate/__init__.py +1 -0
  46. hcpdiff/evaluate/previewer.py +60 -0
  47. hcpdiff/loss/__init__.py +4 -1
  48. hcpdiff/loss/base.py +41 -0
  49. hcpdiff/loss/gw.py +35 -0
  50. hcpdiff/loss/ssim.py +37 -0
  51. hcpdiff/loss/vlb.py +79 -0
  52. hcpdiff/loss/weighting.py +66 -0
  53. hcpdiff/models/__init__.py +2 -2
  54. hcpdiff/models/cfg_context.py +17 -14
  55. hcpdiff/models/compose/compose_hook.py +44 -23
  56. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  57. hcpdiff/models/compose/sdxl_composer.py +4 -4
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/init_proj.py +3 -21
  74. hcpdiff/tools/lora_convert.py +18 -17
  75. hcpdiff/tools/save_model.py +12 -0
  76. hcpdiff/tools/sd2diffusers.py +1 -1
  77. hcpdiff/train_colo.py +1 -1
  78. hcpdiff/train_deepspeed.py +1 -1
  79. hcpdiff/trainer_ac.py +79 -0
  80. hcpdiff/trainer_ac_single.py +31 -0
  81. hcpdiff/utils/__init__.py +0 -2
  82. hcpdiff/utils/inpaint_pipe.py +7 -2
  83. hcpdiff/utils/net_utils.py +29 -6
  84. hcpdiff/utils/pipe_hook.py +24 -7
  85. hcpdiff/utils/utils.py +21 -4
  86. hcpdiff/workflow/__init__.py +15 -10
  87. hcpdiff/workflow/daam/__init__.py +1 -0
  88. hcpdiff/workflow/daam/act.py +66 -0
  89. hcpdiff/workflow/daam/hook.py +109 -0
  90. hcpdiff/workflow/diffusion.py +118 -128
  91. hcpdiff/workflow/fast.py +31 -0
  92. hcpdiff/workflow/flow.py +67 -0
  93. hcpdiff/workflow/io.py +36 -130
  94. hcpdiff/workflow/model.py +46 -43
  95. hcpdiff/workflow/text.py +60 -47
  96. hcpdiff/workflow/utils.py +32 -12
  97. hcpdiff/workflow/vae.py +37 -38
  98. hcpdiff-2.2.dist-info/METADATA +299 -0
  99. hcpdiff-2.2.dist-info/RECORD +115 -0
  100. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
  101. hcpdiff-2.2.dist-info/entry_points.txt +5 -0
  102. hcpdiff/ckpt_manager/base.py +0 -16
  103. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  104. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  105. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  106. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  107. hcpdiff/data/bucket.py +0 -358
  108. hcpdiff/data/caption_loader.py +0 -80
  109. hcpdiff/data/cond_dataset.py +0 -40
  110. hcpdiff/data/crop_info_dataset.py +0 -40
  111. hcpdiff/data/data_processor.py +0 -33
  112. hcpdiff/data/pair_dataset.py +0 -146
  113. hcpdiff/data/sampler.py +0 -54
  114. hcpdiff/data/source/base.py +0 -30
  115. hcpdiff/data/utils.py +0 -80
  116. hcpdiff/deprecated/__init__.py +0 -1
  117. hcpdiff/deprecated/cfg_converter.py +0 -81
  118. hcpdiff/deprecated/lora_convert.py +0 -31
  119. hcpdiff/infer_workflow.py +0 -57
  120. hcpdiff/loggers/__init__.py +0 -13
  121. hcpdiff/loggers/base_logger.py +0 -76
  122. hcpdiff/loggers/cli_logger.py +0 -40
  123. hcpdiff/loggers/preview/__init__.py +0 -1
  124. hcpdiff/loggers/preview/image_previewer.py +0 -149
  125. hcpdiff/loggers/tensorboard_logger.py +0 -30
  126. hcpdiff/loggers/wandb_logger.py +0 -31
  127. hcpdiff/loggers/webui_logger.py +0 -9
  128. hcpdiff/loss/min_snr_loss.py +0 -52
  129. hcpdiff/models/layers.py +0 -81
  130. hcpdiff/models/plugin.py +0 -348
  131. hcpdiff/models/wrapper.py +0 -75
  132. hcpdiff/noise/__init__.py +0 -3
  133. hcpdiff/noise/noise_base.py +0 -16
  134. hcpdiff/noise/pyramid_noise.py +0 -50
  135. hcpdiff/noise/zero_terminal.py +0 -44
  136. hcpdiff/train_ac.py +0 -566
  137. hcpdiff/train_ac_single.py +0 -39
  138. hcpdiff/utils/caption_tools.py +0 -105
  139. hcpdiff/utils/cfg_net_tools.py +0 -321
  140. hcpdiff/utils/cfg_resolvers.py +0 -16
  141. hcpdiff/utils/ema.py +0 -52
  142. hcpdiff/utils/img_size_tool.py +0 -248
  143. hcpdiff/vis/__init__.py +0 -3
  144. hcpdiff/vis/base_interface.py +0 -12
  145. hcpdiff/vis/disk_interface.py +0 -48
  146. hcpdiff/vis/webui_interface.py +0 -17
  147. hcpdiff/viser_fast.py +0 -138
  148. hcpdiff/visualizer.py +0 -265
  149. hcpdiff/visualizer_reloadable.py +0 -237
  150. hcpdiff/workflow/base.py +0 -59
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  198. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  206. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  207. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  208. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  209. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
  211. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,299 @@
1
+ Metadata-Version: 2.4
2
+ Name: hcpdiff
3
+ Version: 2.2
4
+ Summary: A universal Diffusion toolbox
5
+ Home-page: https://github.com/IrisRainbowNeko/HCP-Diffusion
6
+ Author: Ziyi Dong
7
+ Author-email: rainbow-neko@outlook.com
8
+ Classifier: License :: OSI Approved :: Apache Software License
9
+ Classifier: Operating System :: OS Independent
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Programming Language :: Python :: 3.8
12
+ Classifier: Programming Language :: Python :: 3.9
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Requires-Python: >=3.8
18
+ Description-Content-Type: text/markdown
19
+ License-File: LICENSE
20
+ Requires-Dist: rainbowneko
21
+ Requires-Dist: diffusers
22
+ Requires-Dist: matplotlib
23
+ Requires-Dist: pyarrow
24
+ Requires-Dist: transformers>=4.25.1
25
+ Requires-Dist: pytorch-msssim
26
+ Requires-Dist: lmdb
27
+ Dynamic: author
28
+ Dynamic: author-email
29
+ Dynamic: classifier
30
+ Dynamic: description
31
+ Dynamic: description-content-type
32
+ Dynamic: home-page
33
+ Dynamic: license-file
34
+ Dynamic: requires-dist
35
+ Dynamic: requires-python
36
+ Dynamic: summary
37
+
38
+ # HCP-Diffusion V2
39
+
40
+ [![PyPI](https://img.shields.io/pypi/v/hcpdiff)](https://pypi.org/project/hcpdiff/)
41
+ [![GitHub stars](https://img.shields.io/github/stars/7eu7d7/HCP-Diffusion)](https://github.com/7eu7d7/HCP-Diffusion/stargazers)
42
+ [![GitHub license](https://img.shields.io/github/license/7eu7d7/HCP-Diffusion)](https://github.com/7eu7d7/HCP-Diffusion/blob/master/LICENSE)
43
+ [![codecov](https://codecov.io/gh/7eu7d7/HCP-Diffusion/branch/main/graph/badge.svg)](https://codecov.io/gh/7eu7d7/HCP-Diffusion)
44
+ [![open issues](https://isitmaintained.com/badge/open/7eu7d7/HCP-Diffusion.svg)](https://github.com/7eu7d7/HCP-Diffusion/issues)
45
+
46
+ [📘中文说明](./README_cn.md)
47
+
48
+ [📘English document](https://hcpdiff.readthedocs.io/en/latest/)
49
+ [📘中文文档](https://hcpdiff.readthedocs.io/zh_CN/latest/)
50
+
51
+ Old HCP-Diffusion V1 at [main branch](https://github.com/IrisRainbowNeko/HCP-Diffusion/tree/main)
52
+
53
+ ## Introduction
54
+
55
+ **HCP-Diffusion** is a Diffusion model toolbox built on top of the [🐱 RainbowNeko Engine](https://github.com/IrisRainbowNeko/RainbowNekoEngine).
56
+ It features a clean code structure and a flexible **Python-based configuration file**, making it easier to conduct and manage complex experiments. It includes a wide variety of training components, and compared to existing frameworks, it's more extensible, flexible, and user-friendly.
57
+
58
+ HCP-Diffusion allows you to use a single `.py` config file to unify training workflows across popular methods and model architectures, including Prompt-tuning (Textual Inversion), DreamArtist, Fine-tuning, DreamBooth, LoRA, ControlNet, ....
59
+ Different techniques can also be freely combined.
60
+
61
+ This framework also implements **DreamArtist++**, an upgraded version of DreamArtist based on LoRA. It enables high generalization and controllability with just a single image for training.
62
+ Compared to the original DreamArtist, it offers better stability, image quality, controllability, and faster training.
63
+
64
+ ---
65
+
66
+ ## Installation
67
+
68
+ Install [pytorch](https://pytorch.org/)
69
+
70
+ Install via pip:
71
+
72
+ ```bash
73
+ pip install hcpdiff
74
+ # Initialize configuration
75
+ hcpinit
76
+ ```
77
+
78
+ Install from source:
79
+
80
+ ```bash
81
+ git clone https://github.com/7eu7d7/HCP-Diffusion.git
82
+ cd HCP-Diffusion
83
+ pip install -e .
84
+ # Initialize configuration
85
+ hcpinit
86
+ ```
87
+
88
+ Use xFormers to reduce memory usage and accelerate training:
89
+
90
+ ```bash
91
+ # Choose the appropriate xformers version for your PyTorch version
92
+ pip install xformers==?
93
+ ```
94
+
95
+ ## 🚀 Python Configuration Files
96
+ RainbowNeko Engine supports configuration files written in a Python-like syntax. This allows users to call functions and classes directly within the configuration file, with function parameters inheritable from parent configuration files. The framework automatically handles the formatting of these configuration files.
97
+
98
+ For example, consider the following configuration file:
99
+ ```python
100
+ dict(
101
+ layer=Linear(in_features=4, out_features=4)
102
+ )
103
+ ```
104
+ During parsing, this will be automatically compiled into:
105
+ ```python
106
+ dict(
107
+ layer=dict(_target_=Linear, in_features=4, out_features=4)
108
+ )
109
+ ```
110
+ After parsing, the framework will instantiate the components accordingly. This means users can write configuration files using familiar Python syntax.
111
+
112
+ ---
113
+
114
+ ## ✨ Features
115
+
116
+ <details>
117
+ <summary>Features</summary>
118
+
119
+ ### 📦 Model Support
120
+
121
+ | Model Name | Status |
122
+ |--------------------------|-------------|
123
+ | Stable Diffusion 1.5 | ✅ Supported |
124
+ | Stable Diffusion XL (SDXL)| ✅ Supported |
125
+ | PixArt | ✅ Supported |
126
+ | FLUX | 🚧 In Development |
127
+ | Stable Diffusion 3 (SD3) | 🚧 In Development |
128
+
129
+ ---
130
+
131
+ ### 🧠 Fine-Tuning Capabilities
132
+
133
+ | Feature | Description/Support |
134
+ |----------------------------------|---------------------|
135
+ | LoRA Layer-wise Configuration | ✅ Supported (including Conv2d) |
136
+ | Layer-wise Fine-Tuning | ✅ Supported |
137
+ | Multi-token Prompt-Tuning | ✅ Supported |
138
+ | Layer-wise Model Merging | ✅ Supported |
139
+ | Custom Optimizers | ✅ Supported (Lion, DAdaptation, pytorch-optimizer, etc.) |
140
+ | Custom LR Schedulers | ✅ Supported |
141
+
142
+ ---
143
+
144
+ ### 🧩 Extension Method Support
145
+
146
+ | Method | Status |
147
+ |--------------------------------|-------------|
148
+ | ControlNet (including training)| ✅ Supported |
149
+ | DreamArtist / DreamArtist++ | ✅ Supported |
150
+ | Token Attention Adjustment | ✅ Supported |
151
+ | Max Sentence Length Extension | ✅ Supported |
152
+ | Textual Inversion (Custom Tokens)| ✅ Supported |
153
+ | CLIP Skip | ✅ Supported |
154
+
155
+ ---
156
+
157
+ ### 🚀 Training Acceleration
158
+
159
+ | Tool/Library | Supported Modules |
160
+ |---------------------------------------------------|---------------------------|
161
+ | [🤗 Accelerate](https://github.com/huggingface/accelerate) | ✅ Supported |
162
+ | [Colossal-AI](https://github.com/hpcaitech/ColossalAI) | ✅ Supported |
163
+ | [xFormers](https://github.com/facebookresearch/xformers) | ✅ Supported (UNet and text encoder) |
164
+
165
+ ---
166
+
167
+ ### 🗂 Dataset Support
168
+
169
+ | Feature | Description |
170
+ |----------------------------------|-------------|
171
+ | Aspect Ratio Bucket (ARB) | ✅ Auto-clustering supported |
172
+ | Multi-source / Multi-dataset | ✅ Supported |
173
+ | LMDB | ✅ Supported |
174
+ | webdataset | 🚧 In Development |
175
+ | Local Attention Enhancement | ✅ Supported |
176
+ | Tag Shuffling & Dropout | ✅ Multiple tag editing strategies |
177
+
178
+ ---
179
+
180
+ ### 📉 Supported Loss Functions
181
+
182
+ | Loss Type | Description |
183
+ |------------|-------------|
184
+ | Min-SNR | ✅ Supported |
185
+ | SSIM | ✅ Supported |
186
+ | GWLoss | ✅ Supported |
187
+
188
+ ---
189
+
190
+ ### 🌫 Supported Diffusion Strategies
191
+
192
+ | Strategy Type | Status |
193
+ |------------------|--------------|
194
+ | DDPM | ✅ Supported |
195
+ | EDM | ✅ Supported |
196
+ | Flow Matching | ✅ Supported |
197
+
198
+ ---
199
+
200
+ ### 🧠 Automatic Evaluation (Step Selection Assistant)
201
+
202
+ | Feature | Description/Status |
203
+ |------------------|------------------------------------------|
204
+ | Image Preview | ✅ Supported (workflow preview) |
205
+ | FID | 🚧 In Development |
206
+ | CLIP Score | 🚧 In Development |
207
+ | CCIP Score | 🚧 In Development |
208
+ | Corrupt Score | 🚧 In Development |
209
+
210
+ ---
211
+
212
+ ### ⚡️ Image Generation
213
+
214
+ | 功能 | 描述/支持情况 |
215
+ |------------------------------|------------------------------------|
216
+ | Batch Generation | ✅ Supported |
217
+ | Generate from Prompt Dataset | ✅ Supported |
218
+ | Image to Image | ✅ Supported |
219
+ | Inpaint | ✅ Supported |
220
+ | Token Weight | ✅ Supported |
221
+
222
+ </details>
223
+
224
+ ---
225
+
226
+ ## Getting Started
227
+
228
+ ### Training
229
+
230
+ HCP-Diffusion provides training scripts based on 🤗 Accelerate.
231
+
232
+ ```bash
233
+ # Multi-GPU training, configure GPUs in cfgs/launcher/multi.yaml
234
+ hcp_train --cfg cfgs/train/py/your_config.py
235
+
236
+ # Single-GPU training, configure GPU in cfgs/launcher/single.yaml
237
+ hcp_train_1gpu --cfg cfgs/train/py/your_config.py
238
+ ```
239
+
240
+ You can also override config items via command line:
241
+
242
+ ```bash
243
+ # Override base model path
244
+ hcp_train --cfg cfgs/train/py/your_config.py model.wrapper.models.ckpt_path=pretrained_model_path
245
+ ```
246
+
247
+ ### Image Generation
248
+
249
+ Use the workflow defined in the Python config to generate images:
250
+
251
+ ```bash
252
+ hcp_run --cfg cfgs/workflow/text2img.py
253
+ ```
254
+
255
+ Or override parameters via command line:
256
+
257
+ ```bash
258
+ hcp_run --cfg cfgs/workflow/text2img_cli.py \
259
+ pretrained_model=pretrained_model_path \
260
+ prompt='positive_prompt' \
261
+ negative_prompt='negative_prompt' \
262
+ seed=42
263
+ ```
264
+
265
+ ### Tutorials
266
+
267
+ 🚧 In Development
268
+
269
+ ---
270
+
271
+ ## Contributing
272
+
273
+ We welcome contributions to support more models and features.
274
+
275
+ ---
276
+
277
+ ## Team
278
+
279
+ Maintained by [HCP-Lab at Sun Yat-sen University](https://www.sysu-hcp.net/).
280
+
281
+ ---
282
+
283
+ ## Citation
284
+
285
+ ```bibtex
286
+ @article{DBLP:journals/corr/abs-2211-11337,
287
+ author = {Ziyi Dong and
288
+ Pengxu Wei and
289
+ Liang Lin},
290
+ title = {DreamArtist: Towards Controllable One-Shot Text-to-Image Generation
291
+ via Positive-Negative Prompt-Tuning},
292
+ journal = {CoRR},
293
+ volume = {abs/2211.11337},
294
+ year = {2022},
295
+ doi = {10.48550/arXiv.2211.11337},
296
+ eprinttype = {arXiv},
297
+ eprint = {2211.11337},
298
+ }
299
+ ```
@@ -0,0 +1,115 @@
1
+ hcpdiff/__init__.py,sha256=dwNwrEgvG4g60fGMG6b50K3q3AWD1XCfzlIgbxkSUpE,177
2
+ hcpdiff/train_colo.py,sha256=EsuNSzLBvGTZWU_LEk0JpP-F5eNW0lwkawIRAX38jmE,9250
3
+ hcpdiff/train_deepspeed.py,sha256=PwyNukWi0of6TXy_VRDgBQSMLCZBhipO5g3Lq0nCYNk,2988
4
+ hcpdiff/trainer_ac.py,sha256=6KAzo54in7ZRHud_rHjJdwRRZ4uWtc0B4SxVCxgcrmM,2990
5
+ hcpdiff/trainer_ac_single.py,sha256=0PIC5EScqcxp49EaeIWq4KS5K_09OZfKajqbFu-hUb8,1108
6
+ hcpdiff/ckpt_manager/__init__.py,sha256=Mn_5KOC4xbf2GcN6OXg_XdbF5wO9zWeER_1ZO_prKAI,256
7
+ hcpdiff/ckpt_manager/ckpt.py,sha256=Pa3uXQbCi2T99mpV5fYddQ-OGHcpk8r1ll-0lmP_WXk,965
8
+ hcpdiff/ckpt_manager/loader.py,sha256=Ch1xsZmseq4nyPhpox9-nebN-dZB4k0rqBEHos-ZLso,3245
9
+ hcpdiff/ckpt_manager/format/__init__.py,sha256=a3cdKkOTDgdVbDQwSC4mlxOigjX2hBvRb5_X7E3TQWs,237
10
+ hcpdiff/ckpt_manager/format/diffusers.py,sha256=T81WN95Nj1il9DfQp9iioVn0uqFEWOlmdIYs2beNOFU,3769
11
+ hcpdiff/ckpt_manager/format/emb.py,sha256=FrqfTfJ8H7f0Zw17NTWCP2AJtpsJI5oXR5IAd4NekhU,680
12
+ hcpdiff/ckpt_manager/format/lora_webui.py,sha256=4y_T9RdmFTxWzsXd8guNjCiukmyILa5j4MPrhVIL4Qk,10017
13
+ hcpdiff/ckpt_manager/format/sd_single.py,sha256=LpCAL_7nAVooCHTFznVVsNMku1G3C77NBORxxr8GDtQ,2328
14
+ hcpdiff/data/__init__.py,sha256=ZFKtanOoMo3G3eKUJPhysnHXnr8BNARERkcMB6B897U,292
15
+ hcpdiff/data/dataset.py,sha256=1k4GldW13eVyqK_9hrQniqr3_XYAapnWF7iXl_1GXGg,877
16
+ hcpdiff/data/cache/__init__.py,sha256=ToCmokYH6DghlSwm7HJFirPRIWJ0LkgzqVOYlgoAkQw,25
17
+ hcpdiff/data/cache/vae.py,sha256=gB89zs4CdNlvukDXhVYU9QZrY6VTFUWfzjeF2psNQ50,4070
18
+ hcpdiff/data/handler/__init__.py,sha256=G8ZTQF91ilkTRmUoWdmAissTSZ7fvNUpm_hBYmXKTtk,258
19
+ hcpdiff/data/handler/controlnet.py,sha256=bRDMD9BP8-VaG5VrxzvcFKfkqeTbChNfrJSZ3vXbQgY,658
20
+ hcpdiff/data/handler/diffusion.py,sha256=S-_7o5Z1tm6LmRZVZs21rbJC7iUoq0tHOsSjKK6geVk,4156
21
+ hcpdiff/data/handler/text.py,sha256=gOzqB2oEkEUbiuy0kZWduo0c-w4Buu60KI6q6Nyl3aM,4208
22
+ hcpdiff/data/source/__init__.py,sha256=265M8qfWNUE4SKX0pdXhLYjCnCuae5YE4bfZpO-ydXc,187
23
+ hcpdiff/data/source/folder_class.py,sha256=bs4qPMTzwcnT6ZFlT3tpi9sclsRF9a2MBA1pQD-9EYs,961
24
+ hcpdiff/data/source/text.py,sha256=VgI5Ouq986Yy1jwD2fZ9iBlsRciPCeARZmOPEZIcaQY,1468
25
+ hcpdiff/data/source/text2img.py,sha256=acYdolQhZUEpkd7tUAdNkCTVnPc1SMJOVTmGqFt9ZpE,1813
26
+ hcpdiff/data/source/text2img_cond.py,sha256=yj1KpARA2rkjENutnnzC4uDkcU2Rye21FL2VdC25Hac,585
27
+ hcpdiff/diffusion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
+ hcpdiff/diffusion/noise/__init__.py,sha256=seBpOtd0YsU53PqMn7Nyl_RtwoC-ONEIOX7v2XLGpZQ,93
29
+ hcpdiff/diffusion/noise/pyramid_noise.py,sha256=KbpyMT1BHNIaAa7g5eECDkTttOMoMWVFmbP-ekBsuEY,1693
30
+ hcpdiff/diffusion/noise/zero_terminal.py,sha256=EfVOaqrTCfw11AolDBl0LIOey3uQT1bDw2XKr2Bm434,1532
31
+ hcpdiff/diffusion/sampler/__init__.py,sha256=pSHsKpLjscY5yLbdzHeBUeK9nFDuVeMIIeA_k6FQFdY,158
32
+ hcpdiff/diffusion/sampler/base.py,sha256=2AuPVT2ZSXYt2etZmHMyNKuGlT5zn6KIkoMz4m5PGcs,2577
33
+ hcpdiff/diffusion/sampler/ddpm.py,sha256=raqSuKsEPN1AEqRVCuBdMAOnKDoeJTRO17wtLBNJCf4,523
34
+ hcpdiff/diffusion/sampler/diffusers.py,sha256=XIu-oIlT4LnAYI2-yyIoNeIMeSQe5YpeEvn-FkGVFnE,2684
35
+ hcpdiff/diffusion/sampler/edm.py,sha256=5W4pv8hxQsPpJGiFBgElZxR3C9S8kWAhzGKejEVwq3I,753
36
+ hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py,sha256=kmIoWgsWqij6b7KYon3UOCSC07sRJCo-DPR6qkJwUd0,184
37
+ hcpdiff/diffusion/sampler/sigma_scheduler/base.py,sha256=9-JI-jwf7xZoQUtrU0qfbjkhNZT8a_tmapLtwVbFUx0,381
38
+ hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py,sha256=2PMIpg2K6CVoxew1y1pIqvCHbdggC-m3amrOYk15OdQ,8107
39
+ hcpdiff/diffusion/sampler/sigma_scheduler/edm.py,sha256=fOPB3lgnS9uVo4oW26Fur_nc8X_wQ6mmUcbkKhnoQjs,1900
40
+ hcpdiff/easy/__init__.py,sha256=-emoyCOZlLCu3KNMI8L4qapUEtEYFSoiGU6-rKv1at4,149
41
+ hcpdiff/easy/sampler.py,sha256=dQSBkeGh71O0DAmZLhTHTbk1bY7XzyUCeW1oJO14A4I,1250
42
+ hcpdiff/easy/cfg/__init__.py,sha256=SxHMWG6T2CXhX3dP0xizSMd9vFWPaZQDc4Gj4CF__yQ,253
43
+ hcpdiff/easy/cfg/sd15_train.py,sha256=KtplqN-OhzdZjsX2s60J3XR6o7tRJ-QDx7Eqza_eDkM,6704
44
+ hcpdiff/easy/cfg/sdxl_train.py,sha256=ZKfJ19IvR2dZqDNXULmhZEmqjE7qV4QYxSTvEhI7efQ,4269
45
+ hcpdiff/easy/cfg/t2i.py,sha256=SnjFjZAKd9orjJr3RW5_N2_EIlW2Ree7JMvdNUAR9gc,9507
46
+ hcpdiff/easy/model/__init__.py,sha256=CA-7r3R2Jgweekk1XNByFYttLolbWyUV2bCnXygcD8w,133
47
+ hcpdiff/easy/model/cnet.py,sha256=m0NTH9V1kLzb5GybwBrSNT0KvTcRpPfGkzUeMz9jZZQ,1084
48
+ hcpdiff/easy/model/loader.py,sha256=Tdx-lhQEYf2NYjVM1A5B8x6ZZpJKcXUkFIPIbr7h7XM,3456
49
+ hcpdiff/evaluate/__init__.py,sha256=CtNzi8xdUWZBDBcP5TZTDMcRyykaOJhBIxTJgVuMabo,35
50
+ hcpdiff/evaluate/previewer.py,sha256=QiYYiEJBKP06uL3wKLhnpIGUZuAkr9BuHxTE29obpXI,2432
51
+ hcpdiff/loss/__init__.py,sha256=2dwPczSiv3rB5fzOeYbl5ZHpMU-qXOQlXeOiXdxcxwM,173
52
+ hcpdiff/loss/base.py,sha256=3bvgMbwyPOEA9iSkv0hRHw4VnKjkUCZAENNnDMFilYM,1780
53
+ hcpdiff/loss/gw.py,sha256=0yi1kozuII3xZA6FnjOhINtvScWt1MyBZLBtMKmgojM,1224
54
+ hcpdiff/loss/ssim.py,sha256=YofadvBkc6sklxBUx1p3ADw5OHOZPK3kaHz8FH5a6m4,1281
55
+ hcpdiff/loss/vlb.py,sha256=s78iBnXUiDWfGf7mYmhUnHqxqea5gSByKOoqBrX6bzU,3222
56
+ hcpdiff/loss/weighting.py,sha256=UR2PyZ1JTNOydXMw4e1Fh52XmtwKaviPvddcmVKCTlI,2242
57
+ hcpdiff/models/__init__.py,sha256=eQS7DPiGLiE1MFRkZj_17IY3IsfDUVcYpcOmhHb5B9o,472
58
+ hcpdiff/models/cfg_context.py,sha256=e2B3K1KwJhzbD6xdJUOyNtl_XgQ0296XI3FHw3gvZF4,1502
59
+ hcpdiff/models/container.py,sha256=z3p5TmQhxdzXSIfofz55_bmEhSsgUJsy1o9EcDs8Oeo,696
60
+ hcpdiff/models/controlnet.py,sha256=VIkUzJCVpCqqQOtRSLQPfbcDy9CsXutxLeZB6PdZfA0,7809
61
+ hcpdiff/models/lora_base.py,sha256=LGwBD9KP6qf4pgTx24i5-JLo4rDBQ6jFfterQKBjTbE,6758
62
+ hcpdiff/models/lora_base_patch.py,sha256=WW3CULnROTxKXyynJiqirhHYCKN5JtxLhVpT5b7AUQg,6532
63
+ hcpdiff/models/lora_layers.py,sha256=O9W_Ue71lHj7Y_GbpioF4Hc3h2-z_zOqck93VYUra6s,7777
64
+ hcpdiff/models/lora_layers_patch.py,sha256=GYFYsJD2VSLZfdnLma9CmQEHz09HROFJcc4wc_gs9f0,8198
65
+ hcpdiff/models/text_emb_ex.py,sha256=a5QImxzvj0zWR12qXOPP9kmpESl8J9VLabA0W9D_i_c,7867
66
+ hcpdiff/models/textencoder_ex.py,sha256=JrTQ30Avx8tPbdr-Q6K5BvEWCEdsu8Z7eSOzMqpUuzg,8270
67
+ hcpdiff/models/tokenizer_ex.py,sha256=zKUn4BY7b3yXwK9PWkZtQKJPyKYwUc07E-hwB9NQybs,2446
68
+ hcpdiff/models/compose/__init__.py,sha256=lTNFTGg5csqvUuys22RqgjmWlk_7Okw6ZTsnTi1pqCg,217
69
+ hcpdiff/models/compose/compose_hook.py,sha256=FfDSfn5FuLFGM80HMUwiUopy1P4xDbvKSBDuA6QK2So,6112
70
+ hcpdiff/models/compose/compose_textencoder.py,sha256=tiFoStKOIEH9YzsZQrLki4gra18kMy3wSzSUrVQG1sk,6607
71
+ hcpdiff/models/compose/compose_tokenizer.py,sha256=g3l0pOFv6p7Iigxm6Pqt_iTUXBlO1_SWAQOt0m54IoE,3033
72
+ hcpdiff/models/compose/sdxl_composer.py,sha256=NtMGaFGZTfKsPJSVi2yT-UM6K1WKWtk99XxVmTcKlk8,2164
73
+ hcpdiff/models/wrapper/__init__.py,sha256=HbGQmFnfccr-dtvZKjEv-pmR4cCnF4fwGLKS3tuG_OY,135
74
+ hcpdiff/models/wrapper/pixart.py,sha256=nRUvHSHn4TYg_smC0xpeW-GtUgXss-MuaVPTHpMozDE,1147
75
+ hcpdiff/models/wrapper/sd.py,sha256=D7VDI4OmbLTk9mzYta-C4LJjWfZmuBiDub4t8v1-M9o,11711
76
+ hcpdiff/models/wrapper/utils.py,sha256=NyebMoAPnrgcTHbiIocSD-eGdGdD-V1G_TQuWsRWufw,665
77
+ hcpdiff/parser/__init__.py,sha256=-2dDZ2Ii4zoGQqDTme94q4PpJbBiV6HS5BsDASz4Xbo,33
78
+ hcpdiff/parser/embpt.py,sha256=LgwZ0f0tLn3DrTo5ZpSCsZcA5330UpiW_sK96yEPmOM,1307
79
+ hcpdiff/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
80
+ hcpdiff/tools/convert_caption_txt2json.py,sha256=tbBgIphJWvXUoXjtwsnLX2w9IZEY3jTgxbTvUMgukbM,945
81
+ hcpdiff/tools/convert_old_lora.py,sha256=yIP9RGcyQbwT2NNAZtTLgBXs6XJOHRvoHQep0SdqDho,453
82
+ hcpdiff/tools/create_embedding.py,sha256=qbBx3mFPeVcVSeMWn5_-3Dq7EqWn0UTOQg6OurIwDeo,4577
83
+ hcpdiff/tools/dataset_generator.py,sha256=CG28IP1w0XAlEG1xPB0jQi40hVaFGxLe76tnV6hya-E,4639
84
+ hcpdiff/tools/diffusers2sd.py,sha256=MdZirlyztaaO4UAYU_AJa25eTFKfJa08nbHbdcqOdhs,13300
85
+ hcpdiff/tools/download_hf_model.py,sha256=3XZE-GYotE0W5ZQNk78ljZgolYhjbNKxfU-tB9I0yDY,947
86
+ hcpdiff/tools/embedding_convert.py,sha256=JN56dVFcMng8kq9YGS02_Umco71UjDYvwQGeqbhsJXU,1498
87
+ hcpdiff/tools/gen_from_ptlist.py,sha256=j4_wNd9JmsWUOcgkP4q4P1gB1SVnzy0VxY4pJqXM07Q,3105
88
+ hcpdiff/tools/init_proj.py,sha256=XrXxxhIaItywG7HsrloJo-x8w9suZiY35daelzZvjrg,194
89
+ hcpdiff/tools/lora_convert.py,sha256=So14WvSVIm6rU4m1XCajFXDnhq7abpZS95SLbaoyBFU,10058
90
+ hcpdiff/tools/save_model.py,sha256=gbfYi_EfEBZEUcDjle6MDHA19sQWY0zA8_y_LMzHQ7M,428
91
+ hcpdiff/tools/sd2diffusers.py,sha256=vB6OnBLw60sJkdpVZcYEPtKAZW1h8ErbSGSRq0uAiIk,16855
92
+ hcpdiff/utils/__init__.py,sha256=VOLhdNq2mRyqmWxrssIWSZtR_PQ8rFwo2u0uq6GbLHA,45
93
+ hcpdiff/utils/colo_utils.py,sha256=JyLUvVnISa48CnryNLrgVxMo-jxu2UhBq70eYPrkjuI,837
94
+ hcpdiff/utils/inpaint_pipe.py,sha256=CRy1MUlPmHifCAbZnKOP0qbLp2grn7ZbVeaB2qIA4ig,42862
95
+ hcpdiff/utils/net_utils.py,sha256=gdwLYDNKV2t3SP0jBIO3d0HtY6E7jRaf_rmPT8gKZZE,9762
96
+ hcpdiff/utils/pipe_hook.py,sha256=-UDX3FtZGl-bxSk13gdbPXc1OvtbCcpk_fvKxLQo3Ag,31987
97
+ hcpdiff/utils/utils.py,sha256=hZnZP1IETgVpScxES0yIuRfc34TnzvAqmgOTK_56ssw,4976
98
+ hcpdiff/workflow/__init__.py,sha256=t7Zyc0XFORdNvcwHp9AsCtEkhJ3l7Hm41ugngIL0Sag,867
99
+ hcpdiff/workflow/diffusion.py,sha256=yzhqKA3019OPu1RKggrLoytMgm919qf6j9S85PYOwjQ,8644
100
+ hcpdiff/workflow/fast.py,sha256=kZt7bKrvpFInSn7GzbkTkpoCSM0Z6IbDjgaDvcbFYf8,1024
101
+ hcpdiff/workflow/flow.py,sha256=FFbFFOAXT4c31L5bHBEB_qeVGuBQDLYhq8kTD1chGNo,2548
102
+ hcpdiff/workflow/io.py,sha256=aTrMR3s44apVJpnSyvZIabW2Op0tslk_Z9JFJl5svm0,2635
103
+ hcpdiff/workflow/model.py,sha256=1gj5yOTefYTnGXVR6JPAfxIwuB69YwN6E-BontRcuyQ,2913
104
+ hcpdiff/workflow/text.py,sha256=Z__SJHZyuaKyzkYJ6rbiAzOGRiYcCjwCGeqfpP1Jo7o,4336
105
+ hcpdiff/workflow/utils.py,sha256=xojaMG4lHsymslc8df5uiVXmmBVWpn_Phqka8qzJEWw,2226
106
+ hcpdiff/workflow/vae.py,sha256=cingDPkIOc4qGpOwwhXJK4EQbGoIxO583pm6gGov5t8,3118
107
+ hcpdiff/workflow/daam/__init__.py,sha256=ySIDaxloN-D3qM7OuVaG1BR3D-CibDoXYpoTgw0zUhU,59
108
+ hcpdiff/workflow/daam/act.py,sha256=tHbsFWTYYU4bvcZOo1Bpi_z6ofpJatRYccl4vvf8wIA,2756
109
+ hcpdiff/workflow/daam/hook.py,sha256=z9f9mBjKW21xuUZ-iQxQ0HbWOBXtZrisFB0VNMq6d0U,4383
110
+ hcpdiff-2.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
111
+ hcpdiff-2.2.dist-info/METADATA,sha256=u52mZtA0hI2P_fObmJZRUkZZfnKFYg5c24f4p0trH0o,9833
112
+ hcpdiff-2.2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
113
+ hcpdiff-2.2.dist-info/entry_points.txt,sha256=86wPOMzsfWWflTJ-sQPLc7WG5Vtu0kGYBH9C_vR3ur8,207
114
+ hcpdiff-2.2.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
115
+ hcpdiff-2.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.42.0)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -0,0 +1,5 @@
1
+ [console_scripts]
2
+ hcp_run = rainbowneko.infer.infer_workflow:run_workflow
3
+ hcp_train = hcpdiff.trainer_ac:hcp_train
4
+ hcp_train_1gpu = hcpdiff.trainer_ac_single:hcp_train
5
+ hcpinit = hcpdiff.tools.init_proj:main
@@ -1,16 +0,0 @@
1
- from diffusers import StableDiffusionPipeline
2
- from diffusers.models.lora import LoRACompatibleLinear
3
-
4
- class CkptManagerBase:
5
- def __init__(self, **kwargs):
6
- pass
7
-
8
- def set_save_dir(self, save_dir, emb_dir=None):
9
- raise NotImplementedError()
10
-
11
- def save(self, step, unet, TE, lora_unet, lora_TE, all_plugin_unet, all_plugin_TE, embs, pipe: StableDiffusionPipeline, **kwargs):
12
- raise NotImplementedError()
13
-
14
- @classmethod
15
- def load(cls, pretrained_model, **kwargs) -> StableDiffusionPipeline:
16
- raise NotImplementedError
@@ -1,45 +0,0 @@
1
- from .base import CkptManagerBase
2
- import os
3
- from diffusers import StableDiffusionPipeline, UNet2DConditionModel
4
- from hcpdiff.models.plugin import BasePluginBlock
5
-
6
-
7
- class CkptManagerDiffusers(CkptManagerBase):
8
-
9
- def set_save_dir(self, save_dir, emb_dir=None):
10
- os.makedirs(save_dir, exist_ok=True)
11
- self.save_dir = save_dir
12
- self.emb_dir = emb_dir
13
-
14
- def save(self, step, unet, TE, lora_unet, lora_TE, all_plugin_unet, all_plugin_TE, embs, pipe: StableDiffusionPipeline, **kwargs):
15
- def state_dict_unet(*args, model=unet, **kwargs):
16
- plugin_names = {k for k, v in model.named_modules() if isinstance(v, BasePluginBlock)}
17
- model_sd = {}
18
- for k, v in model.state_dict_().items():
19
- for name in plugin_names:
20
- if k.startswith(name):
21
- break
22
- else:
23
- model_sd[k] = v
24
- return model_sd
25
- unet.state_dict_ = unet.state_dict
26
- unet.state_dict = state_dict_unet
27
-
28
- def state_dict_TE(*args, model=TE, **kwargs):
29
- plugin_names = {k for k, v in model.named_modules() if isinstance(v, BasePluginBlock)}
30
- model_sd = {}
31
- for k, v in model.state_dict_().items():
32
- for name in plugin_names:
33
- if k.startswith(name):
34
- break
35
- else:
36
- model_sd[k] = v
37
- return model_sd
38
- TE.state_dict_ = TE.state_dict
39
- TE.state_dict = state_dict_TE
40
-
41
- pipe.save_pretrained(os.path.join(self.save_dir, f"model-{step}"), **kwargs)
42
-
43
- @classmethod
44
- def load(cls, pretrained_model, **kwargs) -> StableDiffusionPipeline:
45
- return StableDiffusionPipeline.from_pretrained(pretrained_model, **kwargs)
@@ -1,138 +0,0 @@
1
- """
2
- ckpt_pkl.py
3
- ====================
4
- :Name: save model with torch
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 8/04/2023
8
- :Licence: MIT
9
- """
10
-
11
- from typing import Dict
12
- import os
13
-
14
- import torch
15
- from torch import nn
16
-
17
- from hcpdiff.models.lora_base import LoraBlock, LoraGroup, split_state
18
- from hcpdiff.models.plugin import PluginGroup, BasePluginBlock
19
- from hcpdiff.utils.net_utils import save_emb
20
- from .base import CkptManagerBase
21
-
22
- class CkptManagerPKL(CkptManagerBase):
23
- def __init__(self, plugin_from_raw=False, **kwargs):
24
- self.plugin_from_raw = plugin_from_raw
25
-
26
- def set_save_dir(self, save_dir, emb_dir=None):
27
- os.makedirs(save_dir, exist_ok=True)
28
- self.save_dir = save_dir
29
- self.emb_dir = emb_dir
30
-
31
- def exclude_state(self, state, key):
32
- if key is None:
33
- return state
34
- else:
35
- return {k: v for k, v in state.items() if key not in k}
36
-
37
- def save_model(self, model: nn.Module, name, step, model_ema=None, exclude_key=None):
38
- sd_model = {
39
- 'base': self.exclude_state(LoraBlock.extract_trainable_state_without_lora(model), exclude_key),
40
- }
41
- if model_ema is not None:
42
- sd_ema, sd_ema_lora = split_state(model_ema.state_dict())
43
- sd_model['base_ema'] = self.exclude_state(sd_ema, exclude_key)
44
- self._save_ckpt(sd_model, name, step)
45
-
46
- def save_plugins(self, host_model: nn.Module, plugins: Dict[str, PluginGroup], name:str, step:int, model_ema=None):
47
- if len(plugins)>0:
48
- sd_plugin={}
49
- for plugin_name, plugin in plugins.items():
50
- sd_plugin['plugin'] = plugin.state_dict(host_model if self.plugin_from_raw else None)
51
- if model_ema is not None:
52
- sd_plugin['plugin_ema'] = plugin.state_dict(model_ema)
53
- self._save_ckpt(sd_plugin, f'{name}-{plugin_name}', step)
54
-
55
- def save_model_with_lora(self, model: nn.Module, lora_blocks: LoraGroup, name:str, step:int, model_ema=None,
56
- exclude_key=None):
57
- sd_model = {
58
- 'base': self.exclude_state(BasePluginBlock.extract_state_without_plugin(model, trainable=True), exclude_key),
59
- } if model is not None else {}
60
- if (lora_blocks is not None) and (not lora_blocks.empty()):
61
- sd_model['lora'] = lora_blocks.state_dict(model if self.plugin_from_raw else None)
62
-
63
- if model_ema is not None:
64
- ema_state = model_ema.state_dict()
65
- if model is not None:
66
- sd_ema = {k:ema_state[k] for k in sd_model['base'].keys()}
67
- sd_model['base_ema'] = self.exclude_state(sd_ema, exclude_key)
68
- if (lora_blocks is not None) and (not lora_blocks.empty()):
69
- sd_model['lora_ema'] = lora_blocks.state_dict(model_ema)
70
-
71
- self._save_ckpt(sd_model, name, step)
72
-
73
- def _save_ckpt(self, sd_model, name=None, step=None, save_path=None):
74
- if save_path is None:
75
- save_path = os.path.join(self.save_dir, f"{name}-{step}.ckpt")
76
- torch.save(sd_model, save_path)
77
-
78
- def load_ckpt(self, ckpt_path, map_location='cpu'):
79
- return torch.load(ckpt_path, map_location=map_location)
80
-
81
- def load_ckpt_to_model(self, model: nn.Module, ckpt_path, model_ema=None):
82
- sd = self.load_ckpt(ckpt_path)
83
- if 'base' in sd:
84
- model.load_state_dict(sd['base'], strict=False)
85
- if 'lora' in sd:
86
- model.load_state_dict(sd['lora'], strict=False)
87
- if 'plugin' in sd:
88
- model.load_state_dict(sd['plugin'], strict=False)
89
-
90
- if model_ema is not None:
91
- if 'base' in sd:
92
- model_ema.load_state_dict(sd['base_ema'])
93
- if 'lora' in sd:
94
- model_ema.load_state_dict(sd['lora_ema'])
95
- if 'plugin' in sd:
96
- model_ema.load_state_dict(sd['plugin_ema'])
97
-
98
- def save_embedding(self, train_pts, step, replace):
99
- for k, v in train_pts.items():
100
- save_path = os.path.join(self.save_dir, f"{k}-{step}.pt")
101
- save_emb(save_path, v.data, replace=True)
102
- if replace:
103
- save_emb(f'{k}.pt', v.data, replace=True)
104
-
105
- def save(self, step, unet, TE, lora_unet, lora_TE, all_plugin_unet, all_plugin_TE, embs, pipe):
106
- '''
107
-
108
- :param step:
109
- :param unet:
110
- :param TE:
111
- :param lora_unet: [pos, neg]
112
- :param lora_TE: [pos, neg]
113
- :param all_plugin_unet:
114
- :param all_plugin_TE:
115
- :param emb:
116
- :param pipe:
117
- :return:
118
- '''
119
- self.save_model_with_lora(unet, lora_unet[0], model_ema=getattr(self, 'ema_unet', None), name='unet', step=step)
120
- self.save_plugins(unet, all_plugin_unet, name='unet', step=step, model_ema=getattr(self, 'ema_unet', None))
121
-
122
- if TE is not None:
123
- # exclude_key: embeddings should not save with text-encoder
124
- self.save_model_with_lora(TE, lora_TE[0], model_ema=getattr(self, 'ema_text_encoder', None),
125
- name='text_encoder', step=step, exclude_key='emb_ex.')
126
- self.save_plugins(TE, all_plugin_TE, name='text_encoder', step=step,
127
- model_ema=getattr(self, 'ema_text_encoder', None))
128
-
129
- if lora_unet[1] is not None:
130
- self.save_model_with_lora(None, lora_unet[1], name='unet-neg', step=step)
131
- if lora_TE[1] is not None:
132
- self.save_model_with_lora(None, lora_TE[1], name='text_encoder-neg', step=step)
133
-
134
- self.save_embedding(embs, step, False)
135
-
136
- @classmethod
137
- def load(cls, pretrained_model):
138
- raise NotImplementedError(f'{cls} dose not support load()')