hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +90 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +3 -3
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text.py +40 -0
  21. hcpdiff/data/source/text2img.py +36 -74
  22. hcpdiff/data/source/text2img_cond.py +9 -15
  23. hcpdiff/diffusion/__init__.py +0 -0
  24. hcpdiff/diffusion/noise/__init__.py +2 -0
  25. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  26. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  27. hcpdiff/diffusion/sampler/__init__.py +5 -0
  28. hcpdiff/diffusion/sampler/base.py +72 -0
  29. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  30. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  31. hcpdiff/diffusion/sampler/edm.py +22 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  35. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  36. hcpdiff/easy/__init__.py +2 -0
  37. hcpdiff/easy/cfg/__init__.py +3 -0
  38. hcpdiff/easy/cfg/sd15_train.py +207 -0
  39. hcpdiff/easy/cfg/sdxl_train.py +147 -0
  40. hcpdiff/easy/cfg/t2i.py +228 -0
  41. hcpdiff/easy/model/__init__.py +2 -0
  42. hcpdiff/easy/model/cnet.py +31 -0
  43. hcpdiff/easy/model/loader.py +79 -0
  44. hcpdiff/easy/sampler.py +46 -0
  45. hcpdiff/evaluate/__init__.py +1 -0
  46. hcpdiff/evaluate/previewer.py +60 -0
  47. hcpdiff/loss/__init__.py +4 -1
  48. hcpdiff/loss/base.py +41 -0
  49. hcpdiff/loss/gw.py +35 -0
  50. hcpdiff/loss/ssim.py +37 -0
  51. hcpdiff/loss/vlb.py +79 -0
  52. hcpdiff/loss/weighting.py +66 -0
  53. hcpdiff/models/__init__.py +2 -2
  54. hcpdiff/models/cfg_context.py +17 -14
  55. hcpdiff/models/compose/compose_hook.py +44 -23
  56. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  57. hcpdiff/models/compose/sdxl_composer.py +4 -4
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/init_proj.py +3 -21
  74. hcpdiff/tools/lora_convert.py +18 -17
  75. hcpdiff/tools/save_model.py +12 -0
  76. hcpdiff/tools/sd2diffusers.py +1 -1
  77. hcpdiff/train_colo.py +1 -1
  78. hcpdiff/train_deepspeed.py +1 -1
  79. hcpdiff/trainer_ac.py +79 -0
  80. hcpdiff/trainer_ac_single.py +31 -0
  81. hcpdiff/utils/__init__.py +0 -2
  82. hcpdiff/utils/inpaint_pipe.py +7 -2
  83. hcpdiff/utils/net_utils.py +29 -6
  84. hcpdiff/utils/pipe_hook.py +24 -7
  85. hcpdiff/utils/utils.py +21 -4
  86. hcpdiff/workflow/__init__.py +15 -10
  87. hcpdiff/workflow/daam/__init__.py +1 -0
  88. hcpdiff/workflow/daam/act.py +66 -0
  89. hcpdiff/workflow/daam/hook.py +109 -0
  90. hcpdiff/workflow/diffusion.py +118 -128
  91. hcpdiff/workflow/fast.py +31 -0
  92. hcpdiff/workflow/flow.py +67 -0
  93. hcpdiff/workflow/io.py +36 -130
  94. hcpdiff/workflow/model.py +46 -43
  95. hcpdiff/workflow/text.py +60 -47
  96. hcpdiff/workflow/utils.py +32 -12
  97. hcpdiff/workflow/vae.py +37 -38
  98. hcpdiff-2.2.dist-info/METADATA +299 -0
  99. hcpdiff-2.2.dist-info/RECORD +115 -0
  100. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
  101. hcpdiff-2.2.dist-info/entry_points.txt +5 -0
  102. hcpdiff/ckpt_manager/base.py +0 -16
  103. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  104. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  105. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  106. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  107. hcpdiff/data/bucket.py +0 -358
  108. hcpdiff/data/caption_loader.py +0 -80
  109. hcpdiff/data/cond_dataset.py +0 -40
  110. hcpdiff/data/crop_info_dataset.py +0 -40
  111. hcpdiff/data/data_processor.py +0 -33
  112. hcpdiff/data/pair_dataset.py +0 -146
  113. hcpdiff/data/sampler.py +0 -54
  114. hcpdiff/data/source/base.py +0 -30
  115. hcpdiff/data/utils.py +0 -80
  116. hcpdiff/deprecated/__init__.py +0 -1
  117. hcpdiff/deprecated/cfg_converter.py +0 -81
  118. hcpdiff/deprecated/lora_convert.py +0 -31
  119. hcpdiff/infer_workflow.py +0 -57
  120. hcpdiff/loggers/__init__.py +0 -13
  121. hcpdiff/loggers/base_logger.py +0 -76
  122. hcpdiff/loggers/cli_logger.py +0 -40
  123. hcpdiff/loggers/preview/__init__.py +0 -1
  124. hcpdiff/loggers/preview/image_previewer.py +0 -149
  125. hcpdiff/loggers/tensorboard_logger.py +0 -30
  126. hcpdiff/loggers/wandb_logger.py +0 -31
  127. hcpdiff/loggers/webui_logger.py +0 -9
  128. hcpdiff/loss/min_snr_loss.py +0 -52
  129. hcpdiff/models/layers.py +0 -81
  130. hcpdiff/models/plugin.py +0 -348
  131. hcpdiff/models/wrapper.py +0 -75
  132. hcpdiff/noise/__init__.py +0 -3
  133. hcpdiff/noise/noise_base.py +0 -16
  134. hcpdiff/noise/pyramid_noise.py +0 -50
  135. hcpdiff/noise/zero_terminal.py +0 -44
  136. hcpdiff/train_ac.py +0 -566
  137. hcpdiff/train_ac_single.py +0 -39
  138. hcpdiff/utils/caption_tools.py +0 -105
  139. hcpdiff/utils/cfg_net_tools.py +0 -321
  140. hcpdiff/utils/cfg_resolvers.py +0 -16
  141. hcpdiff/utils/ema.py +0 -52
  142. hcpdiff/utils/img_size_tool.py +0 -248
  143. hcpdiff/vis/__init__.py +0 -3
  144. hcpdiff/vis/base_interface.py +0 -12
  145. hcpdiff/vis/disk_interface.py +0 -48
  146. hcpdiff/vis/webui_interface.py +0 -17
  147. hcpdiff/viser_fast.py +0 -138
  148. hcpdiff/visualizer.py +0 -265
  149. hcpdiff/visualizer_reloadable.py +0 -237
  150. hcpdiff/workflow/base.py +0 -59
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  198. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  206. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  207. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  208. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  209. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
  211. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,18 @@
