hcpdiff 0.9.1__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (210) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/controlnet.py +16 -16
  58. hcpdiff/models/lora_base_patch.py +14 -25
  59. hcpdiff/models/lora_layers.py +3 -9
  60. hcpdiff/models/lora_layers_patch.py +14 -24
  61. hcpdiff/models/text_emb_ex.py +84 -6
  62. hcpdiff/models/textencoder_ex.py +54 -18
  63. hcpdiff/models/wrapper/__init__.py +3 -0
  64. hcpdiff/models/wrapper/pixart.py +19 -0
  65. hcpdiff/models/wrapper/sd.py +218 -0
  66. hcpdiff/models/wrapper/utils.py +20 -0
  67. hcpdiff/parser/__init__.py +1 -0
  68. hcpdiff/parser/embpt.py +32 -0
  69. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  70. hcpdiff/tools/dataset_generator.py +94 -0
  71. hcpdiff/tools/download_hf_model.py +24 -0
  72. hcpdiff/tools/init_proj.py +3 -21
  73. hcpdiff/tools/lora_convert.py +18 -17
  74. hcpdiff/tools/save_model.py +12 -0
  75. hcpdiff/tools/sd2diffusers.py +1 -1
  76. hcpdiff/train_colo.py +1 -1
  77. hcpdiff/train_deepspeed.py +1 -1
  78. hcpdiff/trainer_ac.py +79 -0
  79. hcpdiff/trainer_ac_single.py +31 -0
  80. hcpdiff/utils/__init__.py +0 -2
  81. hcpdiff/utils/inpaint_pipe.py +7 -2
  82. hcpdiff/utils/net_utils.py +29 -6
  83. hcpdiff/utils/pipe_hook.py +24 -7
  84. hcpdiff/utils/utils.py +21 -4
  85. hcpdiff/workflow/__init__.py +15 -10
  86. hcpdiff/workflow/daam/__init__.py +1 -0
  87. hcpdiff/workflow/daam/act.py +66 -0
  88. hcpdiff/workflow/daam/hook.py +109 -0
  89. hcpdiff/workflow/diffusion.py +114 -125
  90. hcpdiff/workflow/fast.py +31 -0
  91. hcpdiff/workflow/flow.py +67 -0
  92. hcpdiff/workflow/io.py +36 -130
  93. hcpdiff/workflow/model.py +46 -43
  94. hcpdiff/workflow/text.py +78 -46
  95. hcpdiff/workflow/utils.py +32 -12
  96. hcpdiff/workflow/vae.py +37 -38
  97. hcpdiff-2.1.dist-info/METADATA +285 -0
  98. hcpdiff-2.1.dist-info/RECORD +114 -0
  99. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  100. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  101. hcpdiff/ckpt_manager/base.py +0 -16
  102. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  103. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  104. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  105. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  106. hcpdiff/data/bucket.py +0 -358
  107. hcpdiff/data/caption_loader.py +0 -80
  108. hcpdiff/data/cond_dataset.py +0 -40
  109. hcpdiff/data/crop_info_dataset.py +0 -40
  110. hcpdiff/data/data_processor.py +0 -33
  111. hcpdiff/data/pair_dataset.py +0 -146
  112. hcpdiff/data/sampler.py +0 -54
  113. hcpdiff/data/source/base.py +0 -30
  114. hcpdiff/data/utils.py +0 -80
  115. hcpdiff/deprecated/__init__.py +0 -1
  116. hcpdiff/deprecated/cfg_converter.py +0 -81
  117. hcpdiff/deprecated/lora_convert.py +0 -31
  118. hcpdiff/infer_workflow.py +0 -57
  119. hcpdiff/loggers/__init__.py +0 -13
  120. hcpdiff/loggers/base_logger.py +0 -76
  121. hcpdiff/loggers/cli_logger.py +0 -40
  122. hcpdiff/loggers/preview/__init__.py +0 -1
  123. hcpdiff/loggers/preview/image_previewer.py +0 -149
  124. hcpdiff/loggers/tensorboard_logger.py +0 -30
  125. hcpdiff/loggers/wandb_logger.py +0 -31
  126. hcpdiff/loggers/webui_logger.py +0 -9
  127. hcpdiff/loss/min_snr_loss.py +0 -52
  128. hcpdiff/models/layers.py +0 -81
  129. hcpdiff/models/plugin.py +0 -348
  130. hcpdiff/models/wrapper.py +0 -75
  131. hcpdiff/noise/__init__.py +0 -3
  132. hcpdiff/noise/noise_base.py +0 -16
  133. hcpdiff/noise/pyramid_noise.py +0 -50
  134. hcpdiff/noise/zero_terminal.py +0 -44
  135. hcpdiff/train_ac.py +0 -566
  136. hcpdiff/train_ac_single.py +0 -39
  137. hcpdiff/utils/caption_tools.py +0 -105
  138. hcpdiff/utils/cfg_net_tools.py +0 -321
  139. hcpdiff/utils/cfg_resolvers.py +0 -16
  140. hcpdiff/utils/ema.py +0 -52
  141. hcpdiff/utils/img_size_tool.py +0 -248
  142. hcpdiff/vis/__init__.py +0 -3
  143. hcpdiff/vis/base_interface.py +0 -12
  144. hcpdiff/vis/disk_interface.py +0 -48
  145. hcpdiff/vis/webui_interface.py +0 -17
  146. hcpdiff/viser_fast.py +0 -138
  147. hcpdiff/visualizer.py +0 -265
  148. hcpdiff/visualizer_reloadable.py +0 -237
  149. hcpdiff/workflow/base.py +0 -59
  150. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  198. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  206. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  207. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  208. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  209. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
