hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +90 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +3 -3
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text.py +40 -0
  21. hcpdiff/data/source/text2img.py +36 -74
  22. hcpdiff/data/source/text2img_cond.py +9 -15
  23. hcpdiff/diffusion/__init__.py +0 -0
  24. hcpdiff/diffusion/noise/__init__.py +2 -0
  25. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  26. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  27. hcpdiff/diffusion/sampler/__init__.py +5 -0
  28. hcpdiff/diffusion/sampler/base.py +72 -0
  29. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  30. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  31. hcpdiff/diffusion/sampler/edm.py +22 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  35. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  36. hcpdiff/easy/__init__.py +2 -0
  37. hcpdiff/easy/cfg/__init__.py +3 -0
  38. hcpdiff/easy/cfg/sd15_train.py +207 -0
  39. hcpdiff/easy/cfg/sdxl_train.py +147 -0
  40. hcpdiff/easy/cfg/t2i.py +228 -0
  41. hcpdiff/easy/model/__init__.py +2 -0
  42. hcpdiff/easy/model/cnet.py +31 -0
  43. hcpdiff/easy/model/loader.py +79 -0
  44. hcpdiff/easy/sampler.py +46 -0
  45. hcpdiff/evaluate/__init__.py +1 -0
  46. hcpdiff/evaluate/previewer.py +60 -0
  47. hcpdiff/loss/__init__.py +4 -1
  48. hcpdiff/loss/base.py +41 -0
  49. hcpdiff/loss/gw.py +35 -0
  50. hcpdiff/loss/ssim.py +37 -0
  51. hcpdiff/loss/vlb.py +79 -0
  52. hcpdiff/loss/weighting.py +66 -0
  53. hcpdiff/models/__init__.py +2 -2
  54. hcpdiff/models/cfg_context.py +17 -14
  55. hcpdiff/models/compose/compose_hook.py +44 -23
  56. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  57. hcpdiff/models/compose/sdxl_composer.py +4 -4
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/init_proj.py +3 -21
  74. hcpdiff/tools/lora_convert.py +18 -17
  75. hcpdiff/tools/save_model.py +12 -0
  76. hcpdiff/tools/sd2diffusers.py +1 -1
  77. hcpdiff/train_colo.py +1 -1
  78. hcpdiff/train_deepspeed.py +1 -1
  79. hcpdiff/trainer_ac.py +79 -0
  80. hcpdiff/trainer_ac_single.py +31 -0
  81. hcpdiff/utils/__init__.py +0 -2
  82. hcpdiff/utils/inpaint_pipe.py +7 -2
  83. hcpdiff/utils/net_utils.py +29 -6
  84. hcpdiff/utils/pipe_hook.py +24 -7
  85. hcpdiff/utils/utils.py +21 -4
  86. hcpdiff/workflow/__init__.py +15 -10
  87. hcpdiff/workflow/daam/__init__.py +1 -0
  88. hcpdiff/workflow/daam/act.py +66 -0
  89. hcpdiff/workflow/daam/hook.py +109 -0
  90. hcpdiff/workflow/diffusion.py +118 -128
  91. hcpdiff/workflow/fast.py +31 -0
  92. hcpdiff/workflow/flow.py +67 -0
  93. hcpdiff/workflow/io.py +36 -130
  94. hcpdiff/workflow/model.py +46 -43
  95. hcpdiff/workflow/text.py +60 -47
  96. hcpdiff/workflow/utils.py +32 -12
  97. hcpdiff/workflow/vae.py +37 -38
  98. hcpdiff-2.2.dist-info/METADATA +299 -0
  99. hcpdiff-2.2.dist-info/RECORD +115 -0
  100. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
  101. hcpdiff-2.2.dist-info/entry_points.txt +5 -0
  102. hcpdiff/ckpt_manager/base.py +0 -16
  103. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  104. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  105. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  106. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  107. hcpdiff/data/bucket.py +0 -358
  108. hcpdiff/data/caption_loader.py +0 -80
  109. hcpdiff/data/cond_dataset.py +0 -40
  110. hcpdiff/data/crop_info_dataset.py +0 -40
  111. hcpdiff/data/data_processor.py +0 -33
  112. hcpdiff/data/pair_dataset.py +0 -146
  113. hcpdiff/data/sampler.py +0 -54
  114. hcpdiff/data/source/base.py +0 -30
  115. hcpdiff/data/utils.py +0 -80
  116. hcpdiff/deprecated/__init__.py +0 -1
  117. hcpdiff/deprecated/cfg_converter.py +0 -81
  118. hcpdiff/deprecated/lora_convert.py +0 -31
  119. hcpdiff/infer_workflow.py +0 -57
  120. hcpdiff/loggers/__init__.py +0 -13
  121. hcpdiff/loggers/base_logger.py +0 -76
  122. hcpdiff/loggers/cli_logger.py +0 -40
  123. hcpdiff/loggers/preview/__init__.py +0 -1
  124. hcpdiff/loggers/preview/image_previewer.py +0 -149
  125. hcpdiff/loggers/tensorboard_logger.py +0 -30
  126. hcpdiff/loggers/wandb_logger.py +0 -31
  127. hcpdiff/loggers/webui_logger.py +0 -9
  128. hcpdiff/loss/min_snr_loss.py +0 -52
  129. hcpdiff/models/layers.py +0 -81
  130. hcpdiff/models/plugin.py +0 -348
  131. hcpdiff/models/wrapper.py +0 -75
  132. hcpdiff/noise/__init__.py +0 -3
  133. hcpdiff/noise/noise_base.py +0 -16
  134. hcpdiff/noise/pyramid_noise.py +0 -50
  135. hcpdiff/noise/zero_terminal.py +0 -44
  136. hcpdiff/train_ac.py +0 -566
  137. hcpdiff/train_ac_single.py +0 -39
  138. hcpdiff/utils/caption_tools.py +0 -105
  139. hcpdiff/utils/cfg_net_tools.py +0 -321
  140. hcpdiff/utils/cfg_resolvers.py +0 -16
  141. hcpdiff/utils/ema.py +0 -52
  142. hcpdiff/utils/img_size_tool.py +0 -248
  143. hcpdiff/vis/__init__.py +0 -3
  144. hcpdiff/vis/base_interface.py +0 -12
  145. hcpdiff/vis/disk_interface.py +0 -48
  146. hcpdiff/vis/webui_interface.py +0 -17
  147. hcpdiff/viser_fast.py +0 -138
  148. hcpdiff/visualizer.py +0 -265
  149. hcpdiff/visualizer_reloadable.py +0 -237
  150. hcpdiff/workflow/base.py +0 -59
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  198. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  206. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  207. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  208. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  209. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
  211. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
