hcpdiff 0.9.1__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +80 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +1 -2
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +201 -0
- hcpdiff/easy/cfg/sdxl_train.py +140 -0
- hcpdiff/easy/cfg/t2i.py +177 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +18 -17
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +7 -2
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +24 -7
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +114 -125
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -130
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +78 -46
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.1.dist-info/METADATA +285 -0
- hcpdiff-2.1.dist-info/RECORD +114 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
- hcpdiff-2.1.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/deprecated/__init__.py +0 -1
- hcpdiff/deprecated/cfg_converter.py +0 -81
- hcpdiff/deprecated/lora_convert.py +0 -31
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -566
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/viser_fast.py +0 -138
- hcpdiff/visualizer.py +0 -265
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.1.dist-info/METADATA +0 -199
- hcpdiff-0.9.1.dist-info/RECORD +0 -160
- hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -1,54 +0,0 @@
|
|
1
|
-
from .base import CkptManagerBase
|
2
|
-
import os
|
3
|
-
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
|
4
|
-
from hcpdiff.models.plugin import BasePluginBlock
|
5
|
-
from hcpdiff.tools.sd2diffusers import load_sd_ckpt, patch_method
|
6
|
-
|
7
|
-
class CkptManagerWebui(CkptManagerBase):
|
8
|
-
|
9
|
-
def set_save_dir(self, save_dir, emb_dir=None):
|
10
|
-
os.makedirs(save_dir, exist_ok=True)
|
11
|
-
self.save_dir = save_dir
|
12
|
-
self.emb_dir = emb_dir
|
13
|
-
|
14
|
-
def save(self, step, unet, TE, lora_unet, lora_TE, all_plugin_unet, all_plugin_TE, embs, pipe: StableDiffusionPipeline, **kwargs):
|
15
|
-
def state_dict_unet(*args, model=unet, **kwargs):
|
16
|
-
plugin_names = {k for k, v in model.named_modules() if isinstance(v, BasePluginBlock)}
|
17
|
-
model_sd = {}
|
18
|
-
for k, v in model.state_dict_().items():
|
19
|
-
for name in plugin_names:
|
20
|
-
if k.startswith(name):
|
21
|
-
break
|
22
|
-
else:
|
23
|
-
model_sd[k] = v
|
24
|
-
return model_sd
|
25
|
-
unet.state_dict_ = unet.state_dict
|
26
|
-
unet.state_dict = state_dict_unet
|
27
|
-
|
28
|
-
def state_dict_TE(*args, model=TE, **kwargs):
|
29
|
-
plugin_names = {k for k, v in model.named_modules() if isinstance(v, BasePluginBlock)}
|
30
|
-
model_sd = {}
|
31
|
-
for k, v in model.state_dict_().items():
|
32
|
-
for name in plugin_names:
|
33
|
-
if k.startswith(name):
|
34
|
-
break
|
35
|
-
else:
|
36
|
-
model_sd[k] = v
|
37
|
-
return model_sd
|
38
|
-
TE.state_dict_ = TE.state_dict
|
39
|
-
TE.state_dict = state_dict_TE
|
40
|
-
|
41
|
-
pipe.save_pretrained(os.path.join(self.save_dir, f"model-{step}"), **kwargs)
|
42
|
-
|
43
|
-
@classmethod
|
44
|
-
def load(cls, pretrained_model, original_config_file, from_safetensors=False, device='cpu', ema=True, **kwargs) -> StableDiffusionPipeline:
|
45
|
-
patch_method()
|
46
|
-
pipe = load_sd_ckpt(
|
47
|
-
checkpoint_path=pretrained_model,
|
48
|
-
original_config_file=original_config_file,
|
49
|
-
extract_ema=ema,
|
50
|
-
scheduler_type='pndm',
|
51
|
-
from_safetensors=from_safetensors,
|
52
|
-
device=device,
|
53
|
-
)
|
54
|
-
return pipe
|
hcpdiff/data/bucket.py
DELETED
@@ -1,358 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
bucket.py
|
3
|
-
====================
|
4
|
-
:Name: aspect ratio bucket with k-means
|
5
|
-
:Author: Dong Ziyi
|
6
|
-
:Affiliation: HCP Lab, SYSU
|
7
|
-
:Created: 10/03/2023
|
8
|
-
:Licence: Apache-2.0
|
9
|
-
"""
|
10
|
-
|
11
|
-
import math
|
12
|
-
import os.path
|
13
|
-
import pickle
|
14
|
-
from typing import List, Tuple, Union, Any
|
15
|
-
|
16
|
-
import cv2
|
17
|
-
import numpy as np
|
18
|
-
from hcpdiff.utils.img_size_tool import types_support, get_image_size
|
19
|
-
from hcpdiff.utils.utils import get_file_ext
|
20
|
-
from .source import DataSource
|
21
|
-
from loguru import logger
|
22
|
-
from sklearn.cluster import KMeans
|
23
|
-
from tqdm import tqdm
|
24
|
-
from concurrent.futures import ThreadPoolExecutor
|
25
|
-
|
26
|
-
from .utils import resize_crop_fix, pad_crop_fix
|
27
|
-
|
28
|
-
class BaseBucket:
|
29
|
-
def __getitem__(self, idx):
|
30
|
-
'''
|
31
|
-
:return: (file name of image), (target image size)
|
32
|
-
'''
|
33
|
-
raise NotImplementedError()
|
34
|
-
|
35
|
-
def __len__(self):
|
36
|
-
raise NotImplementedError()
|
37
|
-
|
38
|
-
def build(self, bs: int, img_root_list: List[str]):
|
39
|
-
raise NotImplementedError()
|
40
|
-
|
41
|
-
def rest(self, epoch):
|
42
|
-
pass
|
43
|
-
|
44
|
-
def crop_resize(self, image, size, mask_interp=cv2.INTER_CUBIC) -> Tuple[Any, Tuple]:
|
45
|
-
return image, (*size, 0, 0, *size)
|
46
|
-
|
47
|
-
class FixedBucket(BaseBucket):
|
48
|
-
def __init__(self, target_size: Union[Tuple[int, int], int] = 512, **kwargs):
|
49
|
-
self.target_size = (target_size, target_size) if isinstance(target_size, int) else target_size
|
50
|
-
|
51
|
-
def build(self, bs: int, file_names: List[Tuple[str, DataSource]]):
|
52
|
-
self.file_names = file_names
|
53
|
-
|
54
|
-
def crop_resize(self, image, size, mask_interp=cv2.INTER_CUBIC):
|
55
|
-
return resize_crop_fix(image, size, mask_interp=mask_interp)
|
56
|
-
|
57
|
-
def __getitem__(self, idx) -> Tuple[Tuple[str, DataSource], Tuple[int, int]]:
|
58
|
-
return self.file_names[idx], self.target_size
|
59
|
-
|
60
|
-
def __len__(self):
|
61
|
-
return len(self.file_names)
|
62
|
-
|
63
|
-
class RatioBucket(BaseBucket):
|
64
|
-
def __init__(self, target_area: int = 640*640, step_size: int = 8, num_bucket: int = 10, pre_build_bucket: str = None):
|
65
|
-
self.target_area = target_area
|
66
|
-
self.step_size = step_size
|
67
|
-
self.num_bucket = num_bucket
|
68
|
-
self.pre_build_bucket = pre_build_bucket
|
69
|
-
|
70
|
-
def load_bucket(self, path):
|
71
|
-
with open(path, 'rb') as f:
|
72
|
-
data = pickle.load(f)
|
73
|
-
self.buckets = data['buckets']
|
74
|
-
self.size_buckets = data['size_buckets']
|
75
|
-
self.idx_bucket_map = data['idx_bucket_map']
|
76
|
-
self.data_len = data['data_len']
|
77
|
-
|
78
|
-
def save_bucket(self, path):
|
79
|
-
with open(path, 'wb') as f:
|
80
|
-
pickle.dump({
|
81
|
-
'buckets':self.buckets,
|
82
|
-
'size_buckets':self.size_buckets,
|
83
|
-
'idx_bucket_map':self.idx_bucket_map,
|
84
|
-
'data_len':self.data_len,
|
85
|
-
}, f)
|
86
|
-
|
87
|
-
def build_buckets_from_ratios(self):
|
88
|
-
logger.info('build buckets from ratios')
|
89
|
-
size_low = int(math.sqrt(self.target_area/self.ratio_max))
|
90
|
-
size_high = int(self.ratio_max*size_low)
|
91
|
-
|
92
|
-
# SD需要边长是8的倍数
|
93
|
-
size_low = (size_low//self.step_size)*self.step_size
|
94
|
-
size_high = (size_high//self.step_size)*self.step_size
|
95
|
-
|
96
|
-
data = []
|
97
|
-
for w in range(size_low, size_high+1, self.step_size):
|
98
|
-
for h in range(size_low, size_high+1, self.step_size):
|
99
|
-
data.append([w*h, np.log2(w/h), w, h]) # 对比例取对数,更符合人感知,宽高相反的可以对称分布。
|
100
|
-
data = np.array(data)
|
101
|
-
|
102
|
-
error_area = np.abs(data[:, 0]-self.target_area)
|
103
|
-
data_use = data[np.argsort(error_area)[:self.num_bucket*3], :] # 取最小的num_bucket*3个
|
104
|
-
|
105
|
-
# 聚类,选出指定个数的bucket
|
106
|
-
kmeans = KMeans(n_clusters=self.num_bucket, random_state=3407).fit(data_use[:, 1].reshape(-1, 1))
|
107
|
-
labels = kmeans.labels_
|
108
|
-
self.buckets = [] # [bucket_id:[file_idx,...]]
|
109
|
-
ratios_log = []
|
110
|
-
self.size_buckets = []
|
111
|
-
for i in range(self.num_bucket):
|
112
|
-
map_idx = np.where(labels == i)[0]
|
113
|
-
m_idx = map_idx[np.argmin(np.abs(data_use[labels == i, 1]-np.median(data_use[labels == i, 1])))]
|
114
|
-
# self.buckets[wh_hash(*data_use[m_idx, 2:])]=[]
|
115
|
-
self.buckets.append([])
|
116
|
-
ratios_log.append(data_use[m_idx, 1])
|
117
|
-
self.size_buckets.append(data_use[m_idx, 2:].astype(int))
|
118
|
-
ratios_log = np.array(ratios_log)
|
119
|
-
self.size_buckets = np.array(self.size_buckets)
|
120
|
-
|
121
|
-
# fill buckets with images w,h
|
122
|
-
self.idx_bucket_map = np.empty(len(self.file_names), dtype=int)
|
123
|
-
for i, (file, source) in enumerate(self.file_names):
|
124
|
-
w, h = get_image_size(file)
|
125
|
-
bucket_id = np.abs(ratios_log-np.log2(w/h)).argmin()
|
126
|
-
self.buckets[bucket_id].append(i)
|
127
|
-
self.idx_bucket_map[i] = bucket_id
|
128
|
-
logger.info('buckets info: '+', '.join(f'size:{self.size_buckets[i]}, num:{len(b)}' for i, b in enumerate(self.buckets)))
|
129
|
-
|
130
|
-
def build_buckets_from_images(self):
|
131
|
-
logger.info('build buckets from images')
|
132
|
-
|
133
|
-
def get_ratio(data):
|
134
|
-
file, source = data
|
135
|
-
w, h = get_image_size(file)
|
136
|
-
ratio = np.log2(w/h)
|
137
|
-
return ratio
|
138
|
-
|
139
|
-
ratio_list = []
|
140
|
-
with ThreadPoolExecutor() as executor:
|
141
|
-
for ratio in tqdm(executor.map(get_ratio, self.file_names), desc='get image info', total=len(self.file_names)):
|
142
|
-
ratio_list.append(ratio)
|
143
|
-
ratio_list = np.array(ratio_list)
|
144
|
-
|
145
|
-
# 聚类,选出指定个数的bucket
|
146
|
-
kmeans = KMeans(n_clusters=self.num_bucket, random_state=3407, verbose=True, tol=1e-3).fit(ratio_list.reshape(-1, 1))
|
147
|
-
labels = kmeans.labels_
|
148
|
-
ratios = 2**kmeans.cluster_centers_.reshape(-1)
|
149
|
-
|
150
|
-
h_all = np.sqrt(self.target_area/ratios)
|
151
|
-
w_all = h_all*ratios
|
152
|
-
|
153
|
-
# SD需要边长是8的倍数
|
154
|
-
h_all = (np.round(h_all/self.step_size)*self.step_size).astype(int)
|
155
|
-
w_all = (np.round(w_all/self.step_size)*self.step_size).astype(int)
|
156
|
-
self.size_buckets = list(zip(w_all, h_all))
|
157
|
-
self.size_buckets = np.array(self.size_buckets)
|
158
|
-
|
159
|
-
self.buckets = [] # [bucket_id:[file_idx,...]]
|
160
|
-
self.idx_bucket_map = np.empty(len(self.file_names), dtype=int)
|
161
|
-
for bidx in range(self.num_bucket):
|
162
|
-
bnow = labels == bidx
|
163
|
-
self.buckets.append(np.where(bnow)[0].tolist())
|
164
|
-
self.idx_bucket_map[bnow] = bidx
|
165
|
-
logger.info('buckets info: '+', '.join(f'size:{self.size_buckets[i]}, num:{len(b)}' for i, b in enumerate(self.buckets)))
|
166
|
-
|
167
|
-
def build(self, bs: int, file_names: List[Tuple[str, DataSource]]):
|
168
|
-
'''
|
169
|
-
:param bs: batch_size * n_gpus * accumulation_step
|
170
|
-
:param img_root_list:
|
171
|
-
'''
|
172
|
-
self.file_names = file_names
|
173
|
-
self.bs = bs
|
174
|
-
if self.pre_build_bucket and os.path.exists(self.pre_build_bucket):
|
175
|
-
self.load_bucket(self.pre_build_bucket)
|
176
|
-
return
|
177
|
-
|
178
|
-
self._build()
|
179
|
-
|
180
|
-
rs = np.random.RandomState(42)
|
181
|
-
# make len(bucket)%bs==0
|
182
|
-
self.data_len = 0
|
183
|
-
for bidx, bucket in enumerate(self.buckets):
|
184
|
-
rest = len(bucket)%bs
|
185
|
-
if rest>0:
|
186
|
-
bucket.extend(rs.choice(bucket, bs-rest))
|
187
|
-
self.data_len += len(bucket)
|
188
|
-
self.buckets[bidx] = np.array(bucket)
|
189
|
-
|
190
|
-
if self.pre_build_bucket:
|
191
|
-
self.save_bucket(self.pre_build_bucket)
|
192
|
-
|
193
|
-
def rest(self, epoch):
|
194
|
-
rs = np.random.RandomState(42+epoch)
|
195
|
-
bucket_list = [x.copy() for x in self.buckets]
|
196
|
-
# shuffle inter bucket
|
197
|
-
for x in bucket_list:
|
198
|
-
rs.shuffle(x)
|
199
|
-
|
200
|
-
# shuffle of batches
|
201
|
-
bucket_list = np.hstack(bucket_list).reshape(-1, self.bs).astype(int)
|
202
|
-
rs.shuffle(bucket_list)
|
203
|
-
|
204
|
-
self.idx_bucket = bucket_list.reshape(-1)
|
205
|
-
|
206
|
-
def crop_resize(self, image, size, mask_interp=cv2.INTER_CUBIC):
|
207
|
-
return resize_crop_fix(image, size, mask_interp=mask_interp)
|
208
|
-
|
209
|
-
def __getitem__(self, idx):
|
210
|
-
file_idx = self.idx_bucket[idx]
|
211
|
-
bucket_idx = self.idx_bucket_map[file_idx]
|
212
|
-
return self.file_names[file_idx], self.size_buckets[bucket_idx]
|
213
|
-
|
214
|
-
def __len__(self):
|
215
|
-
return self.data_len
|
216
|
-
|
217
|
-
@classmethod
|
218
|
-
def from_ratios(cls, target_area: int = 640*640, step_size: int = 8, num_bucket: int = 10, ratio_max: float = 4,
|
219
|
-
pre_build_bucket: str = None, **kwargs):
|
220
|
-
arb = cls(target_area, step_size, num_bucket, pre_build_bucket=pre_build_bucket)
|
221
|
-
arb.ratio_max = ratio_max
|
222
|
-
arb._build = arb.build_buckets_from_ratios
|
223
|
-
return arb
|
224
|
-
|
225
|
-
@classmethod
|
226
|
-
def from_files(cls, target_area: int = 640*640, step_size: int = 8, num_bucket: int = 10, pre_build_bucket: str = None, **kwargs):
|
227
|
-
arb = cls(target_area, step_size, num_bucket, pre_build_bucket=pre_build_bucket)
|
228
|
-
arb._build = arb.build_buckets_from_images
|
229
|
-
return arb
|
230
|
-
|
231
|
-
class SizeBucket(RatioBucket):
|
232
|
-
def __init__(self, step_size: int = 8, num_bucket: int = 10, pre_build_bucket: str = None):
|
233
|
-
super().__init__(step_size=step_size, num_bucket=num_bucket, pre_build_bucket=pre_build_bucket)
|
234
|
-
|
235
|
-
def build_buckets_from_images(self):
|
236
|
-
'''
|
237
|
-
根据图像尺寸聚类,不会resize图像,只有剪裁和填充操作。
|
238
|
-
'''
|
239
|
-
logger.info('build buckets from images size')
|
240
|
-
size_list = []
|
241
|
-
for i, (file, source) in enumerate(self.file_names):
|
242
|
-
w, h = get_image_size(file)
|
243
|
-
size_list.append([w, h])
|
244
|
-
size_list = np.array(size_list)
|
245
|
-
|
246
|
-
# 聚类,选出指定个数的bucket
|
247
|
-
kmeans = KMeans(n_clusters=self.num_bucket, random_state=3407).fit(size_list)
|
248
|
-
labels = kmeans.labels_
|
249
|
-
size_buckets = kmeans.cluster_centers_
|
250
|
-
|
251
|
-
# SD需要边长是8的倍数
|
252
|
-
self.size_buckets = (np.round(size_buckets/self.step_size)*self.step_size).astype(int)
|
253
|
-
|
254
|
-
self.buckets = [] # [bucket_id:[file_idx,...]]
|
255
|
-
self.idx_bucket_map = np.empty(len(self.file_names), dtype=int)
|
256
|
-
for bidx in range(self.num_bucket):
|
257
|
-
bnow = labels == bidx
|
258
|
-
self.buckets.append(np.where(bnow)[0].tolist())
|
259
|
-
self.idx_bucket_map[bnow] = bidx
|
260
|
-
logger.info('buckets info: '+', '.join(f'size:{self.size_buckets[i]}, num:{len(b)}' for i, b in enumerate(self.buckets)))
|
261
|
-
|
262
|
-
def crop_resize(self, image, size):
|
263
|
-
return pad_crop_fix(image, size)
|
264
|
-
|
265
|
-
@classmethod
|
266
|
-
def from_files(cls, step_size: int = 8, num_bucket: int = 10, pre_build_bucket: str = None, **kwargs):
|
267
|
-
arb = cls(step_size, num_bucket, pre_build_bucket=pre_build_bucket)
|
268
|
-
arb._build = arb.build_buckets_from_images
|
269
|
-
return arb
|
270
|
-
|
271
|
-
class RatioSizeBucket(RatioBucket):
|
272
|
-
def __init__(self, step_size: int = 8, num_bucket: int = 10, max_area:int=640*640, pre_build_bucket: str = None):
|
273
|
-
super().__init__(step_size=step_size, num_bucket=num_bucket, pre_build_bucket=pre_build_bucket)
|
274
|
-
self.max_area = max_area
|
275
|
-
|
276
|
-
def build_buckets_from_images(self):
|
277
|
-
'''
|
278
|
-
根据图像尺寸聚类,不会resize图像,只有剪裁和填充操作。
|
279
|
-
'''
|
280
|
-
logger.info('build buckets from images')
|
281
|
-
ratio_list = []
|
282
|
-
for i, (file, source) in enumerate(self.file_names):
|
283
|
-
w, h = get_image_size(file)
|
284
|
-
ratio = np.log2(w/h)
|
285
|
-
log_area = np.log2(min(w*h, self.max_area))
|
286
|
-
ratio_list.append([ratio, log_area])
|
287
|
-
ratio_list = np.array(ratio_list)
|
288
|
-
|
289
|
-
# 聚类,选出指定个数的bucket
|
290
|
-
kmeans = KMeans(n_clusters=self.num_bucket, random_state=3407).fit(ratio_list)
|
291
|
-
labels = kmeans.labels_
|
292
|
-
ratios = 2**kmeans.cluster_centers_[:, 0]
|
293
|
-
sizes = 2**kmeans.cluster_centers_[:, 1]
|
294
|
-
|
295
|
-
h_all = np.sqrt(sizes/ratios)
|
296
|
-
w_all = h_all*ratios
|
297
|
-
|
298
|
-
# SD需要边长是8的倍数
|
299
|
-
h_all = (np.round(h_all/self.step_size)*self.step_size).astype(int)
|
300
|
-
w_all = (np.round(w_all/self.step_size)*self.step_size).astype(int)
|
301
|
-
self.size_buckets = list(zip(w_all, h_all))
|
302
|
-
self.size_buckets = np.array(self.size_buckets)
|
303
|
-
|
304
|
-
self.buckets = [] # [bucket_id:[file_idx,...]]
|
305
|
-
self.idx_bucket_map = np.empty(len(self.file_names), dtype=int)
|
306
|
-
for bidx in range(self.num_bucket):
|
307
|
-
bnow = labels == bidx
|
308
|
-
self.buckets.append(np.where(bnow)[0].tolist())
|
309
|
-
self.idx_bucket_map[bnow] = bidx
|
310
|
-
logger.info('buckets info: '+', '.join(f'size:{self.size_buckets[i]}, num:{len(b)}' for i, b in enumerate(self.buckets)))
|
311
|
-
|
312
|
-
@classmethod
|
313
|
-
def from_files(cls, step_size: int = 8, num_bucket: int = 10, max_area:int=640*640, pre_build_bucket: str = None, **kwargs):
|
314
|
-
arb = cls(step_size, num_bucket, max_area=max_area, pre_build_bucket=pre_build_bucket)
|
315
|
-
arb._build = arb.build_buckets_from_images
|
316
|
-
return arb
|
317
|
-
|
318
|
-
class LongEdgeBucket(RatioBucket):
|
319
|
-
def __init__(self, target_edge=640, step_size: int = 8, num_bucket: int = 10, pre_build_bucket: str = None):
|
320
|
-
super().__init__(step_size=step_size, num_bucket=num_bucket, pre_build_bucket=pre_build_bucket)
|
321
|
-
self.target_edge = target_edge
|
322
|
-
|
323
|
-
def build_buckets_from_images(self):
|
324
|
-
'''
|
325
|
-
根据图像尺寸聚类,不会resize图像,只有剪裁和填充操作。
|
326
|
-
'''
|
327
|
-
logger.info('build buckets from images size')
|
328
|
-
size_list = []
|
329
|
-
for i, (file, source) in enumerate(self.file_names):
|
330
|
-
w, h = get_image_size(file)
|
331
|
-
scale = self.target_edge/max(w, h)
|
332
|
-
size_list.append([round(w*scale), round(h*scale)])
|
333
|
-
size_list = np.array(size_list)
|
334
|
-
|
335
|
-
# 聚类,选出指定个数的bucket
|
336
|
-
kmeans = KMeans(n_clusters=self.num_bucket, random_state=3407, verbose=True).fit(size_list)
|
337
|
-
labels = kmeans.labels_
|
338
|
-
size_buckets = kmeans.cluster_centers_
|
339
|
-
|
340
|
-
# SD需要边长是8的倍数
|
341
|
-
self.size_buckets = (np.round(size_buckets/self.step_size)*self.step_size).astype(int)
|
342
|
-
|
343
|
-
self.buckets = [] # [bucket_id:[file_idx,...]]
|
344
|
-
self.idx_bucket_map = np.empty(len(self.file_names), dtype=int)
|
345
|
-
for bidx in range(self.num_bucket):
|
346
|
-
bnow = labels == bidx
|
347
|
-
self.buckets.append(np.where(bnow)[0].tolist())
|
348
|
-
self.idx_bucket_map[bnow] = bidx
|
349
|
-
logger.info('buckets info: '+', '.join(f'size:{self.size_buckets[i]}, num:{len(b)}' for i, b in enumerate(self.buckets)))
|
350
|
-
|
351
|
-
def crop_resize(self, image, size):
|
352
|
-
return resize_crop_fix(image, size)
|
353
|
-
|
354
|
-
@classmethod
|
355
|
-
def from_files(cls, target_edge, step_size: int = 8, num_bucket: int = 10, pre_build_bucket: str = None, **kwargs):
|
356
|
-
arb = cls(target_edge, step_size, num_bucket, pre_build_bucket=pre_build_bucket)
|
357
|
-
arb._build = arb.build_buckets_from_images
|
358
|
-
return arb
|
hcpdiff/data/caption_loader.py
DELETED
@@ -1,80 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import os
|
3
|
-
import glob
|
4
|
-
import yaml
|
5
|
-
from typing import Dict
|
6
|
-
|
7
|
-
from loguru import logger
|
8
|
-
from hcpdiff.utils.img_size_tool import types_support
|
9
|
-
import os
|
10
|
-
|
11
|
-
class BaseCaptionLoader:
|
12
|
-
def __init__(self, path):
|
13
|
-
self.path = path
|
14
|
-
|
15
|
-
def _load(self):
|
16
|
-
raise NotImplementedError
|
17
|
-
|
18
|
-
def load(self):
|
19
|
-
retval = self._load()
|
20
|
-
logger.info(f'{len(retval)} record(s) loaded with {self.__class__.__name__}, from path {self.path!r}')
|
21
|
-
return retval
|
22
|
-
|
23
|
-
@staticmethod
|
24
|
-
def clean_ext(captions:Dict[str, str]):
|
25
|
-
def rm_ext(path):
|
26
|
-
name, ext = os.path.splitext(path)
|
27
|
-
if len(ext)>0 and ext[1:] in types_support:
|
28
|
-
return name
|
29
|
-
return path
|
30
|
-
return {rm_ext(k):v for k,v in captions.items()}
|
31
|
-
|
32
|
-
class JsonCaptionLoader(BaseCaptionLoader):
|
33
|
-
def _load(self):
|
34
|
-
with open(self.path, 'r', encoding='utf-8') as f:
|
35
|
-
return self.clean_ext(json.loads(f.read()))
|
36
|
-
|
37
|
-
class YamlCaptionLoader(BaseCaptionLoader):
|
38
|
-
def _load(self):
|
39
|
-
with open(self.path, 'r', encoding='utf-8') as f:
|
40
|
-
return self.clean_ext(yaml.load(f.read(), Loader=yaml.FullLoader))
|
41
|
-
|
42
|
-
class TXTCaptionLoader(BaseCaptionLoader):
|
43
|
-
def _load(self):
|
44
|
-
txt_files = glob.glob(os.path.join(self.path, '*.txt'))
|
45
|
-
captions = {}
|
46
|
-
for file in txt_files:
|
47
|
-
with open(file, 'r', encoding='utf-8') as f:
|
48
|
-
captions[os.path.basename(file).split('.')[0]] = f.read().strip()
|
49
|
-
return captions
|
50
|
-
|
51
|
-
def auto_caption_loader(path):
|
52
|
-
if os.path.isdir(path):
|
53
|
-
json_files = glob.glob(os.path.join(path, '*.json'))
|
54
|
-
if json_files:
|
55
|
-
return JsonCaptionLoader(json_files[0])
|
56
|
-
|
57
|
-
yaml_files = [
|
58
|
-
*glob.glob(os.path.join(path, '*.yaml')),
|
59
|
-
*glob.glob(os.path.join(path, '*.yml')),
|
60
|
-
]
|
61
|
-
if yaml_files:
|
62
|
-
return YamlCaptionLoader(yaml_files[0])
|
63
|
-
|
64
|
-
txt_files = glob.glob(os.path.join(path, '*.txt'))
|
65
|
-
if txt_files:
|
66
|
-
return TXTCaptionLoader(path)
|
67
|
-
|
68
|
-
raise FileNotFoundError(f'Caption file not found in directory {path!r}.')
|
69
|
-
|
70
|
-
elif os.path.isfile(path):
|
71
|
-
_, ext = os.path.splitext(path)
|
72
|
-
if ext == '.json':
|
73
|
-
return JsonCaptionLoader(path)
|
74
|
-
elif ext in {'.yaml', '.yml'}:
|
75
|
-
return YamlCaptionLoader(path)
|
76
|
-
else:
|
77
|
-
raise FileNotFoundError(f'Unknown caption file {path!r}.')
|
78
|
-
|
79
|
-
else:
|
80
|
-
raise FileNotFoundError(f'Unknown caption file type {path!r}.')
|
hcpdiff/data/cond_dataset.py
DELETED
@@ -1,40 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
pair_dataset.py
|
3
|
-
====================
|
4
|
-
:Name: text-image pair dataset
|
5
|
-
:Author: Dong Ziyi
|
6
|
-
:Affiliation: HCP Lab, SYSU
|
7
|
-
:Created: 10/03/2023
|
8
|
-
:Licence: Apache-2.0
|
9
|
-
"""
|
10
|
-
|
11
|
-
import cv2
|
12
|
-
import torch
|
13
|
-
|
14
|
-
from .pair_dataset import TextImagePairDataset
|
15
|
-
|
16
|
-
class TextImageCondPairDataset(TextImagePairDataset):
|
17
|
-
"""
|
18
|
-
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
19
|
-
It pre-processes the images and the tokenizes prompts.
|
20
|
-
"""
|
21
|
-
|
22
|
-
def load_data(self, path, data_source, size):
|
23
|
-
image_dict = data_source.load_image(path)
|
24
|
-
image = image_dict['image']
|
25
|
-
att_mask = image_dict.get('att_mask', None)
|
26
|
-
img_cond = image_dict.get('cond', None)
|
27
|
-
if img_cond is None:
|
28
|
-
raise FileNotFoundError(f'{self.__class__} need the condition images!')
|
29
|
-
|
30
|
-
if att_mask is None:
|
31
|
-
data, crop_coord = self.bucket.crop_resize({"img":image, "cond":img_cond}, size)
|
32
|
-
image = data_source.procees_image(data['img']) # resize to bucket size
|
33
|
-
img_cond = data_source.cond_transform(data['cond'])
|
34
|
-
att_mask = torch.ones((size[1]//8, size[0]//8))
|
35
|
-
else:
|
36
|
-
data, crop_coord = self.bucket.crop_resize({"img":image, "mask":att_mask, "cond":img_cond}, size)
|
37
|
-
image = data_source.procees_image(data['img'])
|
38
|
-
img_cond = data_source.cond_transform(data['cond'])
|
39
|
-
att_mask = torch.tensor(cv2.resize(data['mask'], (size[0]//8, size[1]//8), interpolation=cv2.INTER_LINEAR))
|
40
|
-
return {'img':image, 'mask':att_mask, 'plugin_input':{"cond":img_cond}}
|
@@ -1,40 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
pair_dataset.py
|
3
|
-
====================
|
4
|
-
:Name: text-image pair dataset
|
5
|
-
:Author: Dong Ziyi
|
6
|
-
:Affiliation: HCP Lab, SYSU
|
7
|
-
:Created: 10/03/2023
|
8
|
-
:Licence: Apache-2.0
|
9
|
-
"""
|
10
|
-
|
11
|
-
from typing import Callable, Iterable, Dict
|
12
|
-
from .bucket import BaseBucket
|
13
|
-
import os.path
|
14
|
-
|
15
|
-
import torch
|
16
|
-
import cv2
|
17
|
-
from .pair_dataset import TextImagePairDataset
|
18
|
-
from hcpdiff.utils.utils import get_file_name
|
19
|
-
from torchvision import transforms
|
20
|
-
|
21
|
-
class CropInfoPairDataset(TextImagePairDataset):
|
22
|
-
"""
|
23
|
-
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
24
|
-
It pre-processes the images and the tokenizes prompts.
|
25
|
-
"""
|
26
|
-
|
27
|
-
def load_data(self, path, data_source, size):
|
28
|
-
image_dict = data_source.load_image(path)
|
29
|
-
image = image_dict['image']
|
30
|
-
att_mask = image_dict.get('att_mask', None)
|
31
|
-
if att_mask is None:
|
32
|
-
data, crop_coord = self.bucket.crop_resize({"img":image}, size)
|
33
|
-
image = data_source.procees_image(data['img']) # resize to bucket size
|
34
|
-
att_mask = torch.ones((size[1]//8, size[0]//8))
|
35
|
-
else:
|
36
|
-
data, crop_coord = self.bucket.crop_resize({"img":image, "mask":att_mask}, size)
|
37
|
-
image = data_source.procees_image(data['img'])
|
38
|
-
att_mask = torch.tensor(cv2.resize(data['mask'], (size[0]//8, size[1]//8), interpolation=cv2.INTER_LINEAR))
|
39
|
-
crop_info = torch.tensor(crop_coord, dtype=torch.float) # for sdxl
|
40
|
-
return {'img':image, 'mask':att_mask, 'crop_info':crop_info}
|
hcpdiff/data/data_processor.py
DELETED
@@ -1,33 +0,0 @@
|
|
1
|
-
import numpy as np
|
2
|
-
import torch
|
3
|
-
from PIL import Image
|
4
|
-
from diffusers.utils import PIL_INTERPOLATION
|
5
|
-
|
6
|
-
class ControlNetProcessor:
|
7
|
-
def __init__(self, image):
|
8
|
-
self.image_path = image
|
9
|
-
|
10
|
-
def prepare_cond_image(self, image, width, height, batch_size, device):
|
11
|
-
if not isinstance(image, torch.Tensor):
|
12
|
-
if isinstance(image, Image.Image):
|
13
|
-
image = [image]
|
14
|
-
|
15
|
-
if isinstance(image[0], Image.Image):
|
16
|
-
image = [
|
17
|
-
np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image
|
18
|
-
]
|
19
|
-
image = np.concatenate(image, axis=0)
|
20
|
-
image = np.array(image).astype(np.float32)/255.0
|
21
|
-
image = image.transpose(0, 3, 1, 2)
|
22
|
-
image = torch.from_numpy(image)
|
23
|
-
elif isinstance(image[0], torch.Tensor):
|
24
|
-
image = torch.cat(image, dim=0)
|
25
|
-
|
26
|
-
image = image.repeat_interleave(batch_size, dim=0)
|
27
|
-
image = image.to(device=device)
|
28
|
-
|
29
|
-
return image
|
30
|
-
|
31
|
-
def __call__(self, width, height, batch_size, device, dtype):
|
32
|
-
img = Image.open(self.image_path).convert('RGB')
|
33
|
-
return self.prepare_cond_image(img, width, height, batch_size, 'cuda').to(dtype=dtype)
|