hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +80 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +1 -2
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +201 -0
- hcpdiff/easy/cfg/sdxl_train.py +140 -0
- hcpdiff/easy/cfg/t2i.py +177 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/container.py +1 -1
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/embedding_convert.py +6 -2
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +19 -15
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +790 -0
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +46 -33
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +128 -136
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -68
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +84 -52
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.1.dist-info/METADATA +285 -0
- hcpdiff-2.1.dist-info/RECORD +114 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
- hcpdiff-2.1.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -565
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/visualizer.py +0 -258
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.0.dist-info/METADATA +0 -199
- hcpdiff-0.9.0.dist-info/RECORD +0 -155
- hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
hcpdiff/data/pair_dataset.py
DELETED
@@ -1,146 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
pair_dataset.py
|
3
|
-
====================
|
4
|
-
:Name: text-image pair dataset
|
5
|
-
:Author: Dong Ziyi
|
6
|
-
:Affiliation: HCP Lab, SYSU
|
7
|
-
:Created: 10/03/2023
|
8
|
-
:Licence: Apache-2.0
|
9
|
-
"""
|
10
|
-
|
11
|
-
import os.path
|
12
|
-
from argparse import Namespace
|
13
|
-
|
14
|
-
import cv2
|
15
|
-
import torch
|
16
|
-
from PIL import Image
|
17
|
-
from torch.utils.data import Dataset
|
18
|
-
from tqdm.auto import tqdm
|
19
|
-
from typing import Tuple
|
20
|
-
|
21
|
-
from hcpdiff.utils.caption_tools import *
|
22
|
-
from hcpdiff.utils.utils import get_file_name, get_file_ext
|
23
|
-
from .bucket import BaseBucket
|
24
|
-
from .source import DataSource, ComposeDataSource
|
25
|
-
|
26
|
-
class TextImagePairDataset(Dataset):
|
27
|
-
"""
|
28
|
-
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
29
|
-
It pre-processes the images and the tokenizes prompts.
|
30
|
-
"""
|
31
|
-
|
32
|
-
def __init__(self, tokenizer, tokenizer_repeats: int = 1, att_mask_encode: bool = False,
|
33
|
-
bucket: BaseBucket = None, source: Dict[str, DataSource] = None, return_path: bool = False,
|
34
|
-
cache_path:str=None, encoder_attention_mask=False, **kwargs):
|
35
|
-
self.return_path = return_path
|
36
|
-
|
37
|
-
self.tokenizer = tokenizer
|
38
|
-
self.tokenizer_repeats = tokenizer_repeats
|
39
|
-
self.bucket: BaseBucket = bucket
|
40
|
-
self.att_mask_encode = att_mask_encode
|
41
|
-
self.source = ComposeDataSource(source)
|
42
|
-
self.latents = None # Cache latents for faster training. Works only without image argumentations.
|
43
|
-
self.cache_path = cache_path
|
44
|
-
self.encoder_attention_mask = encoder_attention_mask
|
45
|
-
|
46
|
-
def load_data(self, path:str, data_source:DataSource, size:Tuple[int]):
|
47
|
-
image_dict = data_source.load_image(path)
|
48
|
-
image = image_dict['image']
|
49
|
-
att_mask = image_dict.get('att_mask', None)
|
50
|
-
if att_mask is None:
|
51
|
-
data, crop_coord = self.bucket.crop_resize({"img":image}, size)
|
52
|
-
image = data_source.procees_image(data['img']) # resize to bucket size
|
53
|
-
att_mask = torch.ones((size[1]//8, size[0]//8))
|
54
|
-
else:
|
55
|
-
data, crop_coord = self.bucket.crop_resize({"img":image, "mask":att_mask}, size)
|
56
|
-
image = data_source.procees_image(data['img'])
|
57
|
-
att_mask = torch.tensor(cv2.resize(data['mask'], (size[0]//8, size[1]//8), interpolation=cv2.INTER_LINEAR))
|
58
|
-
return {'img':image, 'mask':att_mask}
|
59
|
-
|
60
|
-
@torch.no_grad()
|
61
|
-
def cache_latents(self, vae, weight_dtype, device, show_prog=True):
|
62
|
-
if self.cache_path and os.path.exists(self.cache_path):
|
63
|
-
self.latents = torch.load(self.cache_path)
|
64
|
-
return
|
65
|
-
|
66
|
-
self.latents = {}
|
67
|
-
self.bucket.rest(0)
|
68
|
-
|
69
|
-
for (path, data_source), size in tqdm(self.bucket, disable=not show_prog):
|
70
|
-
img_name = data_source.get_image_name(path)
|
71
|
-
if img_name not in self.latents:
|
72
|
-
data = self.load_data(path, data_source, size)
|
73
|
-
image = data['img'].unsqueeze(0).to(device, dtype=weight_dtype)
|
74
|
-
latents = vae.encode(image).latent_dist.sample().squeeze(0)
|
75
|
-
data['img'] = (latents*vae.config.scaling_factor).cpu()
|
76
|
-
self.latents[img_name] = data
|
77
|
-
|
78
|
-
if self.cache_path:
|
79
|
-
torch.save(self.latents, self.cache_path)
|
80
|
-
|
81
|
-
def __len__(self):
|
82
|
-
return len(self.bucket)
|
83
|
-
|
84
|
-
def __getitem__(self, index):
|
85
|
-
(path, data_source), size = self.bucket[index]
|
86
|
-
img_name = data_source.get_image_name(path)
|
87
|
-
|
88
|
-
if self.latents is None:
|
89
|
-
data = self.load_data(path, data_source, size)
|
90
|
-
else:
|
91
|
-
data = self.latents[img_name].copy()
|
92
|
-
|
93
|
-
prompt_ist = data_source.load_caption(img_name)
|
94
|
-
|
95
|
-
# tokenize Sp or (Sn, Sp)
|
96
|
-
tokens = self.tokenizer(prompt_ist, truncation=True, padding="max_length", return_tensors="pt",
|
97
|
-
max_length=self.tokenizer.model_max_length*self.tokenizer_repeats)
|
98
|
-
data['prompt'] = tokens.input_ids.squeeze()
|
99
|
-
if self.encoder_attention_mask and 'attention_mask' in tokens:
|
100
|
-
data['attn_mask'] = tokens.attention_mask.squeeze()
|
101
|
-
if 'position_ids' in tokens:
|
102
|
-
data['position_ids'] = tokens.position_ids.squeeze()
|
103
|
-
|
104
|
-
if self.return_path:
|
105
|
-
return data, path
|
106
|
-
else:
|
107
|
-
return data
|
108
|
-
|
109
|
-
@staticmethod
|
110
|
-
def collate_fn(batch):
|
111
|
-
'''
|
112
|
-
batch: [{img:tensor, prompt:str, ..., plugin_input:{...}},{}]
|
113
|
-
'''
|
114
|
-
has_plugin_input = 'plugin_input' in batch[0]
|
115
|
-
if has_plugin_input:
|
116
|
-
plugin_input = {k:[] for k in batch[0]['plugin_input'].keys()}
|
117
|
-
|
118
|
-
datas = {k:[] for k in batch[0].keys() if k != 'plugin_input' and k != 'prompt'}
|
119
|
-
sn_list, sp_list = [], []
|
120
|
-
|
121
|
-
for data in batch:
|
122
|
-
if has_plugin_input:
|
123
|
-
for k, v in data.pop('plugin_input').items():
|
124
|
-
plugin_input[k].append(v)
|
125
|
-
|
126
|
-
prompt = data.pop('prompt')
|
127
|
-
if len(prompt.shape) == 2:
|
128
|
-
sn_list.append(prompt[0])
|
129
|
-
sp_list.append(prompt[1])
|
130
|
-
else:
|
131
|
-
sp_list.append(prompt)
|
132
|
-
|
133
|
-
for k, v in data.items():
|
134
|
-
datas[k].append(v)
|
135
|
-
|
136
|
-
for k, v in datas.items():
|
137
|
-
datas[k] = torch.stack(v)
|
138
|
-
if k == 'mask':
|
139
|
-
datas[k] = datas[k].unsqueeze(1)
|
140
|
-
|
141
|
-
sn_list += sp_list
|
142
|
-
datas['prompt'] = torch.stack(sn_list)
|
143
|
-
if has_plugin_input:
|
144
|
-
datas['plugin_input'] = {k:torch.stack(v) for k, v in plugin_input.items()}
|
145
|
-
|
146
|
-
return datas
|
hcpdiff/data/sampler.py
DELETED
@@ -1,54 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
from torch.utils.data.distributed import DistributedSampler
|
3
|
-
from typing import Iterator
|
4
|
-
import platform
|
5
|
-
import math
|
6
|
-
|
7
|
-
class DistributedCycleSampler(DistributedSampler):
|
8
|
-
_cycle = 1000
|
9
|
-
|
10
|
-
def __iter__(self) -> Iterator:
|
11
|
-
def _iter():
|
12
|
-
while True:
|
13
|
-
if self.shuffle:
|
14
|
-
# deterministically shuffle based on epoch and seed
|
15
|
-
g = torch.Generator()
|
16
|
-
g.manual_seed(self.seed + self.epoch)
|
17
|
-
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
|
18
|
-
else:
|
19
|
-
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
20
|
-
|
21
|
-
if not self.drop_last:
|
22
|
-
# add extra samples to make it evenly divisible
|
23
|
-
padding_size = self.total_size - len(indices)
|
24
|
-
if padding_size <= len(indices):
|
25
|
-
indices += indices[:padding_size]
|
26
|
-
else:
|
27
|
-
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
28
|
-
else:
|
29
|
-
# remove tail of data to make it evenly divisible.
|
30
|
-
indices = indices[:self.total_size]
|
31
|
-
assert len(indices) == self.total_size
|
32
|
-
|
33
|
-
# subsample
|
34
|
-
indices = indices[self.rank:self.total_size:self.num_replicas]
|
35
|
-
assert len(indices) == self.num_samples
|
36
|
-
|
37
|
-
for idx in indices:
|
38
|
-
yield idx
|
39
|
-
self.epoch+=1
|
40
|
-
|
41
|
-
if self.epoch>=self._cycle:
|
42
|
-
break
|
43
|
-
|
44
|
-
return _iter()
|
45
|
-
|
46
|
-
def __len__(self) -> int:
|
47
|
-
return self.num_samples #*self._cycle
|
48
|
-
|
49
|
-
def get_sampler():
|
50
|
-
# Fix DataLoader frequently reload bugs in windows
|
51
|
-
if platform.system().lower() == 'windows':
|
52
|
-
return DistributedCycleSampler
|
53
|
-
else:
|
54
|
-
return DistributedSampler
|
hcpdiff/data/source/base.py
DELETED
@@ -1,30 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
from typing import Dict, List, Tuple, Any
|
3
|
-
|
4
|
-
class DataSource:
|
5
|
-
def __init__(self, img_root, repeat=1, **kwargs):
|
6
|
-
self.img_root = img_root
|
7
|
-
self.repeat = repeat
|
8
|
-
|
9
|
-
def get_image_list(self) -> List[Tuple[str, "DataSource"]]:
|
10
|
-
raise NotImplementedError()
|
11
|
-
|
12
|
-
def procees_image(self, image):
|
13
|
-
raise NotImplementedError()
|
14
|
-
|
15
|
-
def load_image(self, path) -> Dict[str, Any]:
|
16
|
-
raise NotImplementedError()
|
17
|
-
|
18
|
-
def get_image_name(self, path: str) -> str:
|
19
|
-
img_root, img_name = os.path.split(path)
|
20
|
-
return img_name.rsplit('.')[0]
|
21
|
-
|
22
|
-
class ComposeDataSource(DataSource):
|
23
|
-
def __init__(self, source_dict: Dict[str, DataSource]):
|
24
|
-
self.source_dict = source_dict
|
25
|
-
|
26
|
-
def get_image_list(self) -> List[Tuple[str, DataSource]]:
|
27
|
-
img_list = []
|
28
|
-
for source in self.source_dict.values():
|
29
|
-
img_list.extend(source.get_image_list())
|
30
|
-
return img_list
|
hcpdiff/data/utils.py
DELETED
@@ -1,80 +0,0 @@
|
|
1
|
-
import cv2
|
2
|
-
import numpy as np
|
3
|
-
from PIL import Image
|
4
|
-
from torchvision import transforms as T
|
5
|
-
from torchvision.transforms import functional as F
|
6
|
-
|
7
|
-
class DualRandomCrop:
|
8
|
-
def __init__(self, size):
|
9
|
-
self.size = size
|
10
|
-
|
11
|
-
def __call__(self, img):
|
12
|
-
crop_params = T.RandomCrop.get_params(img['img'], self.size)
|
13
|
-
img['img'] = F.crop(img['img'], *crop_params)
|
14
|
-
if "mask" in img:
|
15
|
-
img['mask'] = self.crop(img['mask'], *crop_params)
|
16
|
-
if "cond" in img:
|
17
|
-
img['cond'] = F.crop(img['cond'], *crop_params)
|
18
|
-
return img, crop_params[:2]
|
19
|
-
|
20
|
-
@staticmethod
|
21
|
-
def crop(img: np.ndarray, top: int, left: int, height: int, width: int) -> np.ndarray:
|
22
|
-
right = left+width
|
23
|
-
bottom = top+height
|
24
|
-
return img[top:bottom, left:right, ...]
|
25
|
-
|
26
|
-
def resize_crop_fix(img, target_size, mask_interp=cv2.INTER_CUBIC):
|
27
|
-
w, h = img['img'].size
|
28
|
-
if w == target_size[0] and h == target_size[1]:
|
29
|
-
return img, [h,w,0,0,h,w]
|
30
|
-
|
31
|
-
ratio_img = w/h
|
32
|
-
if ratio_img>target_size[0]/target_size[1]:
|
33
|
-
new_size = (round(ratio_img*target_size[1]), target_size[1])
|
34
|
-
interp_type = Image.LANCZOS if h>target_size[1] else Image.BICUBIC
|
35
|
-
else:
|
36
|
-
new_size = (target_size[0], round(target_size[0]/ratio_img))
|
37
|
-
interp_type = Image.LANCZOS if w>target_size[0] else Image.BICUBIC
|
38
|
-
img['img'] = img['img'].resize(new_size, interp_type)
|
39
|
-
if "mask" in img:
|
40
|
-
img['mask'] = cv2.resize(img['mask'], new_size, interpolation=mask_interp)
|
41
|
-
if "cond" in img:
|
42
|
-
img['cond'] = img['cond'].resize(new_size, interp_type)
|
43
|
-
|
44
|
-
img, crop_coord = DualRandomCrop(target_size[::-1])(img)
|
45
|
-
return img, [*new_size, *crop_coord[::-1], *target_size]
|
46
|
-
|
47
|
-
def pad_crop_fix(img, target_size):
|
48
|
-
w, h = img['img'].size
|
49
|
-
if w == target_size[0] and h == target_size[1]:
|
50
|
-
return img, (h,w,0,0,h,w)
|
51
|
-
|
52
|
-
pad_size = [0, 0, max(target_size[0]-w, 0), max(target_size[1]-h, 0)]
|
53
|
-
if pad_size[2]>0 or pad_size[3]>0:
|
54
|
-
img['img'] = F.pad(img['img'], pad_size)
|
55
|
-
if "mask" in img:
|
56
|
-
img['mask'] = np.pad(img['mask'], ((0, pad_size[3]), (0, pad_size[2])), 'constant', constant_values=(0, 0))
|
57
|
-
if "cond" in img:
|
58
|
-
img['cond'] = F.pad(img['cond'], pad_size)
|
59
|
-
|
60
|
-
if pad_size[2]>0 and pad_size[3]>0:
|
61
|
-
return img, (h,w,0,0,h,w) # No need to crop
|
62
|
-
else:
|
63
|
-
img, crop_coord = DualRandomCrop(target_size[::-1])(img)
|
64
|
-
return img, crop_coord
|
65
|
-
|
66
|
-
class CycleData():
|
67
|
-
def __init__(self, data_loader):
|
68
|
-
self.data_loader = data_loader
|
69
|
-
|
70
|
-
def __iter__(self):
|
71
|
-
self.epoch = 0
|
72
|
-
|
73
|
-
def cycle():
|
74
|
-
while True:
|
75
|
-
self.data_loader.dataset.bucket.rest(self.epoch)
|
76
|
-
for data in self.data_loader:
|
77
|
-
yield data
|
78
|
-
self.epoch += 1
|
79
|
-
|
80
|
-
return cycle()
|
hcpdiff/infer_workflow.py
DELETED
@@ -1,57 +0,0 @@
|
|
1
|
-
import argparse
|
2
|
-
|
3
|
-
import torch
|
4
|
-
import hydra
|
5
|
-
from omegaconf import OmegaConf, DictConfig
|
6
|
-
from easydict import EasyDict
|
7
|
-
|
8
|
-
from hcpdiff.utils.utils import load_config_with_cli
|
9
|
-
from .workflow import MemoryMixin
|
10
|
-
from copy import deepcopy
|
11
|
-
|
12
|
-
class WorkflowRunner:
|
13
|
-
def __init__(self, cfgs):
|
14
|
-
self.cfgs_raw = deepcopy(cfgs)
|
15
|
-
self.cfgs = OmegaConf.structured(cfgs, flags={"allow_objects": True})
|
16
|
-
OmegaConf.resolve(self.cfgs)
|
17
|
-
self.memory = EasyDict(**hydra.utils.instantiate(self.cfgs.memory))
|
18
|
-
self.attach_memory(self.cfgs)
|
19
|
-
|
20
|
-
def start(self):
|
21
|
-
prepare_actions = hydra.utils.instantiate(self.cfgs.prepare)
|
22
|
-
states = self.run(prepare_actions, {'cfgs': self.cfgs_raw})
|
23
|
-
actions = hydra.utils.instantiate(self.cfgs.actions)
|
24
|
-
states = self.run(actions, states)
|
25
|
-
|
26
|
-
def attach_memory(self, cfgs):
|
27
|
-
if OmegaConf.is_dict(cfgs):
|
28
|
-
if '_target_' in cfgs and cfgs['_target_'].endswith('.from_memory'):
|
29
|
-
cfgs._set_flag('allow_objects', True)
|
30
|
-
cfgs['memory'] = self.memory
|
31
|
-
else:
|
32
|
-
for v in cfgs.values():
|
33
|
-
self.attach_memory(v)
|
34
|
-
elif OmegaConf.is_list(cfgs):
|
35
|
-
for v in cfgs:
|
36
|
-
self.attach_memory(v)
|
37
|
-
|
38
|
-
@torch.inference_mode()
|
39
|
-
def run(self, actions, states):
|
40
|
-
N_steps = len(actions)
|
41
|
-
for step, act in enumerate(actions):
|
42
|
-
print(f'[{step+1}/{N_steps}] action: {type(act).__name__}')
|
43
|
-
if isinstance(act, MemoryMixin):
|
44
|
-
states = act(memory=self.memory, **states)
|
45
|
-
else:
|
46
|
-
states = act(**states)
|
47
|
-
print(f'states: {", ".join(states.keys())}')
|
48
|
-
return states
|
49
|
-
|
50
|
-
if __name__ == '__main__':
|
51
|
-
parser = argparse.ArgumentParser(description='HCP-Diffusion workflow')
|
52
|
-
parser.add_argument('--cfg', type=str, default='')
|
53
|
-
args, cfg_args = parser.parse_known_args()
|
54
|
-
cfgs = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
|
55
|
-
|
56
|
-
runner = WorkflowRunner(cfgs)
|
57
|
-
runner.start()
|
hcpdiff/loggers/__init__.py
DELETED
@@ -1,13 +0,0 @@
|
|
1
|
-
from .base_logger import BaseLogger, LoggerGroup
|
2
|
-
from .cli_logger import CLILogger
|
3
|
-
from .webui_logger import WebUILogger
|
4
|
-
|
5
|
-
try:
|
6
|
-
from .tensorboard_logger import TBLogger
|
7
|
-
except:
|
8
|
-
print('tensorboard is not available')
|
9
|
-
|
10
|
-
try:
|
11
|
-
from .wandb_logger import WanDBLogger
|
12
|
-
except:
|
13
|
-
print('wandb is not available')
|
hcpdiff/loggers/base_logger.py
DELETED
@@ -1,76 +0,0 @@
|
|
1
|
-
from typing import Dict, Any, List
|
2
|
-
|
3
|
-
from PIL import Image
|
4
|
-
|
5
|
-
from .preview import ImagePreviewer
|
6
|
-
|
7
|
-
class BaseLogger:
|
8
|
-
def __init__(self, exp_dir, out_path, enable_log_image=False, log_step=10, image_log_step=200):
|
9
|
-
self.exp_dir = exp_dir
|
10
|
-
self.out_path = out_path
|
11
|
-
self.enable_log_image = enable_log_image
|
12
|
-
self.log_step = log_step
|
13
|
-
self.image_log_step = image_log_step
|
14
|
-
self.enable_log = True
|
15
|
-
self.previewer_list: List[ImagePreviewer] = []
|
16
|
-
|
17
|
-
def enable(self):
|
18
|
-
self.enable_log = True
|
19
|
-
|
20
|
-
def disable(self):
|
21
|
-
self.enable_log = False
|
22
|
-
|
23
|
-
def add_previewer(self, previewer: ImagePreviewer):
|
24
|
-
self.previewer_list.append(previewer)
|
25
|
-
|
26
|
-
def info(self, info):
|
27
|
-
if self.enable_log:
|
28
|
-
self._info(info)
|
29
|
-
|
30
|
-
def _info(self, info):
|
31
|
-
raise NotImplementedError()
|
32
|
-
|
33
|
-
def log(self, datas: Dict[str, Any], step: int = 0):
|
34
|
-
if self.enable_log and step%self.log_step == 0:
|
35
|
-
self._log(datas, step)
|
36
|
-
|
37
|
-
def _log(self, datas: Dict[str, Any], step: int = 0):
|
38
|
-
raise NotImplementedError()
|
39
|
-
|
40
|
-
def log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
|
41
|
-
if self.enable_log and self.enable_log_image and step%self.image_log_step == 0:
|
42
|
-
self._log_image(imgs, step)
|
43
|
-
|
44
|
-
def _log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
|
45
|
-
raise NotImplementedError()
|
46
|
-
|
47
|
-
class LoggerGroup:
|
48
|
-
def __init__(self, logger_list: List[BaseLogger]):
|
49
|
-
self.logger_list = logger_list
|
50
|
-
|
51
|
-
def enable(self):
|
52
|
-
for logger in self.logger_list:
|
53
|
-
logger.enable()
|
54
|
-
|
55
|
-
def disable(self):
|
56
|
-
for logger in self.logger_list:
|
57
|
-
logger.disable()
|
58
|
-
|
59
|
-
def add_previewer(self, previewer):
|
60
|
-
for logger in self.logger_list:
|
61
|
-
logger.add_previewer(previewer)
|
62
|
-
|
63
|
-
def info(self, info):
|
64
|
-
for logger in self.logger_list:
|
65
|
-
logger.info(info)
|
66
|
-
|
67
|
-
def log(self, datas: Dict[str, Any], step: int = 0):
|
68
|
-
for logger in self.logger_list:
|
69
|
-
logger.log(datas, step)
|
70
|
-
|
71
|
-
def log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
|
72
|
-
for logger in self.logger_list:
|
73
|
-
logger.log_image(imgs, step)
|
74
|
-
|
75
|
-
def __len__(self):
|
76
|
-
return len(self.logger_list)
|
hcpdiff/loggers/cli_logger.py
DELETED
@@ -1,40 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
from typing import Dict, Any
|
3
|
-
|
4
|
-
from PIL import Image
|
5
|
-
from loguru import logger
|
6
|
-
|
7
|
-
from .base_logger import BaseLogger
|
8
|
-
|
9
|
-
class CLILogger(BaseLogger):
|
10
|
-
def __init__(self, exp_dir, out_path, enable_log_image=False, log_step=10, image_log_step=200,
|
11
|
-
img_log_dir='preview', img_ext='png', img_quality=95):
|
12
|
-
super().__init__(exp_dir, out_path, enable_log_image, log_step, image_log_step)
|
13
|
-
if exp_dir is not None: # exp_dir is only available in local main process
|
14
|
-
logger.add(os.path.join(exp_dir, out_path))
|
15
|
-
if enable_log_image:
|
16
|
-
self.img_log_dir = os.path.join(exp_dir, img_log_dir)
|
17
|
-
os.makedirs(self.img_log_dir, exist_ok=True)
|
18
|
-
self.img_ext = img_ext
|
19
|
-
self.img_quality = img_quality
|
20
|
-
else:
|
21
|
-
self.disable()
|
22
|
-
|
23
|
-
def enable(self):
|
24
|
-
super(CLILogger, self).enable()
|
25
|
-
logger.enable("__main__")
|
26
|
-
|
27
|
-
def disable(self):
|
28
|
-
super(CLILogger, self).disable()
|
29
|
-
logger.disable("__main__")
|
30
|
-
|
31
|
-
def _info(self, info):
|
32
|
-
logger.info(info)
|
33
|
-
|
34
|
-
def _log(self, datas: Dict[str, Any], step: int = 0):
|
35
|
-
logger.info(', '.join([f"{k} = {v['format'].format(*v['data'])}" for k, v in datas.items()]))
|
36
|
-
|
37
|
-
def _log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
|
38
|
-
logger.info(f'log {len(imgs)} images')
|
39
|
-
for name, img in imgs.items():
|
40
|
-
img.save(os.path.join(self.img_log_dir, f'{step}-{name}.{self.img_ext}'), quality=self.img_quality)
|
@@ -1 +0,0 @@
|
|
1
|
-
from .image_previewer import ImagePreviewer
|
@@ -1,149 +0,0 @@
|
|
1
|
-
from contextlib import contextmanager
|
2
|
-
from typing import List
|
3
|
-
|
4
|
-
import hydra
|
5
|
-
import torch
|
6
|
-
from accelerate import infer_auto_device_map, dispatch_model
|
7
|
-
from accelerate.hooks import remove_hook_from_module
|
8
|
-
from diffusers import PNDMScheduler
|
9
|
-
from torch.cuda.amp import autocast
|
10
|
-
|
11
|
-
from hcpdiff.models import TokenizerHook
|
12
|
-
from hcpdiff.utils.net_utils import to_cpu
|
13
|
-
from hcpdiff.utils.utils import prepare_seed, load_config, size_to_int, int_to_size
|
14
|
-
from hcpdiff.utils.utils import to_validate_file
|
15
|
-
from hcpdiff.visualizer import Visualizer
|
16
|
-
|
17
|
-
class ImagePreviewer(Visualizer):
|
18
|
-
def __init__(self, infer_cfg, exp_dir, te_hook,
|
19
|
-
unet, TE, tokenizer, vae, save_cfg=False):
|
20
|
-
self.exp_dir = exp_dir
|
21
|
-
self.cfgs_raw = load_config(infer_cfg)
|
22
|
-
self.cfgs = hydra.utils.instantiate(self.cfgs_raw)
|
23
|
-
self.save_cfg = save_cfg
|
24
|
-
self.offload = 'offload' in self.cfgs and self.cfgs.offload is not None
|
25
|
-
self.dtype = self.dtype_dict[self.cfgs.dtype]
|
26
|
-
|
27
|
-
if getattr(self.cfgs.new_components, 'scheduler', None) is None:
|
28
|
-
scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule='scaled_linear')
|
29
|
-
else:
|
30
|
-
scheduler = self.cfgs.new_components.scheduler
|
31
|
-
|
32
|
-
pipe_cls = self.get_pipeline()
|
33
|
-
self.pipe = pipe_cls(vae=vae, text_encoder=TE, tokenizer=tokenizer, unet=unet, scheduler=scheduler, feature_extractor=None,
|
34
|
-
safety_checker=None, requires_safety_checker=False)
|
35
|
-
|
36
|
-
self.token_ex = TokenizerHook(tokenizer)
|
37
|
-
self.te_hook = te_hook
|
38
|
-
|
39
|
-
if self.cfgs.seed is not None:
|
40
|
-
self.seeds = list(range(self.cfgs.seed, self.cfgs.seed+self.cfgs.num*self.cfgs.bs))
|
41
|
-
else:
|
42
|
-
self.seeds = [None]*(self.cfgs.num*self.cfgs.bs)
|
43
|
-
|
44
|
-
def build_vae_offload(self, offload_cfg):
|
45
|
-
vram = size_to_int(offload_cfg.max_VRAM)
|
46
|
-
if not offload_cfg.vae_cpu:
|
47
|
-
device_map = infer_auto_device_map(self.pipe.vae, max_memory={0:int_to_size(vram >> 5), "cpu":offload_cfg.max_RAM}, dtype=torch.float32)
|
48
|
-
self.pipe.vae = dispatch_model(self.pipe.vae, device_map)
|
49
|
-
else:
|
50
|
-
to_cpu(self.pipe.vae)
|
51
|
-
self.vae_decode_raw = self.pipe.vae.decode
|
52
|
-
|
53
|
-
def vae_decode_offload(latents, return_dict=True, decode_raw=self.pipe.vae.decode):
|
54
|
-
self.pipe.vae.to(dtype=torch.float32)
|
55
|
-
res = decode_raw(latents.cpu().to(dtype=torch.float32), return_dict=return_dict)
|
56
|
-
return res
|
57
|
-
|
58
|
-
self.pipe.vae.decode = vae_decode_offload
|
59
|
-
|
60
|
-
self.vae_encode_raw = self.pipe.vae.encode
|
61
|
-
|
62
|
-
def vae_encode_offload(x, return_dict=True, encode_raw=self.pipe.vae.encode):
|
63
|
-
self.pipe.vae.to(dtype=torch.float32)
|
64
|
-
res = encode_raw(x.cpu().to(dtype=torch.float32), return_dict=return_dict)
|
65
|
-
return res
|
66
|
-
|
67
|
-
self.pipe.vae.encode = vae_encode_offload
|
68
|
-
|
69
|
-
def remove_vae_offload(self, offload_cfg):
|
70
|
-
if not offload_cfg.vae_cpu:
|
71
|
-
remove_hook_from_module(self.pipe.vae, recurse=True)
|
72
|
-
else:
|
73
|
-
self.pipe.vae.encode = self.vae_encode_raw
|
74
|
-
self.pipe.vae.decode = self.vae_decode_raw
|
75
|
-
|
76
|
-
@contextmanager
|
77
|
-
def infer_optimize(self):
|
78
|
-
if getattr(self.cfgs, 'vae_optimize', None) is not None:
|
79
|
-
if self.cfgs.vae_optimize.tiling:
|
80
|
-
self.pipe.vae.enable_tiling()
|
81
|
-
if self.cfgs.vae_optimize.slicing:
|
82
|
-
self.pipe.vae.enable_slicing()
|
83
|
-
vae_device = self.pipe.vae.device
|
84
|
-
if self.offload:
|
85
|
-
self.build_vae_offload(self.cfgs.offload)
|
86
|
-
else:
|
87
|
-
self.pipe.vae.to(self.pipe.unet.device)
|
88
|
-
|
89
|
-
yield
|
90
|
-
|
91
|
-
if self.offload:
|
92
|
-
self.remove_vae_offload(self.cfgs.offload)
|
93
|
-
self.pipe.vae.to(vae_device)
|
94
|
-
self.pipe.vae.disable_tiling()
|
95
|
-
self.pipe.vae.disable_slicing()
|
96
|
-
|
97
|
-
def preview(self):
|
98
|
-
image_list, info_list = [], []
|
99
|
-
with self.infer_optimize():
|
100
|
-
for i in range(self.cfgs.num):
|
101
|
-
prompt = self.cfgs.prompt[i*self.cfgs.bs:(i+1)*self.cfgs.bs] if isinstance(self.cfgs.prompt, list) \
|
102
|
-
else [self.cfgs.prompt]*self.cfgs.bs
|
103
|
-
negative_prompt = self.cfgs.neg_prompt[i*self.cfgs.bs:(i+1)*self.cfgs.bs] if isinstance(self.cfgs.neg_prompt, list) \
|
104
|
-
else [self.cfgs.neg_prompt]*self.cfgs.bs
|
105
|
-
seeds = self.seeds[i*self.cfgs.bs:(i+1)*self.cfgs.bs]
|
106
|
-
images = self.vis_images(prompt=prompt, negative_prompt=negative_prompt, seeds=seeds,
|
107
|
-
**self.cfgs.infer_args)
|
108
|
-
for prompt_i, negative_prompt_i, seed in zip(prompt, negative_prompt, seeds):
|
109
|
-
info_list.append({
|
110
|
-
'prompt':prompt_i,
|
111
|
-
'negative_prompt':negative_prompt_i,
|
112
|
-
'seed':seed,
|
113
|
-
})
|
114
|
-
image_list += images
|
115
|
-
|
116
|
-
return image_list, info_list
|
117
|
-
|
118
|
-
def preview_dict(self):
|
119
|
-
image_list, info_list = self.preview()
|
120
|
-
imgs = {f'{info["seed"]}-{to_validate_file(info["prompt"])}':img for img, info in zip(image_list, info_list)}
|
121
|
-
return imgs
|
122
|
-
|
123
|
-
@torch.no_grad()
|
124
|
-
def vis_images(self, prompt, negative_prompt='', seeds: List[int] = None, **kwargs):
|
125
|
-
G = prepare_seed(seeds or [None]*len(prompt))
|
126
|
-
|
127
|
-
ex_input_dict, pipe_input_dict = self.get_ex_input()
|
128
|
-
kwargs.update(pipe_input_dict)
|
129
|
-
|
130
|
-
mult_p, clean_text_p = self.token_ex.parse_attn_mult(prompt)
|
131
|
-
mult_n, clean_text_n = self.token_ex.parse_attn_mult(negative_prompt)
|
132
|
-
with autocast(enabled=self.cfgs.amp, dtype=self.dtype):
|
133
|
-
emb, pooled_output, attention_mask = self.te_hook.encode_prompt_to_emb(clean_text_n+clean_text_p)
|
134
|
-
if not self.cfgs.encoder_attention_mask:
|
135
|
-
attention_mask = None
|
136
|
-
emb_n, emb_p = emb.chunk(2)
|
137
|
-
emb_p = self.te_hook.mult_attn(emb_p, mult_p)
|
138
|
-
emb_n = self.te_hook.mult_attn(emb_n, mult_n)
|
139
|
-
|
140
|
-
if hasattr(self.pipe.unet, 'input_feeder'):
|
141
|
-
for feeder in self.pipe.unet.input_feeder:
|
142
|
-
feeder(ex_input_dict)
|
143
|
-
|
144
|
-
if pooled_output is not None:
|
145
|
-
pooled_output = pooled_output[-1]
|
146
|
-
|
147
|
-
images = self.pipe(prompt_embeds=emb_p, negative_prompt_embeds=emb_n, callback=self.inter_callback, generator=G,
|
148
|
-
pooled_output=pooled_output, encoder_attention_mask=attention_mask, **kwargs).images
|
149
|
-
return images
|
@@ -1,30 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
from typing import Dict, Any
|
3
|
-
|
4
|
-
import numpy as np
|
5
|
-
from PIL import Image
|
6
|
-
from torch.utils.tensorboard import SummaryWriter
|
7
|
-
|
8
|
-
from .base_logger import BaseLogger
|
9
|
-
|
10
|
-
|
11
|
-
class TBLogger(BaseLogger):
|
12
|
-
def __init__(self, exp_dir, out_path, enable_log_image=False, log_step=10, image_log_step=200):
|
13
|
-
super().__init__(exp_dir, out_path, enable_log_image, log_step, image_log_step)
|
14
|
-
if exp_dir is not None: # exp_dir is only available in local main process
|
15
|
-
self.writer = SummaryWriter(os.path.join(exp_dir, out_path))
|
16
|
-
else:
|
17
|
-
self.writer = None
|
18
|
-
self.disable()
|
19
|
-
|
20
|
-
def _info(self, info):
|
21
|
-
pass
|
22
|
-
|
23
|
-
def _log(self, datas: Dict[str, Any], step: int = 0):
|
24
|
-
for k, v in datas.items():
|
25
|
-
if len(v['data']) == 1:
|
26
|
-
self.writer.add_scalar(k, v['data'][0], global_step=step)
|
27
|
-
|
28
|
-
def _log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
|
29
|
-
for name, img in imgs.items():
|
30
|
-
self.writer.add_image(f'img/{name}', np.array(img), dataformats='HWC', global_step=step)
|