@@ -1,64 +0,0 @@
1
- """
2
- ckpt_safetensors.py
3
- ====================
4
- :Name: save model with safetensors
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 8/04/2023
8
- :Licence: MIT
9
- """
10
-
11
- import os
12
- import torch
13
- from safetensors import safe_open
14
- from safetensors.torch import save_file
15
-
16
- from .ckpt_pkl import CkptManagerPKL
17
-
18
- class CkptManagerSafe(CkptManagerPKL):
19
-
20
- def _save_ckpt(self, sd_model, name=None, step=None, save_path=None):
21
- if save_path is None:
22
- save_path = os.path.join(self.save_dir, f"{name}-{step}.safetensors")
23
- sd_unfold = self.unfold_dict(sd_model)
24
- for k, v in sd_unfold.items():
25
- if not v.is_contiguous():
26
- sd_unfold[k] = v.contiguous()
27
- save_file(sd_unfold, save_path)
28
-
29
- def load_ckpt(self, ckpt_path, map_location='cpu'):
30
- with safe_open(ckpt_path, framework="pt", device=map_location) as f:
31
- sd_fold = self.fold_dict(f)
32
- return sd_fold
33
-
34
- @staticmethod
35
- def unfold_dict(data, split_key=':'):
36
- dict_unfold={}
37
-
38
- def unfold(prefix, dict_fold):
39
- for k,v in dict_fold.items():
40
- k_new = k if prefix=='' else f'{prefix}{split_key}{k}'
41
- if isinstance(v, dict):
42
- unfold(k_new, v)
43
- elif isinstance(v, list) or isinstance(v, tuple):
44
- unfold(k_new, {i:d for i,d in enumerate(v)})
45
- else:
46
- dict_unfold[k_new]=v
47
-
48
- unfold('', data)
49
- return dict_unfold
50
-
51
- @staticmethod
52
- def fold_dict(safe_f, split_key=':'):
53
- dict_fold = {}
54
-
55
- for k in safe_f.keys():
56
- k_list = k.split(split_key)
57
- dict_last = dict_fold
58
- for item in k_list[:-1]:
59
- if item not in dict_last:
60
- dict_last[item] = {}
61
- dict_last = dict_last[item]
62
- dict_last[k_list[-1]]=safe_f.get_tensor(k)
63
-
64
- return dict_fold
@@ -1,54 +0,0 @@
1
- from .base import CkptManagerBase
2
- import os
3
- from diffusers import StableDiffusionPipeline, UNet2DConditionModel
4
- from hcpdiff.models.plugin import BasePluginBlock
5
- from hcpdiff.tools.sd2diffusers import load_sd_ckpt, patch_method
6
-
7
- class CkptManagerWebui(CkptManagerBase):
8
-
9
- def set_save_dir(self, save_dir, emb_dir=None):
10
- os.makedirs(save_dir, exist_ok=True)
11
- self.save_dir = save_dir
12
- self.emb_dir = emb_dir
13
-
14
- def save(self, step, unet, TE, lora_unet, lora_TE, all_plugin_unet, all_plugin_TE, embs, pipe: StableDiffusionPipeline, **kwargs):
15
- def state_dict_unet(*args, model=unet, **kwargs):
16
- plugin_names = {k for k, v in model.named_modules() if isinstance(v, BasePluginBlock)}
17
- model_sd = {}
18
- for k, v in model.state_dict_().items():
19
- for name in plugin_names:
20
- if k.startswith(name):
21
- break
22
- else:
23
- model_sd[k] = v
24
- return model_sd
25
- unet.state_dict_ = unet.state_dict
26
- unet.state_dict = state_dict_unet
27
-
28
- def state_dict_TE(*args, model=TE, **kwargs):
29
- plugin_names = {k for k, v in model.named_modules() if isinstance(v, BasePluginBlock)}
30
- model_sd = {}
31
- for k, v in model.state_dict_().items():
32
- for name in plugin_names:
33
- if k.startswith(name):
34
- break
35
- else:
36
- model_sd[k] = v
37
- return model_sd
38
- TE.state_dict_ = TE.state_dict
39
- TE.state_dict = state_dict_TE
40
-
41
- pipe.save_pretrained(os.path.join(self.save_dir, f"model-{step}"), **kwargs)
42
-
43
- @classmethod
44
- def load(cls, pretrained_model, original_config_file, from_safetensors=False, device='cpu', ema=True, **kwargs) -> StableDiffusionPipeline:
45
- patch_method()
46
- pipe = load_sd_ckpt(
47
- checkpoint_path=pretrained_model,
48
- original_config_file=original_config_file,
49
- extract_ema=ema,
50
- scheduler_type='pndm',
51
- from_safetensors=from_safetensors,
52
- device=device,
53
- )
54
- return pipe
hcpdiff/data/bucket.py DELETED
@@ -1,358 +0,0 @@
1
- """
2
- bucket.py
3
- ====================
4
- :Name: aspect ratio bucket with k-means
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 10/03/2023
8
- :Licence: Apache-2.0
9
- """
10
-
11
- import math
12
- import os.path
13
- import pickle
14
- from typing import List, Tuple, Union, Any
15
-
16
- import cv2
17
- import numpy as np
18
- from hcpdiff.utils.img_size_tool import types_support, get_image_size
19
- from hcpdiff.utils.utils import get_file_ext
20
- from .source import DataSource
21
- from loguru import logger
22
- from sklearn.cluster import KMeans
23
- from tqdm import tqdm
24
- from concurrent.futures import ThreadPoolExecutor
25
-
26
- from .utils import resize_crop_fix, pad_crop_fix
27
-
28
- class BaseBucket:
29
- def __getitem__(self, idx):
30
- '''
31
- :return: (file name of image), (target image size)
32
- '''
33
- raise NotImplementedError()
34
-
35
- def __len__(self):
36
- raise NotImplementedError()
37
-
38
- def build(self, bs: int, img_root_list: List[str]):
39
- raise NotImplementedError()
40
-
41
- def rest(self, epoch):
42
- pass
43
-
44
- def crop_resize(self, image, size, mask_interp=cv2.INTER_CUBIC) -> Tuple[Any, Tuple]:
45
- return image, (*size, 0, 0, *size)
46
-
47
- class FixedBucket(BaseBucket):
48
- def __init__(self, target_size: Union[Tuple[int, int], int] = 512, **kwargs):
49
- self.target_size = (target_size, target_size) if isinstance(target_size, int) else target_size
50
-
51
- def build(self, bs: int, file_names: List[Tuple[str, DataSource]]):
52
- self.file_names = file_names
53
-
54
- def crop_resize(self, image, size, mask_interp=cv2.INTER_CUBIC):
55
- return resize_crop_fix(image, size, mask_interp=mask_interp)
56
-
57
- def __getitem__(self, idx) -> Tuple[Tuple[str, DataSource], Tuple[int, int]]:
58
- return self.file_names[idx], self.target_size
59
-
60
- def __len__(self):
61
- return len(self.file_names)
62
-
63
- class RatioBucket(BaseBucket):
64
- def __init__(self, target_area: int = 640*640, step_size: int = 8, num_bucket: int = 10, pre_build_bucket: str = None):
65
- self.target_area = target_area
66
- self.step_size = step_size
67
- self.num_bucket = num_bucket
68
- self.pre_build_bucket = pre_build_bucket
69
-
70
- def load_bucket(self, path):
71
- with open(path, 'rb') as f:
72
- data = pickle.load(f)
73
- self.buckets = data['buckets']
74
- self.size_buckets = data['size_buckets']
75
- self.idx_bucket_map = data['idx_bucket_map']
76
- self.data_len = data['data_len']
77
-
78
- def save_bucket(self, path):
79
- with open(path, 'wb') as f:
80
- pickle.dump({
81
- 'buckets':self.buckets,
82
- 'size_buckets':self.size_buckets,
83
- 'idx_bucket_map':self.idx_bucket_map,
84
- 'data_len':self.data_len,
85
- }, f)
86
-
87
- def build_buckets_from_ratios(self):
88
- logger.info('build buckets from ratios')
89
- size_low = int(math.sqrt(self.target_area/self.ratio_max))
90
- size_high = int(self.ratio_max*size_low)
91
-
92
- # SD需要边长是8的倍数
93
- size_low = (size_low//self.step_size)*self.step_size
94
- size_high = (size_high//self.step_size)*self.step_size
95
-
96
- data = []
97
- for w in range(size_low, size_high+1, self.step_size):
98
- for h in range(size_low, size_high+1, self.step_size):
99
- data.append([w*h, np.log2(w/h), w, h]) # 对比例取对数,更符合人感知,宽高相反的可以对称分布。
100
- data = np.array(data)
101
-
102
- error_area = np.abs(data[:, 0]-self.target_area)
103
- data_use = data[np.argsort(error_area)[:self.num_bucket*3], :] # 取最小的num_bucket*3个
104
-
105
- # 聚类,选出指定个数的bucket
106
- kmeans = KMeans(n_clusters=self.num_bucket, random_state=3407).fit(data_use[:, 1].reshape(-1, 1))
107
- labels = kmeans.labels_
108
- self.buckets = [] # [bucket_id:[file_idx,...]]
109
- ratios_log = []
110
- self.size_buckets = []
111
- for i in range(self.num_bucket):
112
- map_idx = np.where(labels == i)[0]
113
- m_idx = map_idx[np.argmin(np.abs(data_use[labels == i, 1]-np.median(data_use[labels == i, 1])))]
114
- # self.buckets[wh_hash(*data_use[m_idx, 2:])]=[]
115
- self.buckets.append([])
116
- ratios_log.append(data_use[m_idx, 1])
117
- self.size_buckets.append(data_use[m_idx, 2:].astype(int))
118
- ratios_log = np.array(ratios_log)
119
- self.size_buckets = np.array(self.size_buckets)
120
-
121
- # fill buckets with images w,h
122
- self.idx_bucket_map = np.empty(len(self.file_names), dtype=int)
123
- for i, (file, source) in enumerate(self.file_names):
124
- w, h = get_image_size(file)
125
- bucket_id = np.abs(ratios_log-np.log2(w/h)).argmin()
126
- self.buckets[bucket_id].append(i)
127
- self.idx_bucket_map[i] = bucket_id
128
- logger.info('buckets info: '+', '.join(f'size:{self.size_buckets[i]}, num:{len(b)}' for i, b in enumerate(self.buckets)))
129
-
130
- def build_buckets_from_images(self):
131
- logger.info('build buckets from images')
132
-
133
- def get_ratio(data):
134
- file, source = data
135
- w, h = get_image_size(file)
136
- ratio = np.log2(w/h)
137
- return ratio
138
-
139
- ratio_list = []
140
- with ThreadPoolExecutor() as executor:
141
- for ratio in tqdm(executor.map(get_ratio, self.file_names), desc='get image info', total=len(self.file_names)):
142
- ratio_list.append(ratio)
143
- ratio_list = np.array(ratio_list)
144
-
145
- # 聚类,选出指定个数的bucket
146
- kmeans = KMeans(n_clusters=self.num_bucket, random_state=3407, verbose=True, tol=1e-3).fit(ratio_list.reshape(-1, 1))
147
- labels = kmeans.labels_
148
- ratios = 2**kmeans.cluster_centers_.reshape(-1)
149
-
150
- h_all = np.sqrt(self.target_area/ratios)
151
- w_all = h_all*ratios
152
-
153
- # SD需要边长是8的倍数
154
- h_all = (np.round(h_all/self.step_size)*self.step_size).astype(int)
155
- w_all = (np.round(w_all/self.step_size)*self.step_size).astype(int)
156
- self.size_buckets = list(zip(w_all, h_all))
157
- self.size_buckets = np.array(self.size_buckets)
158
-
159
- self.buckets = [] # [bucket_id:[file_idx,...]]
160
- self.idx_bucket_map = np.empty(len(self.file_names), dtype=int)
161
- for bidx in range(self.num_bucket):
162
- bnow = labels == bidx
163
- self.buckets.append(np.where(bnow)[0].tolist())
164
- self.idx_bucket_map[bnow] = bidx
165
- logger.info('buckets info: '+', '.join(f'size:{self.size_buckets[i]}, num:{len(b)}' for i, b in enumerate(self.buckets)))
166
-
167
- def build(self, bs: int, file_names: List[Tuple[str, DataSource]]):
168
- '''
169
- :param bs: batch_size * n_gpus * accumulation_step
170
- :param img_root_list:
171
- '''
172
- self.file_names = file_names
173
- self.bs = bs
174
- if self.pre_build_bucket and os.path.exists(self.pre_build_bucket):
175
- self.load_bucket(self.pre_build_bucket)
176
- return
177
-
178
- self._build()
179
-
180
- rs = np.random.RandomState(42)
181
- # make len(bucket)%bs==0
182
- self.data_len = 0
183
- for bidx, bucket in enumerate(self.buckets):
184
- rest = len(bucket)%bs
185
- if rest>0:
186
- bucket.extend(rs.choice(bucket, bs-rest))
187
- self.data_len += len(bucket)
188
- self.buckets[bidx] = np.array(bucket)
189
-
190
- if self.pre_build_bucket:
191
- self.save_bucket(self.pre_build_bucket)
192
-
193
- def rest(self, epoch):
194
- rs = np.random.RandomState(42+epoch)
195
- bucket_list = [x.copy() for x in self.buckets]
196
- # shuffle inter bucket
197
- for x in bucket_list:
198
- rs.shuffle(x)
199
-
200
- # shuffle of batches
201
- bucket_list = np.hstack(bucket_list).reshape(-1, self.bs).astype(int)
202
- rs.shuffle(bucket_list)
203
-
204
- self.idx_bucket = bucket_list.reshape(-1)
205
-
206
- def crop_resize(self, image, size, mask_interp=cv2.INTER_CUBIC):
207
- return resize_crop_fix(image, size, mask_interp=mask_interp)
208
-
209
- def __getitem__(self, idx):
210
- file_idx = self.idx_bucket[idx]
211
- bucket_idx = self.idx_bucket_map[file_idx]
212
- return self.file_names[file_idx], self.size_buckets[bucket_idx]
213
-
214
- def __len__(self):
215
- return self.data_len
216
-
217
- @classmethod
218
- def from_ratios(cls, target_area: int = 640*640, step_size: int = 8, num_bucket: int = 10, ratio_max: float = 4,
219
- pre_build_bucket: str = None, **kwargs):
220
- arb = cls(target_area, step_size, num_bucket, pre_build_bucket=pre_build_bucket)
221
- arb.ratio_max = ratio_max
222
- arb._build = arb.build_buckets_from_ratios
223
- return arb
224
-
225
- @classmethod
226
- def from_files(cls, target_area: int = 640*640, step_size: int = 8, num_bucket: int = 10, pre_build_bucket: str = None, **kwargs):
227
- arb = cls(target_area, step_size, num_bucket, pre_build_bucket=pre_build_bucket)
228
- arb._build = arb.build_buckets_from_images
229
- return arb
230
-
231
- class SizeBucket(RatioBucket):
232
- def __init__(self, step_size: int = 8, num_bucket: int = 10, pre_build_bucket: str = None):
233
- super().__init__(step_size=step_size, num_bucket=num_bucket, pre_build_bucket=pre_build_bucket)
234
-
235
- def build_buckets_from_images(self):
236
- '''
237
- 根据图像尺寸聚类,不会resize图像,只有剪裁和填充操作。
238
- '''
239
- logger.info('build buckets from images size')
240
- size_list = []
241
- for i, (file, source) in enumerate(self.file_names):
242
- w, h = get_image_size(file)
243
- size_list.append([w, h])
244
- size_list = np.array(size_list)
245
-
246
- # 聚类,选出指定个数的bucket
247
- kmeans = KMeans(n_clusters=self.num_bucket, random_state=3407).fit(size_list)
248
- labels = kmeans.labels_
249
- size_buckets = kmeans.cluster_centers_
250
-
251
- # SD需要边长是8的倍数
252
- self.size_buckets = (np.round(size_buckets/self.step_size)*self.step_size).astype(int)
253
-
254
- self.buckets = [] # [bucket_id:[file_idx,...]]
255
- self.idx_bucket_map = np.empty(len(self.file_names), dtype=int)
256
- for bidx in range(self.num_bucket):
257
- bnow = labels == bidx
258
- self.buckets.append(np.where(bnow)[0].tolist())
259
- self.idx_bucket_map[bnow] = bidx
260
- logger.info('buckets info: '+', '.join(f'size:{self.size_buckets[i]}, num:{len(b)}' for i, b in enumerate(self.buckets)))
261
-
262
- def crop_resize(self, image, size):
263
- return pad_crop_fix(image, size)
264
-
265
- @classmethod
266
- def from_files(cls, step_size: int = 8, num_bucket: int = 10, pre_build_bucket: str = None, **kwargs):
267
- arb = cls(step_size, num_bucket, pre_build_bucket=pre_build_bucket)
268
- arb._build = arb.build_buckets_from_images
269
- return arb
270
-
271
- class RatioSizeBucket(RatioBucket):
272
- def __init__(self, step_size: int = 8, num_bucket: int = 10, max_area:int=640*640, pre_build_bucket: str = None):
273
- super().__init__(step_size=step_size, num_bucket=num_bucket, pre_build_bucket=pre_build_bucket)
274
- self.max_area = max_area
275
-
276
- def build_buckets_from_images(self):
277
- '''
278
- 根据图像尺寸聚类,不会resize图像,只有剪裁和填充操作。
279
- '''
280
- logger.info('build buckets from images')
281
- ratio_list = []
282
- for i, (file, source) in enumerate(self.file_names):
283
- w, h = get_image_size(file)
284
- ratio = np.log2(w/h)
285
- log_area = np.log2(min(w*h, self.max_area))
286
- ratio_list.append([ratio, log_area])
287
- ratio_list = np.array(ratio_list)
288
-
289
- # 聚类,选出指定个数的bucket
290
- kmeans = KMeans(n_clusters=self.num_bucket, random_state=3407).fit(ratio_list)
291
- labels = kmeans.labels_
292
- ratios = 2**kmeans.cluster_centers_[:, 0]
293
- sizes = 2**kmeans.cluster_centers_[:, 1]
294
-
295
- h_all = np.sqrt(sizes/ratios)
296
- w_all = h_all*ratios
297
-
298
- # SD需要边长是8的倍数
299
- h_all = (np.round(h_all/self.step_size)*self.step_size).astype(int)
300
- w_all = (np.round(w_all/self.step_size)*self.step_size).astype(int)
301
- self.size_buckets = list(zip(w_all, h_all))
302
- self.size_buckets = np.array(self.size_buckets)
303
-
304
- self.buckets = [] # [bucket_id:[file_idx,...]]
305
- self.idx_bucket_map = np.empty(len(self.file_names), dtype=int)
306
- for bidx in range(self.num_bucket):
307
- bnow = labels == bidx
308
- self.buckets.append(np.where(bnow)[0].tolist())
309
- self.idx_bucket_map[bnow] = bidx
310
- logger.info('buckets info: '+', '.join(f'size:{self.size_buckets[i]}, num:{len(b)}' for i, b in enumerate(self.buckets)))
311
-
312
- @classmethod
313
- def from_files(cls, step_size: int = 8, num_bucket: int = 10, max_area:int=640*640, pre_build_bucket: str = None, **kwargs):
314
- arb = cls(step_size, num_bucket, max_area=max_area, pre_build_bucket=pre_build_bucket)
315
- arb._build = arb.build_buckets_from_images
316
- return arb
317
-
318
- class LongEdgeBucket(RatioBucket):
319
- def __init__(self, target_edge=640, step_size: int = 8, num_bucket: int = 10, pre_build_bucket: str = None):
320
- super().__init__(step_size=step_size, num_bucket=num_bucket, pre_build_bucket=pre_build_bucket)
321
- self.target_edge = target_edge
322
-
323
- def build_buckets_from_images(self):
324
- '''
325
- 根据图像尺寸聚类,不会resize图像,只有剪裁和填充操作。
326
- '''
327
- logger.info('build buckets from images size')
328
- size_list = []
329
- for i, (file, source) in enumerate(self.file_names):
330
- w, h = get_image_size(file)
331
- scale = self.target_edge/max(w, h)
332
- size_list.append([round(w*scale), round(h*scale)])
333
- size_list = np.array(size_list)
334
-
335
- # 聚类,选出指定个数的bucket
336
- kmeans = KMeans(n_clusters=self.num_bucket, random_state=3407, verbose=True).fit(size_list)
337
- labels = kmeans.labels_
338
- size_buckets = kmeans.cluster_centers_
339
-
340
- # SD需要边长是8的倍数
341
- self.size_buckets = (np.round(size_buckets/self.step_size)*self.step_size).astype(int)
342
-
343
- self.buckets = [] # [bucket_id:[file_idx,...]]
344
- self.idx_bucket_map = np.empty(len(self.file_names), dtype=int)
345
- for bidx in range(self.num_bucket):
346
- bnow = labels == bidx
347
- self.buckets.append(np.where(bnow)[0].tolist())
348
- self.idx_bucket_map[bnow] = bidx
349
- logger.info('buckets info: '+', '.join(f'size:{self.size_buckets[i]}, num:{len(b)}' for i, b in enumerate(self.buckets)))
350
-
351
- def crop_resize(self, image, size):
352
- return resize_crop_fix(image, size)
353
-
354
- @classmethod
355
- def from_files(cls, target_edge, step_size: int = 8, num_bucket: int = 10, pre_build_bucket: str = None, **kwargs):
356
- arb = cls(target_edge, step_size, num_bucket, pre_build_bucket=pre_build_bucket)
357
- arb._build = arb.build_buckets_from_images
358
- return arb
@@ -1,80 +0,0 @@
1
- import json
2
- import os
3
- import glob
4
- import yaml
5
- from typing import Dict
6
-
7
- from loguru import logger
8
- from hcpdiff.utils.img_size_tool import types_support
9
- import os
10
-
11
- class BaseCaptionLoader:
12
- def __init__(self, path):
13
- self.path = path
14
-
15
- def _load(self):
16
- raise NotImplementedError
17
-
18
- def load(self):
19
- retval = self._load()
20
- logger.info(f'{len(retval)} record(s) loaded with {self.__class__.__name__}, from path {self.path!r}')
21
- return retval
22
-
23
- @staticmethod
24
- def clean_ext(captions:Dict[str, str]):
25
- def rm_ext(path):
26
- name, ext = os.path.splitext(path)
27
- if len(ext)>0 and ext[1:] in types_support:
28
- return name
29
- return path
30
- return {rm_ext(k):v for k,v in captions.items()}
31
-
32
- class JsonCaptionLoader(BaseCaptionLoader):
33
- def _load(self):
34
- with open(self.path, 'r', encoding='utf-8') as f:
35
- return self.clean_ext(json.loads(f.read()))
36
-
37
- class YamlCaptionLoader(BaseCaptionLoader):
38
- def _load(self):
39
- with open(self.path, 'r', encoding='utf-8') as f:
40
- return self.clean_ext(yaml.load(f.read(), Loader=yaml.FullLoader))
41
-
42
- class TXTCaptionLoader(BaseCaptionLoader):
43
- def _load(self):
44
- txt_files = glob.glob(os.path.join(self.path, '*.txt'))
45
- captions = {}
46
- for file in txt_files:
47
- with open(file, 'r', encoding='utf-8') as f:
48
- captions[os.path.basename(file).split('.')[0]] = f.read().strip()
49
- return captions
50
-
51
- def auto_caption_loader(path):
52
- if os.path.isdir(path):
53
- json_files = glob.glob(os.path.join(path, '*.json'))
54
- if json_files:
55
- return JsonCaptionLoader(json_files[0])
56
-
57
- yaml_files = [
58
- *glob.glob(os.path.join(path, '*.yaml')),
59
- *glob.glob(os.path.join(path, '*.yml')),
60
- ]
61
- if yaml_files:
62
- return YamlCaptionLoader(yaml_files[0])
63
-
64
- txt_files = glob.glob(os.path.join(path, '*.txt'))
65
- if txt_files:
66
- return TXTCaptionLoader(path)
67
-
68
- raise FileNotFoundError(f'Caption file not found in directory {path!r}.')
69
-
70
- elif os.path.isfile(path):
71
- _, ext = os.path.splitext(path)
72
- if ext == '.json':
73
- return JsonCaptionLoader(path)
74
- elif ext in {'.yaml', '.yml'}:
75
- return YamlCaptionLoader(path)
76
- else:
77
- raise FileNotFoundError(f'Unknown caption file {path!r}.')
78
-
79
- else:
80
- raise FileNotFoundError(f'Unknown caption file type {path!r}.')
@@ -1,40 +0,0 @@
1
- """
2
- pair_dataset.py
3
- ====================
4
- :Name: text-image pair dataset
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 10/03/2023
8
- :Licence: Apache-2.0
9
- """
10
-
11
- import cv2
12
- import torch
13
-
14
- from .pair_dataset import TextImagePairDataset
15
-
16
- class TextImageCondPairDataset(TextImagePairDataset):
17
- """
18
- A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
19
- It pre-processes the images and the tokenizes prompts.
20
- """
21
-
22
- def load_data(self, path, data_source, size):
23
- image_dict = data_source.load_image(path)
24
- image = image_dict['image']
25
- att_mask = image_dict.get('att_mask', None)
26
- img_cond = image_dict.get('cond', None)
27
- if img_cond is None:
28
- raise FileNotFoundError(f'{self.__class__} need the condition images!')
29
-
30
- if att_mask is None:
31
- data, crop_coord = self.bucket.crop_resize({"img":image, "cond":img_cond}, size)
32
- image = data_source.procees_image(data['img']) # resize to bucket size
33
- img_cond = data_source.cond_transform(data['cond'])
34
- att_mask = torch.ones((size[1]//8, size[0]//8))
35
- else:
36
- data, crop_coord = self.bucket.crop_resize({"img":image, "mask":att_mask, "cond":img_cond}, size)
37
- image = data_source.procees_image(data['img'])
38
- img_cond = data_source.cond_transform(data['cond'])
39
- att_mask = torch.tensor(cv2.resize(data['mask'], (size[0]//8, size[1]//8), interpolation=cv2.INTER_LINEAR))
40
- return {'img':image, 'mask':att_mask, 'plugin_input':{"cond":img_cond}}
@@ -1,40 +0,0 @@
1
- """
2
- pair_dataset.py
3
- ====================
4
- :Name: text-image pair dataset
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 10/03/2023
8
- :Licence: Apache-2.0
9
- """
10
-
11
- from typing import Callable, Iterable, Dict
12
- from .bucket import BaseBucket
13
- import os.path
14
-
15
- import torch
16
- import cv2
17
- from .pair_dataset import TextImagePairDataset
18
- from hcpdiff.utils.utils import get_file_name
19
- from torchvision import transforms
20
-
21
- class CropInfoPairDataset(TextImagePairDataset):
22
- """
23
- A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
24
- It pre-processes the images and the tokenizes prompts.
25
- """
26
-
27
- def load_data(self, path, data_source, size):
28
- image_dict = data_source.load_image(path)
29
- image = image_dict['image']
30
- att_mask = image_dict.get('att_mask', None)
31
- if att_mask is None:
32
- data, crop_coord = self.bucket.crop_resize({"img":image}, size)
33
- image = data_source.procees_image(data['img']) # resize to bucket size
34
- att_mask = torch.ones((size[1]//8, size[0]//8))
35
- else:
36
- data, crop_coord = self.bucket.crop_resize({"img":image, "mask":att_mask}, size)
37
- image = data_source.procees_image(data['img'])
38
- att_mask = torch.tensor(cv2.resize(data['mask'], (size[0]//8, size[1]//8), interpolation=cv2.INTER_LINEAR))
39
- crop_info = torch.tensor(crop_coord, dtype=torch.float) # for sdxl
40
- return {'img':image, 'mask':att_mask, 'crop_info':crop_info}
@@ -1,33 +0,0 @@
1
- import numpy as np
2
- import torch
3
- from PIL import Image
4
- from diffusers.utils import PIL_INTERPOLATION
5
-
6
- class ControlNetProcessor:
7
- def __init__(self, image):
8
- self.image_path = image
9
-
10
- def prepare_cond_image(self, image, width, height, batch_size, device):
11
- if not isinstance(image, torch.Tensor):
12
- if isinstance(image, Image.Image):
13
- image = [image]
14
-
15
- if isinstance(image[0], Image.Image):
16
- image = [
17
- np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image
18
- ]
19
- image = np.concatenate(image, axis=0)
20
- image = np.array(image).astype(np.float32)/255.0
21
- image = image.transpose(0, 3, 1, 2)
22
- image = torch.from_numpy(image)
23
- elif isinstance(image[0], torch.Tensor):
24
- image = torch.cat(image, dim=0)
25
-
26
- image = image.repeat_interleave(batch_size, dim=0)
27
- image = image.to(device=device)
28
-
29
- return image
30
-
31
- def __call__(self, width, height, batch_size, device, dtype):
32
- img = Image.open(self.image_path).convert('RGB')
33
- return self.prepare_cond_image(img, width, height, batch_size, 'cuda').to(dtype=dtype)