hcpdiff/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .train_ac import Trainer
2
- from .train_ac_single import TrainerSingleCard
3
- from .visualizer import Visualizer
4
- from .visualizer_reloadable import VisualizerReloadable
1
+ #from .train_ac_old import Trainer
2
+ #from .train_ac_single import TrainerSingleCard
3
+ # from .visualizer import Visualizer
4
+ # from .visualizer_reloadable import VisualizerReloadable
@@ -1,5 +1,4 @@
1
- from .ckpt_pkl import CkptManagerPKL
2
- from .ckpt_safetensor import CkptManagerSafe
3
-
4
- def auto_manager(ckpt_path:str):
5
- return CkptManagerSafe() if ckpt_path.endswith('.safetensors') else CkptManagerPKL()
1
+ from .format import EmbFormat, DiffusersSD15Format, DiffusersModelFormat, DiffusersSDXLFormat, DiffusersPixArtFormat, OfficialSDXLFormat, \
2
+ OfficialSD15Format
3
+ from .ckpt import EmbSaver, easy_emb_saver
4
+ from .loader import HCPLoraLoader
@@ -0,0 +1,24 @@
1
+ from rainbowneko.ckpt_manager import NekoSaver, CkptFormat, LocalCkptSource, PKLFormat
2
+ from torch import nn
3
+ from typing import Dict, Any
4
+
5
+ class EmbSaver(NekoSaver):
6
+ def __init__(self, format: CkptFormat, source: LocalCkptSource, target_key='embs', prefix=None):
7
+ super().__init__(format, source)
8
+ self.target_key = target_key
9
+ self.prefix = prefix
10
+
11
+ def save_to(self, name, model: nn.Module, plugin_groups: Dict[str, Any], model_ema=None, exclude_key=None,
12
+ name_template=None):
13
+ train_pts = plugin_groups[self.target_key]
14
+ for pt_name, pt in train_pts.items():
15
+ self.save(pt_name, (pt_name, pt), prefix=self.prefix)
16
+ if name_template is not None:
17
+ pt_name = name_template.format(pt_name)
18
+ self.save(pt_name, (pt_name, pt), prefix=self.prefix)
19
+
20
+ def easy_emb_saver():
21
+ return EmbSaver(
22
+ format=PKLFormat(),
23
+ source=LocalCkptSource(),
24
+ )
@@ -0,0 +1,4 @@
1
+ from .emb import EmbFormat
2
+ from .diffusers import DiffusersSD15Format, DiffusersModelFormat, DiffusersSDXLFormat, DiffusersPixArtFormat
3
+ from .sd_single import OfficialSD15Format, OfficialSDXLFormat
4
+ from .lora_webui import LoraWebuiFormat
@@ -0,0 +1,59 @@
1
+ import torch
2
+ from diffusers import ModelMixin, AutoencoderKL, UNet2DConditionModel, PixArtTransformer2DModel
3
+ from rainbowneko.ckpt_manager.format import CkptFormat
4
+ from transformers import CLIPTextModel, AutoTokenizer, T5EncoderModel
5
+
6
+ from hcpdiff.diffusion.sampler import DDPMSampler, DDPMDiscreteSigmaScheduler
7
+ from hcpdiff.models.compose import SDXLTokenizer, SDXLTextEncoder
8
+
9
+ class DiffusersModelFormat(CkptFormat):
10
+ def __init__(self, builder: ModelMixin):
11
+ self.builder = builder
12
+
13
+ def save_ckpt(self, sd_model: ModelMixin, save_f: str, **kwargs):
14
+ sd_model.save_pretrained(save_f)
15
+
16
+ def load_ckpt(self, ckpt_f: str, map_location="cpu", **kwargs):
17
+ self.builder.from_pretrained(ckpt_f, **kwargs)
18
+
19
+ class DiffusersSD15Format(CkptFormat):
20
+ def load_ckpt(self, pretrained_model: str, map_location="cpu", denoiser=None, TE=None, vae: AutoencoderKL = None, noise_sampler=None,
21
+ tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
22
+ denoiser = denoiser or UNet2DConditionModel.from_pretrained(
23
+ pretrained_model, subfolder="unet", revision=revision, torch_dtype=dtype
24
+ )
25
+ vae = vae or AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae", revision=revision, torch_dtype=dtype)
26
+ noise_sampler = noise_sampler or DDPMSampler(DDPMDiscreteSigmaScheduler())
27
+
28
+ TE = TE or CLIPTextModel.from_pretrained(pretrained_model, subfolder="text_encoder", revision=revision, torch_dtype=dtype)
29
+ tokenizer = tokenizer or AutoTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer", revision=revision, use_fast=False)
30
+
31
+ return dict(denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer)
32
+
33
+ class DiffusersSDXLFormat(CkptFormat):
34
+ def load_ckpt(self, pretrained_model: str, map_location="cpu", denoiser=None, TE=None, vae: AutoencoderKL = None, noise_sampler=None,
35
+ tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
36
+ denoiser = denoiser or UNet2DConditionModel.from_pretrained(
37
+ pretrained_model, subfolder="unet", revision=revision, torch_dtype=dtype
38
+ )
39
+ vae = vae or AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae", revision=revision, torch_dtype=dtype)
40
+ noise_sampler = noise_sampler or DDPMSampler(DDPMDiscreteSigmaScheduler())
41
+
42
+ TE = TE or SDXLTextEncoder.from_pretrained(pretrained_model, subfolder="text_encoder", revision=revision, torch_dtype=dtype)
43
+ tokenizer = tokenizer or SDXLTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer", revision=revision, use_fast=False)
44
+
45
+ return dict(denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer)
46
+
47
+ class DiffusersPixArtFormat(CkptFormat):
48
+ def load_ckpt(self, pretrained_model: str, map_location="cpu", denoiser=None, TE=None, vae: AutoencoderKL = None, noise_sampler=None,
49
+ tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
50
+ denoiser = denoiser or PixArtTransformer2DModel.from_pretrained(
51
+ pretrained_model, subfolder="transformer", revision=revision, torch_dtype=dtype
52
+ )
53
+ vae = vae or AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae", revision=revision, torch_dtype=dtype)
54
+ noise_sampler = noise_sampler or DDPMSampler(DDPMDiscreteSigmaScheduler())
55
+
56
+ TE = TE or T5EncoderModel.from_pretrained(pretrained_model, subfolder="text_encoder", revision=revision, torch_dtype=dtype)
57
+ tokenizer = tokenizer or AutoTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer", revision=revision, use_fast=False)
58
+
59
+ return dict(denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer)
@@ -0,0 +1,21 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from rainbowneko.ckpt_manager.format import CkptFormat
5
+ from torch.serialization import FILE_LIKE
6
+
7
+ class EmbFormat(CkptFormat):
8
+ EXT = 'pt'
9
+
10
+ def save_ckpt(self, sd_model: Tuple[str, torch.Tensor], save_f: FILE_LIKE):
11
+ name, emb = sd_model
12
+ torch.save({'string_to_param':{'*':emb}, 'name':name}, save_f)
13
+
14
+ def load_ckpt(self, ckpt_f: FILE_LIKE, map_location="cpu"):
15
+ state = torch.load(ckpt_f, map_location=map_location)
16
+ if 'string_to_param' in state:
17
+ emb = state['string_to_param']['*']
18
+ else:
19
+ emb = state['emb_params']
20
+ emb.requires_grad_(False)
21
+ return emb
@@ -0,0 +1,244 @@
1
+ import math
2
+ import re
3
+ from typing import List, Dict, Any
4
+
5
+ from rainbowneko.ckpt_manager.format import CkptFormat
6
+ from torch.serialization import FILE_LIKE
7
+
8
+ class LoraConverter:
9
+ com_name_unet = ['down_blocks', 'up_blocks', 'mid_block', 'transformer_blocks', 'to_q', 'to_k', 'to_v', 'to_out', 'proj_in', 'proj_out',
10
+ 'input_blocks', 'middle_block', 'output_blocks']
11
+ com_name_TE = ['self_attn', 'q_proj', 'v_proj', 'k_proj', 'out_proj', 'text_model']
12
+ prefix_unet = 'lora_unet_'
13
+ prefix_TE = 'lora_te_'
14
+ prefix_TE_xl_clip_L = 'lora_te1_'
15
+ prefix_TE_xl_clip_bigG = 'lora_te2_'
16
+
17
+ lora_w_map = {'lora_down.weight':'W_down', 'lora_up.weight':'W_up'}
18
+
19
+ def __init__(self):
20
+ self.com_name_unet_tmp = [x.replace('_', '%') for x in self.com_name_unet]
21
+ self.com_name_TE_tmp = [x.replace('_', '%') for x in self.com_name_TE]
22
+
23
+ def convert_from_webui(self, state, auto_scale_alpha=False, sdxl=False):
24
+ if not sdxl:
25
+ sd_unet = self.convert_from_webui_(state, prefix=self.prefix_unet, com_name=self.com_name_unet, com_name_tmp=self.com_name_unet_tmp)
26
+ sd_TE = self.convert_from_webui_(state, prefix=self.prefix_TE, com_name=self.com_name_TE, com_name_tmp=self.com_name_TE_tmp)
27
+ else:
28
+ sd_unet = self.convert_from_webui_xl_unet_(state, prefix=self.prefix_unet, com_name=self.com_name_unet,
29
+ com_name_tmp=self.com_name_unet_tmp)
30
+ sd_TE = self.convert_from_webui_xl_te_(state, prefix=self.prefix_TE_xl_clip_L, com_name=self.com_name_TE,
31
+ com_name_tmp=self.com_name_TE_tmp)
32
+ sd_TE2 = self.convert_from_webui_xl_te_(state, prefix=self.prefix_TE_xl_clip_bigG, com_name=self.com_name_TE,
33
+ com_name_tmp=self.com_name_TE_tmp)
34
+ sd_TE.update(sd_TE2)
35
+
36
+ if auto_scale_alpha:
37
+ sd_unet = self.alpha_scale_from_webui(sd_unet)
38
+ sd_TE = self.alpha_scale_from_webui(sd_TE)
39
+ return {'plugin':sd_TE}, {'plugin':sd_unet}
40
+
41
+ def convert_to_webui(self, sd_unet, sd_TE, auto_scale_alpha=False, sdxl=False):
42
+ sd_unet = self.convert_to_webui_(sd_unet, prefix=self.prefix_unet)
43
+ if sdxl:
44
+ sd_TE = self.convert_to_webui_xl_(sd_TE, prefix=self.prefix_TE)
45
+ else:
46
+ sd_TE = self.convert_to_webui_(sd_TE, prefix=self.prefix_TE)
47
+ sd_unet.update(sd_TE)
48
+ if auto_scale_alpha:
49
+ sd_unet = self.alpha_scale_to_webui(sd_unet)
50
+ return sd_unet
51
+
52
+ def convert_from_webui_(self, state, prefix, com_name, com_name_tmp):
53
+ state = {k:v for k, v in state.items() if k.startswith(prefix)}
54
+ prefix_len = len(prefix)
55
+ sd_covert = {}
56
+ for k, v in state.items():
57
+ model_k, lora_k = k[prefix_len:].split('.', 1)
58
+ model_k = self.replace_all(model_k, com_name, com_name_tmp).replace('_', '.').replace('%', '_')
59
+ if lora_k == 'alpha':
60
+ sd_covert[f'{model_k}.___.{lora_k}'] = v
61
+ else:
62
+ sd_covert[f'{model_k}.___.layer.{self.lora_w_map[lora_k]}'] = v
63
+ return sd_covert
64
+
65
+ def convert_to_webui_(self, state, prefix):
66
+ sd_covert = {}
67
+ for k, v in state.items():
68
+ if k.endswith('W_down'):
69
+ model_k, _ = k.split('.___.', 1)
70
+ lora_k = 'lora_down.weight'
71
+ elif k.endswith('W_up'):
72
+ model_k, _ = k.split('.___.', 1)
73
+ lora_k = 'lora_up.weight'
74
+ else:
75
+ model_k, lora_k = k.split('.___.', 1)
76
+
77
+ sd_covert[f"{prefix}{model_k.replace('.', '_')}.{lora_k}"] = v
78
+ return sd_covert
79
+
80
+ def convert_to_webui_xl_(self, state, prefix):
81
+ sd_convert = {}
82
+ for k, v in state.items():
83
+ if k.endswith('W_down'):
84
+ model_k, _ = k.split('.___.', 1)
85
+ lora_k = 'lora_down.weight'
86
+ elif k.endswith('W_up'):
87
+ model_k, _ = k.split('.___.', 1)
88
+ lora_k = 'lora_up.weight'
89
+ else:
90
+ model_k, lora_k = k.split('.___.', 1)
91
+
92
+ new_k = f"{prefix}{model_k.replace('.', '_')}.{lora_k}"
93
+ if 'clip' in new_k:
94
+ new_k = new_k.replace('_clip_L', '1') if 'clip_L' in new_k else new_k.replace('_clip_bigG', '2')
95
+ sd_convert[new_k] = v
96
+ return sd_convert
97
+
98
+ def convert_from_webui_xl_te_(self, state, prefix, com_name, com_name_tmp):
99
+ state = {k:v for k, v in state.items() if k.startswith(prefix)}
100
+ sd_covert = {}
101
+ prefix_len = len(prefix)
102
+
103
+ for k, v in state.items():
104
+ model_k, lora_k = k[prefix_len:].split('.', 1)
105
+ model_k = self.replace_all(model_k, com_name, com_name_tmp).replace('_', '.').replace('%', '_')
106
+ if prefix == 'lora_te1_':
107
+ model_k = f'clip_L.{model_k}'
108
+ else:
109
+ model_k = f'clip_bigG.{model_k}'
110
+
111
+ if lora_k == 'alpha':
112
+ sd_covert[f'{model_k}.___.{lora_k}'] = v
113
+ else:
114
+ sd_covert[f'{model_k}.___.layer.{self.lora_w_map[lora_k]}'] = v
115
+ return sd_covert
116
+
117
+ def convert_from_webui_xl_unet_(self, state, prefix, com_name, com_name_tmp):
118
+ # Down:
119
+ # 4 -> 1, 0 4 = 1 + 3 * 1 + 0
120
+ # 5 -> 1, 1 5 = 1 + 3 * 1 + 1
121
+ # 7 -> 2, 0 7 = 1 + 3 * 2 + 0
122
+ # 8 -> 2, 1 8 = 1 + 3 * 2 + 1
123
+
124
+ # Up
125
+ # 0 -> 0, 0 0 = 0 * 3 + 0
126
+ # 1 -> 0, 1 1 = 0 * 3 + 1
127
+ # 2 -> 0, 2 2 = 0 * 3 + 2
128
+ # 3 -> 1, 0 3 = 1 * 3 + 0
129
+ # 4 -> 1, 1 4 = 1 * 3 + 1
130
+ # 5 -> 1, 2 5 = 1 * 3 + 2
131
+
132
+ down = {
133
+ '4':[1, 0],
134
+ '5':[1, 1],
135
+ '7':[2, 0],
136
+ '8':[2, 1],
137
+ }
138
+ up = {
139
+ '0':[0, 0],
140
+ '1':[0, 1],
141
+ '2':[0, 2],
142
+ '3':[1, 0],
143
+ '4':[1, 1],
144
+ '5':[1, 2],
145
+ }
146
+
147
+ m = []
148
+
149
+ def match(key, regex_text):
150
+ regex = re.compile(regex_text)
151
+ r = re.match(regex, key)
152
+ if not r:
153
+ return False
154
+
155
+ m.clear()
156
+ m.extend(r.groups())
157
+ return True
158
+
159
+ state = {k:v for k, v in state.items() if k.startswith(prefix)}
160
+ sd_covert = {}
161
+ prefix_len = len(prefix)
162
+ for k, v in state.items():
163
+ model_k, lora_k = k[prefix_len:].split('.', 1)
164
+
165
+ model_k = self.replace_all(model_k, com_name, com_name_tmp).replace('_', '.').replace('%', '_')
166
+
167
+ if match(model_k, r'input_blocks.(\d+).1.(.+)'):
168
+ new_k = f'down_blocks.{down[m[0]][0]}.attentions.{down[m[0]][1]}.{m[1]}'
169
+ elif match(model_k, r'middle_block.1.(.+)'):
170
+ new_k = f'mid_block.attentions.0.{m[0]}'
171
+ pass
172
+ elif match(model_k, r'output_blocks.(\d+).(\d+).(.+)'):
173
+ new_k = f'up_blocks.{up[m[0]][0]}.attentions.{up[m[0]][1]}.{m[2]}'
174
+ else:
175
+ raise NotImplementedError
176
+
177
+ if lora_k == 'alpha':
178
+ sd_covert[f'{new_k}.___.{lora_k}'] = v
179
+ else:
180
+ sd_covert[f'{new_k}.___.layer.{self.lora_w_map[lora_k]}'] = v
181
+
182
+ return sd_covert
183
+
184
+ @staticmethod
185
+ def replace_all(data: str, srcs: List[str], dsts: List[str]):
186
+ for src, dst in zip(srcs, dsts):
187
+ data = data.replace(src, dst)
188
+ return data
189
+
190
+ @staticmethod
191
+ def alpha_scale_from_webui(state):
192
+ # Apply to "lora_down" and "lora_up" respectively to prevent overflow
193
+ for k, v in state.items():
194
+ if 'W_up' in k:
195
+ state[k] = v*math.sqrt(v.shape[1])
196
+ elif 'W_down' in k:
197
+ state[k] = v*math.sqrt(v.shape[0])
198
+ return state
199
+
200
+ @staticmethod
201
+ def alpha_scale_to_webui(state):
202
+ for k, v in state.items():
203
+ if 'lora_up' in k:
204
+ state[k] = v*math.sqrt(v.shape[1])
205
+ elif 'lora_down' in k:
206
+ state[k] = v*math.sqrt(v.shape[0])
207
+ return state
208
+
209
+ class LoraWebuiFormat(CkptFormat):
210
+ def __init__(self, format, auto_scale_alpha=False):
211
+ self.converter = LoraConverter()
212
+ self.auto_scale_alpha = auto_scale_alpha
213
+ self.format = format
214
+
215
+ def save_ckpt(self, sd_model: Dict[str, Any], save_f: FILE_LIKE):
216
+ sd_denoiser = {k.removeprefix('denoiser.'):v for k, v in sd_model['base'].items() if k.startswith('denoiser.')}
217
+ sd_TE = {k.removeprefix('TE.'):v for k, v in sd_model['base'].items() if k.startswith('TE.')}
218
+
219
+ if len(sd_denoiser)>0 or len(sd_TE)>0:
220
+ sdxl = False
221
+ for k in sd_TE.keys():
222
+ if 'clip_L' in k or 'clip_bigG' in k:
223
+ sdxl = True
224
+ break
225
+ sd_webui = self.converter.convert_to_webui(sd_denoiser, sd_TE, auto_scale_alpha=self.auto_scale_alpha, sdxl=sdxl)
226
+ else:
227
+ sd_webui = self.converter.convert_to_webui(sd_model['base'], {}, auto_scale_alpha=self.auto_scale_alpha)
228
+
229
+ self.format.save_ckpt(sd_webui, save_f)
230
+
231
+ def load_ckpt(self, ckpt_f: str, map_location="cpu", **kwargs):
232
+ sd_webui = self.format.load_ckpt(ckpt_f, map_location=map_location, **kwargs)
233
+
234
+ sdxl = False
235
+ for k in sd_webui.keys():
236
+ if ('lora_te1_' in k or 'lora_te2_' in k or
237
+ re.match(r'input_blocks.(\d+).1.(.+)', k) or
238
+ re.match(r'middle_block.1.(.+)', k) or
239
+ re.match(r'output_blocks.(\d+).(\d+).(.+)', k)):
240
+ sdxl = True
241
+ break
242
+
243
+ sd_TE, sd_unet = self.converter.convert_from_webui(sd_webui, auto_scale_alpha=self.auto_scale_alpha, sdxl=sdxl)
244
+ return sd_TE, sd_unet
@@ -0,0 +1,41 @@
1
+ import torch
2
+ from diffusers import AutoencoderKL, StableDiffusionPipeline, StableDiffusionXLPipeline
3
+ from rainbowneko.ckpt_manager.format import CkptFormat
4
+
5
+ from hcpdiff.diffusion.sampler import DDPMSampler, DDPMDiscreteSigmaScheduler
6
+ from hcpdiff.models.compose import SDXLTextEncoder, SDXLTokenizer
7
+
8
+ class OfficialSD15Format(CkptFormat):
9
+ # Single file format
10
+ def load_ckpt(self, pretrained_model: str, map_location="cpu", denoiser=None, TE=None, vae: AutoencoderKL = None, noise_sampler=None,
11
+ tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
12
+ pipe_args = dict(unet=denoiser, vae=vae, text_encoder=TE, tokenizer=tokenizer)
13
+ pipe_args = {k:v for k,v in pipe_args.items() if v is not None}
14
+ pipe = StableDiffusionPipeline.from_single_file(
15
+ pretrained_model, revision=revision, torch_dtype=dtype, **pipe_args
16
+ )
17
+ noise_sampler = noise_sampler or DDPMSampler(DDPMDiscreteSigmaScheduler())
18
+ return dict(denoiser=pipe.unet, TE=pipe.text_encoder, vae=pipe.vae, noise_sampler=noise_sampler, tokenizer=pipe.tokenizer)
19
+
20
+ class OfficialSDXLFormat(CkptFormat):
21
+ # Single file format
22
+ def load_ckpt(self, pretrained_model: str, map_location="cpu", denoiser=None, TE=None, vae: AutoencoderKL = None, noise_sampler=None,
23
+ tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
24
+ pipe_args = dict(unet=denoiser, vae=vae)
25
+ if TE is not None:
26
+ pipe_args['text_encoder'] = TE.clip_L
27
+ pipe_args['text_encoder_2'] = TE.clip_bigG
28
+ if tokenizer is not None:
29
+ pipe_args['tokenizer'] = tokenizer.clip_L
30
+ pipe_args['tokenizer_2'] = tokenizer.clip_bigG
31
+
32
+ pipe_args = {k:v for k,v in pipe_args.items() if v is not None}
33
+ pipe = StableDiffusionXLPipeline.from_single_file(
34
+ pretrained_model, revision=revision, torch_dtype=dtype, **pipe_args
35
+ )
36
+
37
+ noise_sampler = noise_sampler or DDPMSampler(DDPMDiscreteSigmaScheduler())
38
+ TE = SDXLTextEncoder([('clip_L', pipe.text_encoder), ('clip_bigG', pipe.text_encoder_2)])
39
+ tokenizer = SDXLTokenizer([('clip_L', pipe.tokenizer), ('clip_bigG', pipe.tokenizer_2)])
40
+
41
+ return dict(denoiser=pipe.unet, TE=TE, vae=pipe.vae, noise_sampler=noise_sampler, tokenizer=tokenizer)
@@ -0,0 +1,64 @@
1
+ from hcpdiff.models.lora_layers_patch import LoraLayer
2
+ from torch import nn
3
+ from hcpdiff.utils.net_utils import split_module_name
4
+ from rainbowneko.ckpt_manager import NekoPluginLoader, LocalCkptSource, CkptFormat
5
+ from rainbowneko.ckpt_manager.locator import get_match_layers
6
+ from rainbowneko.models.plugin import PluginGroup
7
+
8
+ def get_lora_rank_and_cls(lora_state):
9
+ if 'layer.W_down' in lora_state:
10
+ rank = lora_state['layer.W_down'].shape[0]
11
+ return LoraLayer, rank
12
+ else:
13
+ raise ValueError('Unknown lora format.')
14
+
15
+ class HCPLoraLoader(NekoPluginLoader):
16
+ def __init__(self, format: CkptFormat=None, source: LocalCkptSource=None, path: str = None, layers='all', target_plugin=None,
17
+ state_prefix=None, base_model_alpha=0.0, load_ema=False, module_to_load='', **plugin_kwargs):
18
+ super().__init__(format, source, path=path, layers=layers, target_plugin=target_plugin, state_prefix=state_prefix,
19
+ base_model_alpha=base_model_alpha, load_ema=load_ema, **plugin_kwargs)
20
+ self.module_to_load = module_to_load
21
+
22
+ def load_to(self, name, model):
23
+ # get model to load plugin and its named_modules
24
+ model = model if self.module_to_load == '' else eval(f"model.{self.module_to_load}")
25
+
26
+ named_modules = {k:v for k, v in model.named_modules()}
27
+ plugin_state = self.load(self.path, map_location='cpu')['base_ema' if self.load_ema else 'base']
28
+
29
+ # filter layers to load
30
+ if self.layers != 'all':
31
+ match_blocks = get_match_layers(self.layers, named_modules)
32
+ plugin_state = {k: v for blk in match_blocks for k, v in plugin_state.items() if k.startswith(blk)}
33
+
34
+ if self.state_prefix:
35
+ state_prefix_len = len(self.state_prefix)
36
+ plugin_state = {k[state_prefix_len:]: v for k, v in plugin_state.items() if k.startswith(self.state_prefix)}
37
+
38
+ lora_block_state = {}
39
+ # get all layers in the lora_state
40
+ for pname, p in plugin_state.items():
41
+ # lora_block. is the old format
42
+ prefix, block_name = pname.split('.___.', 1)
43
+ if prefix not in lora_block_state:
44
+ lora_block_state[prefix] = {}
45
+ lora_block_state[prefix][block_name] = p
46
+
47
+ # add lora to host and load weights
48
+ lora_blocks = {}
49
+ for layer_name, lora_state in lora_block_state.items():
50
+ lora_layer_cls, rank = get_lora_rank_and_cls(lora_state)
51
+
52
+ if 'alpha' in lora_state:
53
+ lora_state['alpha'] *= self.plugin_kwargs.get('alpha', 1.0)
54
+
55
+ parent_name, host_name = split_module_name(layer_name)
56
+
57
+ lora_block = lora_layer_cls.wrap_layer(name, named_modules[layer_name], rank=rank, bias='layer.bias' in lora_state,
58
+ parent_block=named_modules[parent_name], host_name=host_name)
59
+ lora_block.set_hyper_params(**self.plugin_kwargs)
60
+ lora_blocks[layer_name] = lora_block
61
+ load_info = lora_block.load_state_dict(lora_state, strict=False)
62
+ if len(load_info.unexpected_keys) > 0:
63
+ print(name, 'unexpected_keys', load_info.unexpected_keys)
64
+ return PluginGroup(lora_blocks)
hcpdiff/data/__init__.py CHANGED
@@ -1,28 +1,4 @@
1
- from .pair_dataset import TextImagePairDataset
2
- from .cond_dataset import TextImageCondPairDataset
3
- from .crop_info_dataset import CropInfoPairDataset
4
- from .bucket import BaseBucket, FixedBucket, RatioBucket, SizeBucket, RatioSizeBucket, LongEdgeBucket
5
- from .utils import CycleData
6
- from .caption_loader import JsonCaptionLoader, TXTCaptionLoader
7
- from .sampler import DistributedCycleSampler, get_sampler
8
-
9
- class DataGroup:
10
- def __init__(self, loader_list, loss_weights):
11
- self.loader_list = loader_list
12
- self.loss_weights = loss_weights
13
-
14
- def __iter__(self):
15
- self.data_iter_list = [iter(CycleData(loader)) for loader in self.loader_list]
16
- return self
17
-
18
- def __next__(self):
19
- return [next(data_iter) for data_iter in self.data_iter_list]
20
-
21
- def __len__(self):
22
- return len(self.loader_list)
23
-
24
- def get_dataset(self, idx):
25
- return self.loader_list[idx].dataset
26
-
27
- def get_loss_weights(self, idx):
28
- return self.loss_weights[idx]
1
+ from .dataset import TextImagePairDataset
2
+ from .source import Text2ImageSource, Text2ImageLossMapSource, Text2ImageCondSource, T2IFolderClassSource
3
+ from .handler import StableDiffusionHandler, LossMapHandler, DiffusionImageHandler
4
+ from .cache import VaeCache
@@ -0,0 +1 @@
1
+ from .vae import VaeCache
@@ -0,0 +1,102 @@
1
+ from io import BytesIO
2
+ from pathlib import Path
3
+ from typing import Dict, Any
4
+
5
+ import lmdb
6
+ import torch
7
+ from hcpdiff.models.wrapper import SD15Wrapper
8
+ from rainbowneko import _share
9
+ from rainbowneko.data import DataCache, CacheableDataset
10
+ from rainbowneko.utils import Path_Like
11
+ from torch.utils.data import DataLoader
12
+ from torch.utils.data.distributed import DistributedSampler
13
+ from tqdm import tqdm
14
+
15
+ class VaeCache(DataCache):
16
+ def __init__(self, pre_build: Path_Like = None, lazy=False, bs=1):
17
+ super().__init__(pre_build)
18
+ self.lazy = lazy
19
+ self.bs = bs
20
+
21
+ def load_latent(self, id):
22
+ if self.lazy:
23
+ with self.env.begin() as txn:
24
+ byte_tensor = txn.get(str(id).encode())
25
+ return torch.load(BytesIO(byte_tensor))
26
+ else:
27
+ return self.cache[id]
28
+
29
+ def before_handler(self, index: int, data: Dict[str, Any]):
30
+ cached_data = self.load_latent(data['id'])
31
+ data['image'] = cached_data['latent']
32
+ data['coord'] = cached_data['coord']
33
+ return data
34
+
35
+ def on_finish(self, index, data):
36
+ return data
37
+
38
+ def load(self, path):
39
+ if self.lazy:
40
+ self.env = lmdb.open(path, subdir=False, readonly=True, lock=False, readahead=False, meminit=False)
41
+ return {}
42
+ elif len(self.cache)>0:
43
+ return self.cache
44
+ else:
45
+ env = lmdb.open(path, subdir=False, readonly=True, lock=False, readahead=False, meminit=False)
46
+ with env.begin() as txn:
47
+ cache = {k.decode():torch.load(BytesIO(v)) for k, v in txn.cursor()}
48
+ env.close()
49
+ return cache
50
+
51
+ def build(self, dataset: CacheableDataset, model: SD15Wrapper, all_gather):
52
+ if (self.pre_build and Path(self.pre_build).exists()) or len(self.cache)>0:
53
+ model.vae = None
54
+ return
55
+
56
+ vae = model.vae.to(_share.device)
57
+ with dataset.disable_cache():
58
+ dataset.bucket.rest(0)
59
+
60
+ loader = DataLoader(
61
+ dataset,
62
+ batch_size=self.bs,
63
+ num_workers=0,
64
+ sampler=DistributedSampler(dataset, num_replicas=_share.world_size, rank=_share.local_rank, shuffle=False),
65
+ collate_fn=dataset.collate_fn,
66
+ drop_last=False,
67
+ )
68
+
69
+ if self.pre_build:
70
+ Path(self.pre_build).parent.mkdir(parents=True, exist_ok=True)
71
+ env = lmdb.open(self.pre_build, map_size=1099511627776)
72
+ with env.begin(write=True) as txn:
73
+ for data in tqdm(loader):
74
+ image = data['image'].to(device=_share.device, dtype=vae.dtype)
75
+ latents = model.vae.encode(image).latent_dist.sample()
76
+ latents = (latents*vae.config.scaling_factor).cpu()
77
+
78
+ for img_id, latent, coord in zip(data['id'], latents, data['coord']):
79
+ data_cache = {'latent': latent, 'coord': coord}
80
+
81
+ byte_stream = BytesIO()
82
+ torch.save(data_cache, byte_stream)
83
+ txn.put(str(img_id).encode(), byte_stream.getvalue())
84
+ if not self.lazy:
85
+ self.cache[img_id] = data_cache
86
+ env.close()
87
+ else:
88
+ for data in tqdm(loader):
89
+ img_id = data['id']
90
+ image = data['image'].to(device=_share.device, dtype=vae.dtype)
91
+ latents = model.vae.encode(image).latent_dist.sample()
92
+ latents = (latents*vae.config.scaling_factor).cpu()
93
+ for img_id, latent, coord in zip(data['id'], latents, data['coord']):
94
+ self.cache[img_id] = {'latent': latent, 'coord': coord}
95
+
96
+ model.vae.to('cpu')
97
+ #model.vae = None
98
+ torch.cuda.empty_cache()
99
+
100
+ cache_all = all_gather(self.cache)
101
+ for cache in cache_all:
102
+ self.cache.update(cache)
@@ -0,0 +1,20 @@
1
+ """
2
+ pair_dataset.py
3
+ ====================
4
+ :Name: text-image pair dataset
5
+ :Author: Dong Ziyi
6
+ :Affiliation: HCP Lab, SYSU
7
+ :Created: 10/03/2023
8
+ :Licence: Apache-2.0
9
+ """
10
+
11
+ from typing import Union, Dict
12
+
13
+ from rainbowneko.data import CacheableDataset, BaseDataset, BaseBucket, DataSource, DataHandler, DataCache
14
+
15
+ def TextImagePairDataset(bucket: BaseBucket = None, source: Dict[str, DataSource] = None, handler: DataHandler = None,
16
+ batch_handler: DataHandler = None, cache: DataCache = None, **kwargs) -> Union[CacheableDataset, BaseDataset]:
17
+ if cache is None:
18
+ return BaseDataset(bucket=bucket, source=source, handler=handler, batch_handler=batch_handler, **kwargs)
19
+ else:
20
+ return CacheableDataset(bucket=bucket, source=source, handler=handler, batch_handler=batch_handler, cache=cache, **kwargs)
@@ -0,0 +1,3 @@
1
+ from .diffusion import StableDiffusionHandler, DiffusionImageHandler, LossMapHandler
2
+ from .text import TokenizeHandler, TagEraseHandler, TagDropoutHandler, TagShuffleHandler, TemplateFillHandler
3
+ from .controlnet import ControlNetHandler
@@ -0,0 +1,18 @@
1
+ import torchvision.transforms as T
2
+ from PIL import Image
3
+ from rainbowneko.data import DataHandler, HandlerChain, LoadImageHandler, ImageHandler
4
+
5
+ class ControlNetHandler(DataHandler):
6
+ def __init__(self, key_map_in=('cond -> image',), key_map_out=('image -> cond',), bucket=None):
7
+ super().__init__(key_map_in, key_map_out)
8
+
9
+ self.handlers = HandlerChain(
10
+ load=LoadImageHandler(),
11
+ bucket=bucket.handler if bucket else DataHandler(),
12
+ image=ImageHandler(
13
+ transform=T.ToTensor(),
14
+ )
15
+ )
16
+
17
+ def handle(self, image:Image.Image):
18
+ return self.handlers(dict(image=image))