1
+ import torchvision.transforms as T
2
+ from PIL import Image
3
+ from rainbowneko.data import DataHandler, HandlerChain, LoadImageHandler, ImageHandler
4
+
5
+ class ControlNetHandler(DataHandler):
6
+ def __init__(self, key_map_in=('cond -> image',), key_map_out=('image -> cond',), bucket=None):
7
+ super().__init__(key_map_in, key_map_out)
8
+
9
+ self.handlers = HandlerChain(
10
+ load=LoadImageHandler(),
11
+ bucket=bucket.handler if bucket else DataHandler(),
12
+ image=ImageHandler(
13
+ transform=T.ToTensor(),
14
+ )
15
+ )
16
+
17
+ def handle(self, image:Image.Image):
18
+ return self.handlers(dict(image=image))
@@ -0,0 +1,90 @@
1
+ from typing import Union, Dict, Any
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchvision.transforms as T
6
+ from PIL import Image
7
+ from rainbowneko.data import DataHandler, HandlerChain, LoadImageHandler, ImageHandler
8
+
9
+ from .text import TemplateFillHandler, TagDropoutHandler, TagEraseHandler, TagShuffleHandler, TokenizeHandler
10
+
11
+ class LossMapHandler(DataHandler):
12
+ def __init__(self, bucket, vae_scale=8, key_map_in=('loss_map -> image', 'image_size -> image_size'),
13
+ key_map_out=('image -> loss_map', 'coord -> coord')):
14
+ super().__init__(key_map_in, key_map_out)
15
+ self.vae_scale = vae_scale
16
+
17
+ self.handlers = HandlerChain(
18
+ load=LoadImageHandler(mode='L'),
19
+ bucket=bucket.handler,
20
+ image=ImageHandler(transform=T.Compose([
21
+ lambda x:x.resize((x.size[0]//self.vae_scale, x.size[1]//self.vae_scale), Image.BILINEAR),
22
+ T.ToTensor()
23
+ ]), )
24
+ )
25
+
26
+ def handle(self, image: Union[Image.Image, str], image_size: np.ndarray[int]):
27
+ data = self.handlers(dict(image=image, image_size=image_size))
28
+ image = data['image']
29
+ image[image<=0.5] *= 2
30
+ image[image>0.5] = (image[image>0.5]-0.5)*4+1
31
+ return self.handlers(dict(**data, image=image))
32
+
33
+ class DiffusionImageHandler(DataHandler):
34
+ def __init__(self, bucket, key_map_in=('image -> image', 'image_size -> image_size'), key_map_out=('image -> image', 'coord -> coord')):
35
+ super().__init__(key_map_in, key_map_out)
36
+
37
+ self.handlers = HandlerChain(
38
+ load=LoadImageHandler(),
39
+ bucket=bucket.handler,
40
+ image=ImageHandler(transform=T.Compose([
41
+ T.ToTensor(),
42
+ T.Normalize([0.5], [0.5])
43
+ ]), )
44
+ )
45
+
46
+ def handle(self, image: Image.Image, image_size: np.ndarray[int]):
47
+ if isinstance(image, torch.Tensor): # cached latents
48
+ return dict(image=image, image_size=image_size)
49
+ else:
50
+ return self.handlers(dict(image=image, image_size=image_size))
51
+
52
+ class DiffusionTextHandler(DataHandler):
53
+ def __init__(self, encoder_attention_mask=False, erase=0.0, dropout=0.0, shuffle=0.0, word_names={}, tokenize=True,
54
+ key_map_in=('prompt -> prompt', ), key_map_out=('prompt -> prompt', )):
55
+ super().__init__(key_map_in, key_map_out)
56
+
57
+ text_handlers = {}
58
+ if dropout>0:
59
+ text_handlers['dropout'] = TagDropoutHandler(p=dropout)
60
+ if erase>0:
61
+ text_handlers['erase'] = TagEraseHandler(p=erase)
62
+ if shuffle>0:
63
+ text_handlers['shuffle'] = TagShuffleHandler()
64
+ text_handlers['fill'] = TemplateFillHandler(word_names)
65
+ if tokenize:
66
+ text_handlers['tokenize'] = TokenizeHandler(encoder_attention_mask)
67
+ self.handlers = HandlerChain(**text_handlers)
68
+
69
+ def handle(self, prompt: Union[str, Dict[str, str]]):
70
+ return self.handlers(dict(prompt=prompt))
71
+
72
+ class StableDiffusionHandler(DataHandler):
73
+ def __init__(self, bucket, encoder_attention_mask=False, key_map_in=('image -> image', 'image_size -> image_size', 'prompt -> prompt'),
74
+ key_map_out=('image -> image', 'coord -> coord', 'prompt -> prompt'),
75
+ erase=0.0, dropout=0.0, shuffle=0.0, word_names={}, tokenize=True):
76
+ super().__init__(key_map_in, key_map_out)
77
+
78
+ self.image_handlers = DiffusionImageHandler(bucket)
79
+ self.text_handlers = DiffusionTextHandler(encoder_attention_mask=encoder_attention_mask, erase=erase, dropout=dropout, shuffle=shuffle,
80
+ word_names=word_names, tokenize=tokenize)
81
+
82
+ def handle(self, image: Image.Image, image_size: np.ndarray[int], prompt: str):
83
+ return dict(**self.image_handlers(dict(image=image, image_size=image_size)), **self.text_handlers(dict(prompt=prompt)))
84
+
85
+ def __call__(self, data) -> Dict[str, Any]:
86
+ data_proc = self.handle(**self.key_mapper_in.map_data(data)[1])
87
+ out_data = self.key_mapper_out.map_data(data_proc)[1]
88
+ data = dict(**data)
89
+ data.update(out_data)
90
+ return data
@@ -0,0 +1,111 @@
1
+ import random
2
+ from typing import Dict, Union, List
3
+
4
+ import numpy as np
5
+ from string import Formatter
6
+ from rainbowneko.data import DataHandler
7
+ from rainbowneko._share import register_model_callback
8
+
9
+ class TagShuffleHandler(DataHandler):
10
+ def __init__(self, key_map_in=('prompt -> prompt',), key_map_out=('prompt -> prompt',)):
11
+ super().__init__(key_map_in, key_map_out)
12
+
13
+ def handle(self, prompt: Union[Dict[str, str], str]):
14
+ if isinstance(prompt, str):
15
+ tags = prompt.split(',')
16
+ random.shuffle(tags)
17
+ prompt = ','.join(tags)
18
+ else:
19
+ tags = prompt['caption'].split(',')
20
+ random.shuffle(tags)
21
+ prompt['caption'] = ','.join(tags)
22
+ return {'prompt':prompt}
23
+
24
+ def __repr__(self):
25
+ return 'TagShuffleHandler()'
26
+
27
+ class TagDropoutHandler(DataHandler):
28
+ def __init__(self, p=0.1, key_map_in=('prompt -> prompt',), key_map_out=('prompt -> prompt',)):
29
+ super().__init__(key_map_in, key_map_out)
30
+ self.p = p
31
+
32
+ def handle(self, prompt: Union[Dict[str, str], str]):
33
+ if isinstance(prompt, str):
34
+ tags = np.array(prompt.split(','))
35
+ prompt = ','.join(tags[np.random.random(len(tags))>self.p])
36
+ else:
37
+ tags = prompt['caption'].split(',')
38
+ prompt['caption'] = ','.join(tags[np.random.random(len(tags))>self.p])
39
+ return {'prompt':prompt}
40
+
41
+ def __repr__(self):
42
+ return f'TagDropoutHandler(p={self.p})'
43
+
44
+ class TagEraseHandler(DataHandler):
45
+ def __init__(self, p=0.1, key_map_in=('prompt -> prompt',), key_map_out=('prompt -> prompt',)):
46
+ super().__init__(key_map_in, key_map_out)
47
+ self.p = p
48
+
49
+ def handle(self, prompt):
50
+ if isinstance(prompt, str):
51
+ if random.random()<self.p:
52
+ prompt = ''
53
+ else:
54
+ if random.random()<self.p:
55
+ prompt['caption'] = ''
56
+ return {'prompt':prompt}
57
+
58
+ def __repr__(self):
59
+ return f'TagEraseHandler(p={self.p})'
60
+
61
+
62
+ class TemplateFillHandler(DataHandler):
63
+ def __init__(self, word_names: Dict[str, str], key_map_in=('prompt -> prompt',), key_map_out=('prompt -> prompt',)):
64
+ super().__init__(key_map_in, key_map_out)
65
+ self.word_names = word_names
66
+
67
+ def handle(self, prompt):
68
+ template, caption = prompt['template'], prompt['caption']
69
+
70
+ keys_need = {i[1] for i in Formatter().parse(template) if i[1] is not None}
71
+ fill_dict = {k: v for k, v in self.word_names.items() if k in keys_need}
72
+
73
+ if (caption is not None) and ('caption' in keys_need):
74
+ fill_dict.update(caption=fill_dict.get('caption', None) or caption)
75
+
76
+ # skip keys that not provide
77
+ for k in keys_need:
78
+ if k not in fill_dict:
79
+ fill_dict[k] = ''
80
+
81
+ # replace None value with ''
82
+ fill_dict = {k:(v or '') for k, v in fill_dict.items()}
83
+ return {'prompt':template.format(**fill_dict)}
84
+
85
+ def __repr__(self):
86
+ return f'TemplateFill(\nword_names={self.word_names}\n)'
87
+
88
+ class TokenizeHandler(DataHandler):
89
+ def __init__(self, encoder_attention_mask=False, key_map_in=('prompt -> prompt',), key_map_out=None):
90
+ super().__init__(key_map_in, key_map_out)
91
+ self.encoder_attention_mask = encoder_attention_mask
92
+
93
+ register_model_callback(self.acquire_tokenizer)
94
+
95
+ def acquire_tokenizer(self, model_wrapper):
96
+ self.tokenizer = model_wrapper.tokenizer
97
+
98
+ def handle(self, prompt):
99
+ token_info = self.tokenizer(prompt, truncation=True, padding="max_length", return_tensors="pt",
100
+ max_length=self.tokenizer.model_max_length*self.tokenizer.N_repeats)
101
+ tokens = token_info.input_ids.squeeze()
102
+ data = {'prompt':tokens}
103
+ if self.encoder_attention_mask and 'attention_mask' in token_info:
104
+ data['attn_mask'] = token_info.attention_mask.squeeze()
105
+ if 'position_ids' in token_info:
106
+ data['position_ids'] = token_info.position_ids.squeeze()
107
+
108
+ return data
109
+
110
+ def __repr__(self):
111
+ return f'TokenizeHandler(\nencoder_attention_mask={self.encoder_attention_mask}, tokenizer={self.tokenizer}\n)'
@@ -1,4 +1,4 @@
1
- from .base import DataSource, ComposeDataSource
2
- from .text2img import Text2ImageSource, Text2ImageAttMapSource
1
+ from .text2img import Text2ImageSource, Text2ImageLossMapSource
3
2
  from .text2img_cond import Text2ImageCondSource
4
- from .folder_class import T2IFolderClassSource
3
+ from .folder_class import T2IFolderClassSource
4
+ from .text import TextSource
@@ -1,40 +1,23 @@
1
- import os
2
- from typing import List, Tuple, Union
3
- from hcpdiff.utils.utils import get_file_name, get_file_ext
4
- from hcpdiff.utils.img_size_tool import types_support
5
- from .text2img import Text2ImageAttMapSource
6
- from hcpdiff.data.caption_loader import BaseCaptionLoader, auto_caption_loader
7
1
  from copy import copy
2
+ from typing import Union
8
3
 
9
- class T2IFolderClassSource(Text2ImageAttMapSource):
4
+ from rainbowneko.data.label_loader import BaseLabelLoader, auto_label_loader
10
5
 
11
- def get_image_list(self) -> List[Tuple[str, "T2IFolderClassSource"]]:
12
- sub_folders = [os.path.join(self.img_root, x) for x in os.listdir(self.img_root)]
13
- class_imgs = []
14
- for class_folder in sub_folders:
15
- class_name = os.path.basename(class_folder)
16
- imgs = [(os.path.join(class_folder, x), self) for x in os.listdir(class_folder) if get_file_ext(x) in types_support]
17
- class_imgs.extend(imgs*self.repeat[class_name])
18
- return class_imgs
6
+ from .text2img import Text2ImageLossMapSource
19
7
 
20
- def load_captions(self, caption_file: Union[str, BaseCaptionLoader]):
21
- if caption_file is None:
8
+ class T2IFolderClassSource(Text2ImageLossMapSource):
9
+ def _load_label_data(self, label_file: Union[str, BaseLabelLoader]):
10
+ ''' {class_name/image.ext: label} '''
11
+ if label_file is None:
22
12
  return {}
23
- elif isinstance(caption_file, str):
13
+ elif isinstance(label_file, str):
24
14
  captions = {}
25
- caption_loader = auto_caption_loader(caption_file)
26
- for class_name in os.listdir(caption_loader.path):
27
- class_folder = os.path.join(caption_loader.path, class_name)
15
+ caption_loader = auto_label_loader(label_file)
16
+ for class_folder in caption_loader.path.iterdir():
28
17
  caption_loader_class = copy(caption_loader)
29
18
  caption_loader_class.path = class_folder
30
- captions_class = {f'{class_name}/{name}':caption for name, caption in caption_loader_class.load().item()}
19
+ captions_class = {f'{class_folder.name}/{name}':caption for name, caption in caption_loader_class.load().item()}
31
20
  captions.update(captions_class)
32
21
  return captions
33
22
  else:
34
- return caption_file.load()
35
-
36
- def get_image_name(self, path: str) -> str:
37
- img_root, img_name = os.path.split(path)
38
- img_name = img_name.rsplit('.')[0]
39
- img_root, class_name = os.path.split(img_root)
40
- return f'{class_name}/{img_name}'
23
+ return label_file.load()
@@ -0,0 +1,40 @@
1
+ from rainbowneko.data import UnLabelSource, DataSource
2
+ from rainbowneko.data.label_loader import BaseLabelLoader, auto_label_loader
3
+ from typing import Union, Dict, Any
4
+ import random
5
+
6
+ class TextSource(DataSource):
7
+ def __init__(self, label_file, prompt_template=None, repeat=1, **kwargs):
8
+ super().__init__(repeat=repeat)
9
+ self.label_file = label_file
10
+ self.label_dict = self._load_label_data(label_file)
11
+ self.img_ids = self._load_img_ids(self.label_dict)
12
+ self.prompt_template = self.load_template(prompt_template)
13
+
14
+ def _load_img_ids(self, label_dict):
15
+ return list(label_dict.keys()) * self.repeat
16
+
17
+ def _load_label_data(self, label_file: Union[str, BaseLabelLoader]):
18
+ if label_file is None:
19
+ return {}
20
+ elif isinstance(label_file, str):
21
+ return auto_label_loader(label_file).load()
22
+ else:
23
+ return label_file.load()
24
+
25
+ def load_template(self, template_file):
26
+ if template_file is None:
27
+ return ['{caption}']
28
+ else:
29
+ with open(template_file, 'r', encoding='utf-8') as f:
30
+ return f.read().strip().split('\n')
31
+
32
+ def __getitem__(self, index) -> Dict[str, Any]:
33
+ img_name = self.img_ids[index]
34
+ return {
35
+ 'id':img_name,
36
+ 'prompt':{
37
+ 'template':random.choice(self.prompt_template),
38
+ 'caption':self.label_dict[img_name],
39
+ }
40
+ }
@@ -1,13 +1,11 @@
1
- from .base import DataSource
2
- from hcpdiff.data.caption_loader import BaseCaptionLoader, auto_caption_loader
3
- from typing import Union, Any
4
1
  import os
5
- from hcpdiff.utils.utils import get_file_name, get_file_ext
6
- from hcpdiff.utils.img_size_tool import types_support
7
- from typing import Dict, List, Tuple
8
- from PIL import Image
9
- import numpy as np
10
2
  import random
3
+ from pathlib import Path
4
+ from typing import Any
5
+ from typing import Dict
6
+
7
+ from rainbowneko.data import ImageLabelSource
8
+ from rainbowneko.utils.utils import is_image_file
11
9
  from torchvision.transforms import transforms
12
10
 
13
11
  default_image_transforms = transforms.Compose([
@@ -15,77 +13,41 @@ default_image_transforms = transforms.Compose([
15
13
  transforms.Normalize([0.5], [0.5])
16
14
  ])
17
15
 
18
- class Text2ImageSource(DataSource):
19
- def __init__(self, img_root, caption_file, prompt_template, text_transforms, image_transforms=default_image_transforms,
20
- bg_color=(255,255,255), repeat=1, **kwargs):
21
- super(Text2ImageSource, self).__init__(img_root, repeat=repeat)
16
+ class Text2ImageSource(ImageLabelSource):
17
+ def __init__(self, img_root, label_file, prompt_template, repeat=1, **kwargs):
18
+ super().__init__(img_root, label_file, repeat=repeat)
22
19
 
23
- self.caption_dict = self.load_captions(caption_file)
24
20
  self.prompt_template = self.load_template(prompt_template)
25
- self.image_transforms = image_transforms
26
- self.text_transforms = text_transforms
27
- self.bg_color = tuple(bg_color)
28
-
29
- def load_captions(self, caption_file: Union[str, BaseCaptionLoader]):
30
- if caption_file is None:
31
- return {}
32
- elif isinstance(caption_file, str):
33
- return auto_caption_loader(caption_file).load()
34
- else:
35
- return caption_file.load()
36
21
 
37
22
  def load_template(self, template_file):
38
23
  with open(template_file, 'r', encoding='utf-8') as f:
39
24
  return f.read().strip().split('\n')
40
25
 
41
- def get_image_list(self) -> List[Tuple[str, DataSource]]:
42
- imgs = [(os.path.join(self.img_root, x), self) for x in os.listdir(self.img_root) if get_file_ext(x) in types_support]
43
- return imgs*self.repeat
44
-
45
- def procees_image(self, image):
46
- return self.image_transforms(image)
47
-
48
- def process_text(self, text_dict):
49
- return self.text_transforms(text_dict)
50
-
51
- def load_image(self, path) -> Dict[str, Any]:
52
- image = Image.open(path)
53
- if image.mode == 'RGBA':
54
- x, y = image.size
55
- canvas = Image.new('RGBA', image.size, self.bg_color)
56
- canvas.paste(image, (0, 0, x, y), image)
57
- image = canvas
58
- return {'image': image.convert("RGB")}
59
-
60
- def load_caption(self, img_name) -> str:
61
- caption_ist = self.caption_dict.get(img_name, None)
62
- prompt_template = random.choice(self.prompt_template)
63
- prompt_ist = self.process_text({'prompt':prompt_template, 'caption':caption_ist})['prompt']
64
- return prompt_ist
65
-
66
- class Text2ImageAttMapSource(Text2ImageSource):
67
- def __init__(self, img_root, caption_file, prompt_template, text_transforms, image_transforms=default_image_transforms, att_mask=None,
68
- bg_color=(255, 255, 255), repeat=1, **kwargs):
69
- super().__init__(img_root, caption_file, prompt_template, image_transforms=image_transforms, text_transforms=text_transforms,
70
- bg_color=bg_color, repeat=repeat)
71
-
72
- if att_mask is None:
73
- self.att_mask = {}
26
+ def __getitem__(self, index) -> Dict[str, Any]:
27
+ img_name = self.img_ids[index]
28
+ path = self.img_root / img_name
29
+
30
+ return {
31
+ 'id':img_name,
32
+ 'image':path,
33
+ 'prompt':{
34
+ 'template':random.choice(self.prompt_template),
35
+ 'caption':self.label_dict.get(img_name, None),
36
+ }
37
+ }
38
+
39
+ class Text2ImageLossMapSource(Text2ImageSource):
40
+ def __init__(self, img_root, caption_file, prompt_template, loss_map=None, repeat=1, **kwargs):
41
+ super().__init__(img_root, caption_file, prompt_template, repeat=repeat)
42
+
43
+ if loss_map is None:
44
+ self.loss_map = {}
74
45
  else:
75
- self.att_mask = {get_file_name(file):os.path.join(att_mask, file)
76
- for file in os.listdir(att_mask) if get_file_ext(file) in types_support}
77
-
78
- def get_att_mask(self, img_name):
79
- if img_name not in self.att_mask:
80
- return None
81
- att_mask = Image.open(self.att_mask[img_name]).convert("L")
82
- np_mask = np.array(att_mask).astype(float)
83
- np_mask[np_mask<=127+0.1] = (np_mask[np_mask<=127+0.1]/127.)
84
- np_mask[np_mask>127] = ((np_mask[np_mask>127]-127)/128.)*4+1
85
- return np_mask
86
-
87
- def load_image(self, path) -> Dict[str, Any]:
88
- img_root, img_name = os.path.split(path)
89
- image_dict = super().load_image(path)
90
- image_dict['att_mask'] = self.get_att_mask(get_file_name(img_name))
91
- return image_dict
46
+ loss_map = Path(loss_map)
47
+ self.loss_map = {file.stem:loss_map/file for file in loss_map.iterdir() if is_image_file(file)}
48
+
49
+ def __getitem__(self, index) -> Dict[str, Any]:
50
+ data = super().__getitem__(index)
51
+ img_name = self.img_ids[index]
52
+ data['loss_map'] = self.loss_map[Path(img_name).stem]
53
+ return data
@@ -1,22 +1,16 @@
1
1
  import os
2
2
  from typing import Dict, Any
3
3
 
4
- from PIL import Image
5
- from torchvision import transforms
4
+ from .text2img import Text2ImageSource
6
5
 
7
- from .text2img import Text2ImageAttMapSource, default_image_transforms
8
-
9
- class Text2ImageCondSource(Text2ImageAttMapSource):
10
- def __init__(self, img_root, caption_file, prompt_template, text_transforms, image_transforms=default_image_transforms,
11
- bg_color=(255, 255, 255), repeat=1, cond_dir=None, **kwargs):
12
- super().__init__(img_root, caption_file, prompt_template, image_transforms=image_transforms, text_transforms=text_transforms,
13
- bg_color=bg_color, repeat=repeat)
14
- self.cond_transform = transforms.ToTensor()
6
+ class Text2ImageCondSource(Text2ImageSource):
7
+ def __init__(self, img_root, caption_file, prompt_template, repeat=1, cond_dir=None, **kwargs):
8
+ super().__init__(img_root, caption_file, prompt_template, repeat=repeat)
15
9
  self.cond_dir = cond_dir
16
10
 
17
- def load_image(self, path) -> Dict[str, Any]:
18
- img_root, img_name = os.path.split(path)
19
- image_dict = super().load_image(path)
11
+ def __getitem__(self, index) -> Dict[str, Any]:
12
+ data = super().__getitem__(index)
13
+ img_name = self.img_ids[index]
20
14
  cond_path = os.path.join(self.cond_dir, img_name)
21
- image_dict['cond'] = Image.open(cond_path).convert("RGB")
22
- return image_dict
15
+ data['cond'] = cond_path
16
+ return data
File without changes
@@ -0,0 +1,2 @@
1
+ from .pyramid_noise import PyramidNoiseSampler
2
+ from .zero_terminal import ZeroTerminalSampler
@@ -0,0 +1,42 @@
1
+ import random
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+ from hcpdiff.diffusion.sampler import BaseSampler
7
+
8
+ class PyramidNoiseSampler:
9
+ def __init__(self, level: int = 6, discount: float = 0.4, step_size: float = 2., resize_mode: str = 'bilinear'):
10
+ self.level = level
11
+ self.step_size = step_size
12
+ self.resize_mode = resize_mode
13
+ self.discount = discount
14
+
15
+ def make_nosie(self, shape, device='cuda', dtype=torch.float32):
16
+ noise = torch.randn(shape, device=device, dtype=dtype)
17
+ with torch.no_grad():
18
+ b, c, h, w = noise.shape
19
+ for i in range(1, self.level):
20
+ r = random.random()*2+self.step_size
21
+ wn, hn = max(1, int(w/(r**i))), max(1, int(h/(r**i)))
22
+ noise += F.interpolate(torch.randn(b, c, hn, wn).to(noise), (h, w), None, self.resize_mode)*(self.discount**i)
23
+ if wn == 1 or hn == 1:
24
+ break
25
+ noise = noise/noise.std()
26
+ return noise
27
+
28
+ @classmethod
29
+ def patch(cls, base_sampler: BaseSampler, level: int = 6, discount: float = 0.4, step_size: float = 2., resize_mode: str = 'bilinear'):
30
+ patcher = cls(level, discount, step_size, resize_mode)
31
+ base_sampler.make_nosie = patcher.make_nosie
32
+ return base_sampler
33
+
34
+ if __name__ == '__main__':
35
+ from hcpdiff.diffusion.sampler import EDM_DDPMSampler, DDPMContinuousSigmaScheduler
36
+ from matplotlib import pyplot as plt
37
+
38
+ sampler = PyramidNoiseSampler.patch(EDM_DDPMSampler(DDPMContinuousSigmaScheduler()))
39
+ noise = sampler.make_nosie((1,3,512,512), device='cpu')
40
+ plt.figure()
41
+ plt.imshow(noise[0].permute(1,2,0))
42
+ plt.show()
@@ -0,0 +1,39 @@
1
+ import torch
2
+ from ..sampler.sigma_scheduler import DDPMDiscreteSigmaScheduler
3
+
4
+ class ZeroTerminalSampler:
5
+
6
+ @classmethod
7
+ def patch(cls, base_sampler):
8
+ assert isinstance(base_sampler.sigma_scheduler, DDPMDiscreteSigmaScheduler), "ZeroTerminalScheduler only works with DDPM SigmaScheduler"
9
+
10
+ alphas_cumprod = base_sampler.sigma_scheduler.alphas_cumprod
11
+ base_sampler.sigma_scheduler.alphas_cumprod = cls.rescale_zero_terminal_snr(alphas_cumprod)
12
+ base_sampler.sigma_scheduler.sigmas = ((1-alphas_cumprod)/alphas_cumprod).sqrt()
13
+
14
+
15
+ @staticmethod
16
+ def rescale_zero_terminal_snr(alphas_cumprod, thr=1e-4):
17
+ """
18
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
19
+ Args:
20
+ alphas_cumprod (`torch.FloatTensor`)
21
+ Returns:
22
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
23
+ """
24
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
25
+
26
+ # Store old values.
27
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
28
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
29
+
30
+ # Shift so the last timestep is zero.
31
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
32
+
33
+ # Scale so the first timestep is back to the old value.
34
+ alphas_bar_sqrt *= alphas_bar_sqrt_0/(alphas_bar_sqrt_0-alphas_bar_sqrt_T)
35
+ alphas_bar_sqrt[-1] = thr # avoid nan sigma
36
+
37
+ # Convert alphas_bar_sqrt to betas
38
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
39
+ return alphas_bar
@@ -0,0 +1,5 @@
1
+ from .sigma_scheduler import *
2
+ from .base import BaseSampler
3
+ from .ddpm import DDPMSampler
4
+ from .edm import EDMSampler
5
+ from .diffusers import DiffusersSampler
@@ -0,0 +1,72 @@
1
+ from typing import Tuple
2
+ import torch
3
+ from .sigma_scheduler import SigmaScheduler
4
+ from diffusers import DDPMScheduler
5
+
6
+ class BaseSampler:
7
+ def __init__(self, sigma_scheduler: SigmaScheduler, generator: torch.Generator = None):
8
+ self.sigma_scheduler = sigma_scheduler
9
+ self.generator = generator
10
+
11
+ def c_in(self, sigma):
12
+ return 1
13
+
14
+ def c_out(self, sigma):
15
+ return 1
16
+
17
+ def c_skip(self, sigma):
18
+ return 1
19
+
20
+ @property
21
+ def num_timesteps(self):
22
+ return getattr(self.sigma_scheduler, 'num_timesteps', 1000.)
23
+
24
+ def get_timesteps(self, N_steps, device='cuda'):
25
+ return torch.linspace(0, self.num_timesteps, N_steps, device=device)
26
+
27
+ def make_nosie(self, shape, device='cuda', dtype=torch.float32):
28
+ return torch.randn(shape, generator=self.generator, device=device, dtype=dtype)
29
+
30
+ def init_noise(self, shape, device='cuda', dtype=torch.float32):
31
+ sigma = self.sigma_scheduler.sigma_max
32
+ return self.make_nosie(shape, device, dtype)*sigma
33
+
34
+ def add_noise(self, x, sigma) -> Tuple[torch.Tensor, torch.Tensor]:
35
+ noise = self.make_nosie(x.shape, device=x.device)
36
+ noisy_x = (x.to(dtype=torch.float32)-self.c_out(sigma)*noise)/self.c_skip(sigma)
37
+ return noisy_x.to(dtype=x.dtype), noise.to(dtype=x.dtype)
38
+
39
+ def add_noise_rand_t(self, x):
40
+ bs = x.shape[0]
41
+ # timesteps: [0, 1]
42
+ sigma, timesteps = self.sigma_scheduler.sample_sigma(shape=(bs,))
43
+ sigma = sigma.view(-1, 1, 1, 1).to(x.device)
44
+ timesteps = timesteps.to(x.device)
45
+ noisy_x, noise = self.add_noise(x, sigma)
46
+
47
+ # Sample a random timestep for each image
48
+ timesteps = timesteps*(self.num_timesteps-1)
49
+ return noisy_x, noise, sigma, timesteps
50
+
51
+ def denoise(self, x, sigma, eps=None, generator=None):
52
+ raise NotImplementedError
53
+
54
+ def eps_to_x0(self, eps, x_t, sigma):
55
+ return self.c_skip(sigma)*x_t+self.c_out(sigma)*eps
56
+
57
+ def velocity_to_eps(self, v_pred, x_t, sigma):
58
+ alpha = 1/(sigma**2+1)
59
+ sqrt_alpha = alpha.sqrt()
60
+ one_sqrt_alpha = (1-alpha).sqrt()
61
+ return sqrt_alpha*v_pred + one_sqrt_alpha*(x_t*sqrt_alpha)
62
+
63
+ def eps_to_velocity(self, eps, x_t, sigma):
64
+ alpha = 1/(sigma**2+1)
65
+ sqrt_alpha = alpha.sqrt()
66
+ one_sqrt_alpha = (1-alpha).sqrt()
67
+ return eps/sqrt_alpha - one_sqrt_alpha*x_t
68
+
69
+ def velocity_to_x0(self, v_pred, x_t, sigma):
70
+ alpha = 1/(sigma**2+1)
71
+ one_sqrt_alpha = (1-alpha).sqrt()
72
+ return alpha*x_t - one_sqrt_alpha*v_pred