hcpdiff 2.1__py3-none-any.whl → 2.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/ckpt_manager/__init__.py +1 -1
- hcpdiff/ckpt_manager/format/lora_webui.py +13 -5
- hcpdiff/data/__init__.py +2 -2
- hcpdiff/data/handler/__init__.py +1 -1
- hcpdiff/data/handler/diffusion.py +17 -7
- hcpdiff/data/source/__init__.py +2 -1
- hcpdiff/data/source/text.py +40 -0
- hcpdiff/data/source/text2img.py +1 -1
- hcpdiff/easy/cfg/__init__.py +1 -1
- hcpdiff/easy/cfg/sd15_train.py +12 -6
- hcpdiff/easy/cfg/sdxl_train.py +13 -6
- hcpdiff/easy/cfg/t2i.py +64 -13
- hcpdiff/models/text_emb_ex.py +4 -0
- hcpdiff/trainer_ac.py +0 -7
- hcpdiff/trainer_deepspeed.py +47 -0
- hcpdiff/workflow/diffusion.py +6 -5
- hcpdiff/workflow/text.py +6 -25
- {hcpdiff-2.1.dist-info → hcpdiff-2.2.1.dist-info}/METADATA +22 -4
- {hcpdiff-2.1.dist-info → hcpdiff-2.2.1.dist-info}/RECORD +23 -22
- {hcpdiff-2.1.dist-info → hcpdiff-2.2.1.dist-info}/WHEEL +1 -1
- {hcpdiff-2.1.dist-info → hcpdiff-2.2.1.dist-info}/entry_points.txt +1 -0
- hcpdiff/train_deepspeed.py +0 -69
- {hcpdiff-2.1.dist-info → hcpdiff-2.2.1.dist-info}/licenses/LICENSE +0 -0
- {hcpdiff-2.1.dist-info → hcpdiff-2.2.1.dist-info}/top_level.txt +0 -0
hcpdiff/ckpt_manager/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
1
|
from .format import EmbFormat, DiffusersSD15Format, DiffusersModelFormat, DiffusersSDXLFormat, DiffusersPixArtFormat, OfficialSDXLFormat, \
|
2
|
-
OfficialSD15Format
|
2
|
+
OfficialSD15Format, LoraWebuiFormat
|
3
3
|
from .ckpt import EmbSaver, easy_emb_saver
|
4
4
|
from .loader import HCPLoraLoader
|
@@ -2,7 +2,7 @@ import math
|
|
2
2
|
import re
|
3
3
|
from typing import List, Dict, Any
|
4
4
|
|
5
|
-
from rainbowneko.ckpt_manager.format import CkptFormat
|
5
|
+
from rainbowneko.ckpt_manager.format import CkptFormat, SafeTensorFormat
|
6
6
|
from torch.serialization import FILE_LIKE
|
7
7
|
|
8
8
|
class LoraConverter:
|
@@ -36,7 +36,12 @@ class LoraConverter:
|
|
36
36
|
if auto_scale_alpha:
|
37
37
|
sd_unet = self.alpha_scale_from_webui(sd_unet)
|
38
38
|
sd_TE = self.alpha_scale_from_webui(sd_TE)
|
39
|
-
|
39
|
+
|
40
|
+
sd = {
|
41
|
+
**{f'denoiser.{k}':v for k,v in sd_unet.items()},
|
42
|
+
**{f'TE.{k}':v for k,v in sd_TE.items()},
|
43
|
+
}
|
44
|
+
return {'base': sd}
|
40
45
|
|
41
46
|
def convert_to_webui(self, sd_unet, sd_TE, auto_scale_alpha=False, sdxl=False):
|
42
47
|
sd_unet = self.convert_to_webui_(sd_unet, prefix=self.prefix_unet)
|
@@ -207,9 +212,12 @@ class LoraConverter:
|
|
207
212
|
return state
|
208
213
|
|
209
214
|
class LoraWebuiFormat(CkptFormat):
|
210
|
-
def __init__(self, format, auto_scale_alpha=False):
|
215
|
+
def __init__(self, format=None, auto_scale_alpha=False):
|
211
216
|
self.converter = LoraConverter()
|
212
217
|
self.auto_scale_alpha = auto_scale_alpha
|
218
|
+
|
219
|
+
if format is None:
|
220
|
+
format = SafeTensorFormat()
|
213
221
|
self.format = format
|
214
222
|
|
215
223
|
def save_ckpt(self, sd_model: Dict[str, Any], save_f: FILE_LIKE):
|
@@ -240,5 +248,5 @@ class LoraWebuiFormat(CkptFormat):
|
|
240
248
|
sdxl = True
|
241
249
|
break
|
242
250
|
|
243
|
-
|
244
|
-
return
|
251
|
+
sd_all = self.converter.convert_from_webui(sd_webui, auto_scale_alpha=self.auto_scale_alpha, sdxl=sdxl)
|
252
|
+
return sd_all
|
hcpdiff/data/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
1
|
from .dataset import TextImagePairDataset
|
2
|
-
from .source import Text2ImageSource, Text2ImageLossMapSource, Text2ImageCondSource, T2IFolderClassSource
|
3
|
-
from .handler import StableDiffusionHandler, LossMapHandler, DiffusionImageHandler
|
2
|
+
from .source import Text2ImageSource, Text2ImageLossMapSource, Text2ImageCondSource, T2IFolderClassSource, TextSource
|
3
|
+
from .handler import StableDiffusionHandler, LossMapHandler, DiffusionImageHandler, DiffusionTextHandler
|
4
4
|
from .cache import VaeCache
|
hcpdiff/data/handler/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1
|
-
from .diffusion import StableDiffusionHandler, DiffusionImageHandler, LossMapHandler
|
1
|
+
from .diffusion import StableDiffusionHandler, DiffusionImageHandler, LossMapHandler, DiffusionTextHandler
|
2
2
|
from .text import TokenizeHandler, TagEraseHandler, TagDropoutHandler, TagShuffleHandler, TemplateFillHandler
|
3
3
|
from .controlnet import ControlNetHandler
|
@@ -49,14 +49,11 @@ class DiffusionImageHandler(DataHandler):
|
|
49
49
|
else:
|
50
50
|
return self.handlers(dict(image=image, image_size=image_size))
|
51
51
|
|
52
|
-
class
|
53
|
-
def __init__(self,
|
54
|
-
|
55
|
-
erase=0.15, dropout=0.0, shuffle=0.0, word_names={}, tokenize=True):
|
52
|
+
class DiffusionTextHandler(DataHandler):
|
53
|
+
def __init__(self, encoder_attention_mask=False, erase=0.0, dropout=0.0, shuffle=0.0, word_names={}, tokenize=True,
|
54
|
+
key_map_in=('prompt -> prompt', ), key_map_out=('prompt -> prompt', )):
|
56
55
|
super().__init__(key_map_in, key_map_out)
|
57
56
|
|
58
|
-
self.image_handlers = DiffusionImageHandler(bucket)
|
59
|
-
|
60
57
|
text_handlers = {}
|
61
58
|
if dropout>0:
|
62
59
|
text_handlers['dropout'] = TagDropoutHandler(p=dropout)
|
@@ -67,7 +64,20 @@ class StableDiffusionHandler(DataHandler):
|
|
67
64
|
text_handlers['fill'] = TemplateFillHandler(word_names)
|
68
65
|
if tokenize:
|
69
66
|
text_handlers['tokenize'] = TokenizeHandler(encoder_attention_mask)
|
70
|
-
self.
|
67
|
+
self.handlers = HandlerChain(**text_handlers)
|
68
|
+
|
69
|
+
def handle(self, prompt: Union[str, Dict[str, str]]):
|
70
|
+
return self.handlers(dict(prompt=prompt))
|
71
|
+
|
72
|
+
class StableDiffusionHandler(DataHandler):
|
73
|
+
def __init__(self, bucket, encoder_attention_mask=False, key_map_in=('image -> image', 'image_size -> image_size', 'prompt -> prompt'),
|
74
|
+
key_map_out=('image -> image', 'coord -> coord', 'prompt -> prompt'),
|
75
|
+
erase=0.0, dropout=0.0, shuffle=0.0, word_names={}, tokenize=True):
|
76
|
+
super().__init__(key_map_in, key_map_out)
|
77
|
+
|
78
|
+
self.image_handlers = DiffusionImageHandler(bucket)
|
79
|
+
self.text_handlers = DiffusionTextHandler(encoder_attention_mask=encoder_attention_mask, erase=erase, dropout=dropout, shuffle=shuffle,
|
80
|
+
word_names=word_names, tokenize=tokenize)
|
71
81
|
|
72
82
|
def handle(self, image: Image.Image, image_size: np.ndarray[int], prompt: str):
|
73
83
|
return dict(**self.image_handlers(dict(image=image, image_size=image_size)), **self.text_handlers(dict(prompt=prompt)))
|
hcpdiff/data/source/__init__.py
CHANGED
@@ -0,0 +1,40 @@
|
|
1
|
+
from rainbowneko.data import UnLabelSource, DataSource
|
2
|
+
from rainbowneko.data.label_loader import BaseLabelLoader, auto_label_loader
|
3
|
+
from typing import Union, Dict, Any
|
4
|
+
import random
|
5
|
+
|
6
|
+
class TextSource(DataSource):
|
7
|
+
def __init__(self, label_file, prompt_template=None, repeat=1, **kwargs):
|
8
|
+
super().__init__(repeat=repeat)
|
9
|
+
self.label_file = label_file
|
10
|
+
self.label_dict = self._load_label_data(label_file)
|
11
|
+
self.img_ids = self._load_img_ids(self.label_dict)
|
12
|
+
self.prompt_template = self.load_template(prompt_template)
|
13
|
+
|
14
|
+
def _load_img_ids(self, label_dict):
|
15
|
+
return list(label_dict.keys()) * self.repeat
|
16
|
+
|
17
|
+
def _load_label_data(self, label_file: Union[str, BaseLabelLoader]):
|
18
|
+
if label_file is None:
|
19
|
+
return {}
|
20
|
+
elif isinstance(label_file, str):
|
21
|
+
return auto_label_loader(label_file).load()
|
22
|
+
else:
|
23
|
+
return label_file.load()
|
24
|
+
|
25
|
+
def load_template(self, template_file):
|
26
|
+
if template_file is None:
|
27
|
+
return ['{caption}']
|
28
|
+
else:
|
29
|
+
with open(template_file, 'r', encoding='utf-8') as f:
|
30
|
+
return f.read().strip().split('\n')
|
31
|
+
|
32
|
+
def __getitem__(self, index) -> Dict[str, Any]:
|
33
|
+
img_name = self.img_ids[index]
|
34
|
+
return {
|
35
|
+
'id':img_name,
|
36
|
+
'prompt':{
|
37
|
+
'template':random.choice(self.prompt_template),
|
38
|
+
'caption':self.label_dict[img_name],
|
39
|
+
}
|
40
|
+
}
|
hcpdiff/data/source/text2img.py
CHANGED
hcpdiff/easy/cfg/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1
1
|
from .sd15_train import SD15_lora_train, cfg_data_SD_ARB, cfg_data_SD_resize_crop, SD15_finetuning
|
2
2
|
from .sdxl_train import SDXL_lora_train, SDXL_finetuning
|
3
|
-
from .t2i import SD15_t2i, SDXL_t2i, SDXL_t2i_lora, SD15_t2i_lora
|
3
|
+
from .t2i import SD15_t2i, SDXL_t2i, SDXL_t2i_lora, SD15_t2i_lora, SDXL_t2i_parts, SD15_t2i_parts
|
hcpdiff/easy/cfg/sd15_train.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
1
|
import torch
|
2
|
-
from rainbowneko.ckpt_manager import ckpt_saver, LAYERS_TRAINABLE,
|
2
|
+
from rainbowneko.ckpt_manager import ckpt_saver, LAYERS_TRAINABLE, NekoPluginSaver, SafeTensorFormat
|
3
3
|
from rainbowneko.data import RatioBucket, FixedBucket
|
4
4
|
from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
|
5
5
|
from rainbowneko.utils import ConstantLR, Path_Like
|
6
6
|
|
7
|
+
from hcpdiff.ckpt_manager import LoraWebuiFormat
|
7
8
|
from hcpdiff.data import TextImagePairDataset, Text2ImageSource, StableDiffusionHandler
|
8
9
|
from hcpdiff.data import VaeCache
|
9
10
|
from hcpdiff.easy import SD15_auto_loader
|
@@ -46,7 +47,7 @@ def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
|
|
46
47
|
|
47
48
|
optimizer=optimizer,
|
48
49
|
|
49
|
-
|
50
|
+
lr_scheduler=ConstantLR(
|
50
51
|
_partial_=True,
|
51
52
|
warmup_steps=warmup_steps,
|
52
53
|
),
|
@@ -69,7 +70,7 @@ def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
|
|
69
70
|
@neko_cfg
|
70
71
|
def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, lr: float = 1e-4, rank: int = 4, alpha: float = None,
|
71
72
|
clip_skip: int = 0, with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0,
|
72
|
-
name: str = 'SD15'):
|
73
|
+
name: str = 'SD15', save_webui_format=False):
|
73
74
|
with disable_neko_cfg:
|
74
75
|
if alpha is None:
|
75
76
|
alpha = rank
|
@@ -95,6 +96,11 @@ def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
|
|
95
96
|
else:
|
96
97
|
optimizer = torch.optim.AdamW(_partial_=True, betas=(0.9, 0.99))
|
97
98
|
|
99
|
+
if save_webui_format:
|
100
|
+
lora_format = LoraWebuiFormat()
|
101
|
+
else:
|
102
|
+
lora_format = SafeTensorFormat()
|
103
|
+
|
98
104
|
from cfgs.train.py.examples import SD_FT
|
99
105
|
|
100
106
|
return dict(
|
@@ -114,8 +120,8 @@ def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
|
|
114
120
|
|
115
121
|
ckpt_saver=dict(
|
116
122
|
_replace_ = True,
|
117
|
-
lora_unet=
|
118
|
-
|
123
|
+
lora_unet=NekoPluginSaver(
|
124
|
+
format=lora_format,
|
119
125
|
target_plugin='lora1',
|
120
126
|
)
|
121
127
|
),
|
@@ -126,7 +132,7 @@ def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
|
|
126
132
|
|
127
133
|
optimizer=optimizer,
|
128
134
|
|
129
|
-
|
135
|
+
lr_scheduler=ConstantLR(
|
130
136
|
_partial_=True,
|
131
137
|
warmup_steps=warmup_steps,
|
132
138
|
),
|
hcpdiff/easy/cfg/sdxl_train.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1
1
|
import torch
|
2
|
-
from rainbowneko.ckpt_manager import ckpt_saver,
|
2
|
+
from rainbowneko.ckpt_manager import ckpt_saver, NekoPluginSaver, LAYERS_TRAINABLE, SafeTensorFormat
|
3
3
|
from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
|
4
4
|
from rainbowneko.utils import ConstantLR
|
5
5
|
|
6
6
|
from hcpdiff.easy import SDXL_auto_loader
|
7
7
|
from hcpdiff.models import SDXLWrapper
|
8
8
|
from hcpdiff.models.lora_layers_patch import LoraLayer
|
9
|
+
from hcpdiff.ckpt_manager import LoraWebuiFormat
|
9
10
|
|
10
11
|
@neko_cfg
|
11
12
|
def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int = 500, lr: float = 1e-5,
|
@@ -43,7 +44,7 @@ def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
|
|
43
44
|
|
44
45
|
optimizer=optimizer,
|
45
46
|
|
46
|
-
|
47
|
+
lr_scheduler=ConstantLR(
|
47
48
|
_partial_=True,
|
48
49
|
warmup_steps=warmup_steps,
|
49
50
|
),
|
@@ -64,7 +65,8 @@ def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
|
|
64
65
|
|
65
66
|
@neko_cfg
|
66
67
|
def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, lr: float = 1e-4, rank: int = 4, alpha: float = None,
|
67
|
-
with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = '
|
68
|
+
with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SDXL',
|
69
|
+
save_webui_format=False):
|
68
70
|
with disable_neko_cfg:
|
69
71
|
if alpha is None:
|
70
72
|
alpha = rank
|
@@ -90,6 +92,11 @@ def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
|
|
90
92
|
else:
|
91
93
|
optimizer = torch.optim.AdamW(_partial_=True, betas=(0.9, 0.99))
|
92
94
|
|
95
|
+
if save_webui_format:
|
96
|
+
lora_format = LoraWebuiFormat()
|
97
|
+
else:
|
98
|
+
lora_format = SafeTensorFormat()
|
99
|
+
|
93
100
|
from cfgs.train.py.examples import SD_FT
|
94
101
|
|
95
102
|
return dict(
|
@@ -109,8 +116,8 @@ def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
|
|
109
116
|
|
110
117
|
ckpt_saver=dict(
|
111
118
|
_replace_ = True,
|
112
|
-
lora_unet=
|
113
|
-
|
119
|
+
lora_unet=NekoPluginSaver(
|
120
|
+
format=lora_format,
|
114
121
|
target_plugin='lora1',
|
115
122
|
)
|
116
123
|
),
|
@@ -121,7 +128,7 @@ def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
|
|
121
128
|
|
122
129
|
optimizer=optimizer,
|
123
130
|
|
124
|
-
|
131
|
+
lr_scheduler=ConstantLR(
|
125
132
|
_partial_=True,
|
126
133
|
warmup_steps=warmup_steps,
|
127
134
|
),
|
hcpdiff/easy/cfg/t2i.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import torch
|
2
2
|
from rainbowneko.infer.workflow import (Actions, PrepareAction, LoopAction, LoadModelAction)
|
3
|
+
from rainbowneko.ckpt_manager import NekoModelLoader
|
3
4
|
from rainbowneko.parser import neko_cfg, disable_neko_cfg
|
4
5
|
from typing import Union, List
|
5
6
|
|
@@ -25,6 +26,29 @@ def build_model(pretrained_model='ckpts/any5', noise_sampler=Diffusers_SD.dpmpp_
|
|
25
26
|
),
|
26
27
|
])
|
27
28
|
|
29
|
+
@neko_cfg
|
30
|
+
def load_parts(info: List[str]) -> Actions:
|
31
|
+
acts = []
|
32
|
+
for i, path in enumerate(info):
|
33
|
+
part_unet = LoadModelAction(cfg={
|
34
|
+
f'part_unet_{i}':NekoModelLoader(
|
35
|
+
path=path,
|
36
|
+
state_prefix='denoiser.'
|
37
|
+
)
|
38
|
+
}, key_map_in=('denoiser -> model', 'in_preview -> in_preview'))
|
39
|
+
part_TE = LoadModelAction(cfg={
|
40
|
+
f'part_TE_{i}':NekoModelLoader(
|
41
|
+
path=path,
|
42
|
+
state_prefix='TE.',
|
43
|
+
)
|
44
|
+
}, key_map_in=('TE -> model', 'in_preview -> in_preview'))
|
45
|
+
|
46
|
+
with disable_neko_cfg:
|
47
|
+
acts.append(part_unet)
|
48
|
+
acts.append(part_TE)
|
49
|
+
|
50
|
+
return Actions(acts)
|
51
|
+
|
28
52
|
@neko_cfg
|
29
53
|
def load_lora(info: List[List]) -> Actions:
|
30
54
|
lora_acts = []
|
@@ -37,7 +61,7 @@ def load_lora(info: List[List]) -> Actions:
|
|
37
61
|
)
|
38
62
|
}, key_map_in=('denoiser -> model', 'in_preview -> in_preview'))
|
39
63
|
lora_TE = LoadModelAction(cfg={
|
40
|
-
f'
|
64
|
+
f'lora_TE_{i}':HCPLoraLoader(
|
41
65
|
path=item[0],
|
42
66
|
state_prefix='TE.',
|
43
67
|
alpha=item[1],
|
@@ -59,9 +83,9 @@ def optimize_model() -> Actions:
|
|
59
83
|
])
|
60
84
|
|
61
85
|
@neko_cfg
|
62
|
-
def text(prompt, negative_prompt=negative_prompt, bs=4) -> Actions:
|
86
|
+
def text(prompt, negative_prompt=negative_prompt, bs=4, N_repeats=1, layer_skip=1) -> Actions:
|
63
87
|
return Actions([
|
64
|
-
TextHookAction(N_repeats=
|
88
|
+
TextHookAction(N_repeats=N_repeats, layer_skip=layer_skip),
|
65
89
|
AttnMultTextEncodeAction(
|
66
90
|
prompt=prompt,
|
67
91
|
negative_prompt=negative_prompt,
|
@@ -84,9 +108,9 @@ def build_model_SDXL(pretrained_model='ckpts/any5', noise_sampler=Diffusers_SD.d
|
|
84
108
|
])
|
85
109
|
|
86
110
|
@neko_cfg
|
87
|
-
def text_SDXL(prompt, negative_prompt=negative_prompt, bs=4) -> Actions:
|
111
|
+
def text_SDXL(prompt, negative_prompt=negative_prompt, bs=4, N_repeats=1, layer_skip=1) -> Actions:
|
88
112
|
return Actions([
|
89
|
-
TextHookAction(N_repeats=
|
113
|
+
TextHookAction(N_repeats=N_repeats, layer_skip=layer_skip, TE_final_norm=False),
|
90
114
|
AttnMultTextEncodeAction(
|
91
115
|
prompt=prompt,
|
92
116
|
negative_prompt=negative_prompt,
|
@@ -128,11 +152,24 @@ def resize(width=1024, height=1024):
|
|
128
152
|
|
129
153
|
@neko_cfg
|
130
154
|
def SD15_t2i(pretrained_model, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4, width=512, height=512,
|
131
|
-
seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/'):
|
155
|
+
seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
|
156
|
+
return dict(workflow=Actions(actions=[
|
157
|
+
build_model(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
|
158
|
+
optimize_model(),
|
159
|
+
text(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
|
160
|
+
config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
|
161
|
+
diffusion(guidance_scale=guidance_scale),
|
162
|
+
decode(save_root=save_root)
|
163
|
+
]))
|
164
|
+
|
165
|
+
@neko_cfg
|
166
|
+
def SD15_t2i_parts(pretrained_model, parts, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4, width=512, height=512,
|
167
|
+
seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
|
132
168
|
return dict(workflow=Actions(actions=[
|
133
169
|
build_model(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
|
170
|
+
load_parts(parts),
|
134
171
|
optimize_model(),
|
135
|
-
text(prompt=prompt, negative_prompt=negative_prompt, bs=bs),
|
172
|
+
text(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
|
136
173
|
config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
|
137
174
|
diffusion(guidance_scale=guidance_scale),
|
138
175
|
decode(save_root=save_root)
|
@@ -140,12 +177,12 @@ def SD15_t2i(pretrained_model, prompt, negative_prompt=negative_prompt, noise_sa
|
|
140
177
|
|
141
178
|
@neko_cfg
|
142
179
|
def SD15_t2i_lora(pretrained_model, lora_info, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4,
|
143
|
-
width=512, height=512, seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/'):
|
180
|
+
width=512, height=512, seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
|
144
181
|
return dict(workflow=Actions(actions=[
|
145
182
|
build_model(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
|
146
183
|
load_lora(info=lora_info),
|
147
184
|
optimize_model(),
|
148
|
-
text(prompt=prompt, negative_prompt=negative_prompt, bs=bs),
|
185
|
+
text(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
|
149
186
|
config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
|
150
187
|
diffusion(guidance_scale=guidance_scale),
|
151
188
|
decode(save_root=save_root)
|
@@ -153,24 +190,38 @@ def SD15_t2i_lora(pretrained_model, lora_info, prompt, negative_prompt=negative_
|
|
153
190
|
|
154
191
|
@neko_cfg
|
155
192
|
def SDXL_t2i(pretrained_model, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4, width=1024, height=1024,
|
156
|
-
seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/'):
|
193
|
+
seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
|
157
194
|
return dict(workflow=Actions(actions=[
|
158
195
|
build_model_SDXL(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
|
159
196
|
optimize_model(),
|
160
|
-
text_SDXL(prompt=prompt, negative_prompt=negative_prompt, bs=bs),
|
197
|
+
text_SDXL(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
|
161
198
|
config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
|
162
199
|
diffusion(guidance_scale=guidance_scale),
|
163
200
|
decode(save_root=save_root)
|
164
201
|
]))
|
165
202
|
|
203
|
+
@neko_cfg
|
204
|
+
def SDXL_t2i_parts(pretrained_model, parts, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4, width=1024, height=1024,
|
205
|
+
seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
|
206
|
+
return dict(workflow=Actions(actions=[
|
207
|
+
build_model_SDXL(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
|
208
|
+
load_parts(parts),
|
209
|
+
optimize_model(),
|
210
|
+
text_SDXL(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
|
211
|
+
config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
|
212
|
+
diffusion(guidance_scale=guidance_scale),
|
213
|
+
decode(save_root=save_root)
|
214
|
+
]))
|
215
|
+
|
216
|
+
|
166
217
|
@neko_cfg
|
167
218
|
def SDXL_t2i_lora(pretrained_model, lora_info, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4,
|
168
|
-
width=1024, height=1024, seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/'):
|
219
|
+
width=1024, height=1024, seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
|
169
220
|
return dict(workflow=Actions(actions=[
|
170
221
|
build_model_SDXL(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
|
171
222
|
load_lora(info=lora_info),
|
172
223
|
optimize_model(),
|
173
|
-
text_SDXL(prompt=prompt, negative_prompt=negative_prompt, bs=bs),
|
224
|
+
text_SDXL(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
|
174
225
|
config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
|
175
226
|
diffusion(guidance_scale=guidance_scale),
|
176
227
|
decode(save_root=save_root)
|
hcpdiff/models/text_emb_ex.py
CHANGED
@@ -126,6 +126,10 @@ class EmbeddingPTInterpHook(SinglePluginBlock):
|
|
126
126
|
BOS = repeat(inputs_embeds[0,0,:], 'e -> r 1 e', r=self.N_repeats)
|
127
127
|
EOS = repeat(inputs_embeds[0,-1,:], 'e -> r 1 e', r=self.N_repeats)
|
128
128
|
|
129
|
+
# make DDP happy
|
130
|
+
if len(self.emb_train) > 0:
|
131
|
+
BOS = BOS + sum(emb.mean()*0 for emb in self.emb_train if emb.requires_grad)
|
132
|
+
|
129
133
|
replaced_embeds = []
|
130
134
|
for item, rep_idxs, ids_raw in zip(inputs_embeds, rep_idxs_B, self.input_ids):
|
131
135
|
# insert pt to embeddings
|
hcpdiff/trainer_ac.py
CHANGED
@@ -42,13 +42,6 @@ class HCPTrainer(Trainer):
|
|
42
42
|
def pt_trainable(self):
|
43
43
|
return self.cfgs.emb_pt is not None
|
44
44
|
|
45
|
-
def get_loss(self, ds_name, model_pred, inputs):
|
46
|
-
loss = super().get_loss(ds_name, model_pred, inputs)
|
47
|
-
# make DDP happy
|
48
|
-
if len(self.train_pts)>0:
|
49
|
-
loss = loss+0*sum([emb.mean() for emb in self.train_pts.values()])
|
50
|
-
return loss
|
51
|
-
|
52
45
|
def save_model(self, from_raw=False):
|
53
46
|
NekoSaver.save_all(
|
54
47
|
self.model_raw,
|
@@ -0,0 +1,47 @@
|
|
1
|
+
import argparse
|
2
|
+
import warnings
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from rainbowneko.ckpt_manager import NekoPluginSaver
|
6
|
+
from rainbowneko.train.trainer import TrainerDeepspeed
|
7
|
+
from rainbowneko.utils import xformers_available
|
8
|
+
|
9
|
+
from hcpdiff.trainer_ac import HCPTrainer, load_config_with_cli
|
10
|
+
|
11
|
+
class HCPTrainerDeepspeed(TrainerDeepspeed, HCPTrainer):
|
12
|
+
def config_model(self):
|
13
|
+
if self.cfgs.model.enable_xformers:
|
14
|
+
if xformers_available:
|
15
|
+
self.model_wrapper.enable_xformers()
|
16
|
+
else:
|
17
|
+
warnings.warn("xformers is not available. Make sure it is installed correctly")
|
18
|
+
|
19
|
+
if self.model_wrapper.vae is not None:
|
20
|
+
self.vae_dtype = self.weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
|
21
|
+
self.model_wrapper.set_dtype(self.weight_dtype, self.vae_dtype)
|
22
|
+
|
23
|
+
if self.cfgs.model.gradient_checkpointing:
|
24
|
+
self.model_wrapper.enable_gradient_checkpointing()
|
25
|
+
|
26
|
+
if self.is_local_main_process:
|
27
|
+
for saver in self.ckpt_saver.values():
|
28
|
+
if isinstance(saver, NekoPluginSaver):
|
29
|
+
saver.plugin_from_raw = True
|
30
|
+
|
31
|
+
def hcp_train():
|
32
|
+
import subprocess
|
33
|
+
parser = argparse.ArgumentParser(description='HCP-Diffusion Launcher')
|
34
|
+
parser.add_argument('--launch_cfg', type=str, default='cfgs/launcher/deepspeed.yaml')
|
35
|
+
args, train_args = parser.parse_known_args()
|
36
|
+
|
37
|
+
subprocess.run(["accelerate", "launch", '--config_file', args.launch_cfg, "-m",
|
38
|
+
"hcpdiff.trainer_deepspeed"]+train_args, check=True)
|
39
|
+
|
40
|
+
if __name__ == '__main__':
|
41
|
+
parser = argparse.ArgumentParser(description='HCP Diffusion Trainer for DeepSpeed')
|
42
|
+
parser.add_argument("--cfg", type=str, default=None, required=True)
|
43
|
+
args, cfg_args = parser.parse_known_args()
|
44
|
+
|
45
|
+
parser, conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
|
46
|
+
trainer = HCPTrainerDeepspeed(parser, conf)
|
47
|
+
trainer.train()
|
hcpdiff/workflow/diffusion.py
CHANGED
@@ -32,14 +32,15 @@ class SeedAction(BasicAction):
|
|
32
32
|
self.seed = seed
|
33
33
|
self.bs = bs
|
34
34
|
|
35
|
-
def forward(self, device,
|
35
|
+
def forward(self, device, seed=None, **states):
|
36
36
|
bs = states['prompt_embeds'].shape[0]//2 if 'prompt_embeds' in states else self.bs
|
37
|
-
|
37
|
+
seed = seed or self.seed
|
38
|
+
if seed is None:
|
38
39
|
seeds = [None]*bs
|
39
|
-
elif isinstance(
|
40
|
-
seeds = list(range(
|
40
|
+
elif isinstance(seed, int):
|
41
|
+
seeds = list(range(seed, seed+bs))
|
41
42
|
else:
|
42
|
-
seeds =
|
43
|
+
seeds = seed
|
43
44
|
seeds = [s or random.randint(0, 1 << 30) for s in seeds]
|
44
45
|
|
45
46
|
G = prepare_seed(seeds, device=device)
|
hcpdiff/workflow/text.py
CHANGED
@@ -48,18 +48,9 @@ class TextEncodeAction(BasicAction):
|
|
48
48
|
self.negative_prompt = negative_prompt
|
49
49
|
self.bs = bs
|
50
50
|
|
51
|
-
def forward(self, te_hook, TE, dtype: str, device, amp=None,
|
52
|
-
|
53
|
-
|
54
|
-
negative_prompt_all = negative_prompt_all or self.negative_prompt
|
55
|
-
|
56
|
-
if gen_step is not None:
|
57
|
-
idx = (gen_step*self.bs)%len(prompt_all)
|
58
|
-
prompt = prompt_all[idx:idx+self.bs]
|
59
|
-
negative_prompt = negative_prompt_all[idx:idx+self.bs]
|
60
|
-
else:
|
61
|
-
prompt = prompt_all
|
62
|
-
negative_prompt = negative_prompt_all
|
51
|
+
def forward(self, te_hook, TE, dtype: str, device, amp=None, prompt=None, negative_prompt=None, model_offload=False, **states):
|
52
|
+
prompt = prompt or self.prompt
|
53
|
+
negative_prompt = negative_prompt or self.negative_prompt
|
63
54
|
|
64
55
|
if model_offload:
|
65
56
|
to_cuda(TE)
|
@@ -78,19 +69,9 @@ class TextEncodeAction(BasicAction):
|
|
78
69
|
'pooled_output':pooled_output}
|
79
70
|
|
80
71
|
class AttnMultTextEncodeAction(TextEncodeAction):
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
prompt_all = prompt_all if prompt_all is not None else self.prompt
|
85
|
-
negative_prompt_all = negative_prompt_all if negative_prompt_all is not None else self.negative_prompt
|
86
|
-
|
87
|
-
if gen_step is not None:
|
88
|
-
idx = (gen_step*self.bs)%len(prompt_all)
|
89
|
-
prompt = prompt_all[idx:idx+self.bs]
|
90
|
-
negative_prompt = negative_prompt_all[idx:idx+self.bs]
|
91
|
-
else:
|
92
|
-
prompt = prompt_all
|
93
|
-
negative_prompt = negative_prompt_all
|
72
|
+
def forward(self, te_hook, token_ex, TE, dtype: str, device, amp=None, prompt=None, negative_prompt=None, model_offload=False, **states):
|
73
|
+
prompt = prompt or self.prompt
|
74
|
+
negative_prompt = negative_prompt or self.negative_prompt
|
94
75
|
|
95
76
|
if model_offload:
|
96
77
|
to_cuda(TE)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: hcpdiff
|
3
|
-
Version: 2.1
|
3
|
+
Version: 2.2.1
|
4
4
|
Summary: A universal Diffusion toolbox
|
5
5
|
Home-page: https://github.com/IrisRainbowNeko/HCP-Diffusion
|
6
6
|
Author: Ziyi Dong
|
@@ -17,7 +17,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
17
|
Requires-Python: >=3.8
|
18
18
|
Description-Content-Type: text/markdown
|
19
19
|
License-File: LICENSE
|
20
|
-
Requires-Dist: rainbowneko
|
20
|
+
Requires-Dist: rainbowneko==1.6
|
21
21
|
Requires-Dist: diffusers
|
22
22
|
Requires-Dist: matplotlib
|
23
23
|
Requires-Dist: pyarrow
|
@@ -65,6 +65,8 @@ Compared to the original DreamArtist, it offers better stability, image quality,
|
|
65
65
|
|
66
66
|
## Installation
|
67
67
|
|
68
|
+
Install [pytorch](https://pytorch.org/)
|
69
|
+
|
68
70
|
Install via pip:
|
69
71
|
|
70
72
|
```bash
|
@@ -205,6 +207,18 @@ After parsing, the framework will instantiate the components accordingly. This m
|
|
205
207
|
| CCIP Score | 🚧 In Development |
|
206
208
|
| Corrupt Score | 🚧 In Development |
|
207
209
|
|
210
|
+
---
|
211
|
+
|
212
|
+
### ⚡️ Image Generation
|
213
|
+
|
214
|
+
| 功能 | 描述/支持情况 |
|
215
|
+
|------------------------------|------------------------------------|
|
216
|
+
| Batch Generation | ✅ Supported |
|
217
|
+
| Generate from Prompt Dataset | ✅ Supported |
|
218
|
+
| Image to Image | ✅ Supported |
|
219
|
+
| Inpaint | ✅ Supported |
|
220
|
+
| Token Weight | ✅ Supported |
|
221
|
+
|
208
222
|
</details>
|
209
223
|
|
210
224
|
---
|
@@ -248,9 +262,13 @@ hcp_run --cfg cfgs/workflow/text2img_cli.py \
|
|
248
262
|
seed=42
|
249
263
|
```
|
250
264
|
|
251
|
-
### Tutorials
|
265
|
+
### 📚 Tutorials
|
252
266
|
|
253
|
-
|
267
|
+
+ 🧠 [Model Training Guide](https://hcpdiff.readthedocs.io/en/latest/user_guides/train.html)
|
268
|
+
+ 🔧 [LoRA Training Tutorial](https://hcpdiff.readthedocs.io/enlatest/tutorial/lora.html)
|
269
|
+
+ 🎨 [Image Generation Guide](https://hcpdiff.readthedocs.io/en/latest/user_guides/workflow.html)
|
270
|
+
+ ⚙️ [Configuration File Explanation](https://hcpdiff.readthedocs.io/en/latest/user_guides/cfg.html)
|
271
|
+
+ 🧩 [Model Format Explanation](https://hcpdiff.readthedocs.io/en/latest/user_guides/model_format.html)
|
254
272
|
|
255
273
|
---
|
256
274
|
|
@@ -1,27 +1,28 @@
|
|
1
1
|
hcpdiff/__init__.py,sha256=dwNwrEgvG4g60fGMG6b50K3q3AWD1XCfzlIgbxkSUpE,177
|
2
2
|
hcpdiff/train_colo.py,sha256=EsuNSzLBvGTZWU_LEk0JpP-F5eNW0lwkawIRAX38jmE,9250
|
3
|
-
hcpdiff/
|
4
|
-
hcpdiff/trainer_ac.py,sha256=6KAzo54in7ZRHud_rHjJdwRRZ4uWtc0B4SxVCxgcrmM,2990
|
3
|
+
hcpdiff/trainer_ac.py,sha256=scH3FU0onCQtwLiy0-pcrhuowTZob3fLQqRP52iwY0c,2717
|
5
4
|
hcpdiff/trainer_ac_single.py,sha256=0PIC5EScqcxp49EaeIWq4KS5K_09OZfKajqbFu-hUb8,1108
|
6
|
-
hcpdiff/
|
5
|
+
hcpdiff/trainer_deepspeed.py,sha256=7lGsiAstWuIlmhRMwWTcJCkoxzUaakVxBngKDnJdSJk,1947
|
6
|
+
hcpdiff/ckpt_manager/__init__.py,sha256=Mn_5KOC4xbf2GcN6OXg_XdbF5wO9zWeER_1ZO_prKAI,256
|
7
7
|
hcpdiff/ckpt_manager/ckpt.py,sha256=Pa3uXQbCi2T99mpV5fYddQ-OGHcpk8r1ll-0lmP_WXk,965
|
8
8
|
hcpdiff/ckpt_manager/loader.py,sha256=Ch1xsZmseq4nyPhpox9-nebN-dZB4k0rqBEHos-ZLso,3245
|
9
9
|
hcpdiff/ckpt_manager/format/__init__.py,sha256=a3cdKkOTDgdVbDQwSC4mlxOigjX2hBvRb5_X7E3TQWs,237
|
10
10
|
hcpdiff/ckpt_manager/format/diffusers.py,sha256=T81WN95Nj1il9DfQp9iioVn0uqFEWOlmdIYs2beNOFU,3769
|
11
11
|
hcpdiff/ckpt_manager/format/emb.py,sha256=FrqfTfJ8H7f0Zw17NTWCP2AJtpsJI5oXR5IAd4NekhU,680
|
12
|
-
hcpdiff/ckpt_manager/format/lora_webui.py,sha256=
|
12
|
+
hcpdiff/ckpt_manager/format/lora_webui.py,sha256=4y_T9RdmFTxWzsXd8guNjCiukmyILa5j4MPrhVIL4Qk,10017
|
13
13
|
hcpdiff/ckpt_manager/format/sd_single.py,sha256=LpCAL_7nAVooCHTFznVVsNMku1G3C77NBORxxr8GDtQ,2328
|
14
|
-
hcpdiff/data/__init__.py,sha256
|
14
|
+
hcpdiff/data/__init__.py,sha256=ZFKtanOoMo3G3eKUJPhysnHXnr8BNARERkcMB6B897U,292
|
15
15
|
hcpdiff/data/dataset.py,sha256=1k4GldW13eVyqK_9hrQniqr3_XYAapnWF7iXl_1GXGg,877
|
16
16
|
hcpdiff/data/cache/__init__.py,sha256=ToCmokYH6DghlSwm7HJFirPRIWJ0LkgzqVOYlgoAkQw,25
|
17
17
|
hcpdiff/data/cache/vae.py,sha256=gB89zs4CdNlvukDXhVYU9QZrY6VTFUWfzjeF2psNQ50,4070
|
18
|
-
hcpdiff/data/handler/__init__.py,sha256=
|
18
|
+
hcpdiff/data/handler/__init__.py,sha256=G8ZTQF91ilkTRmUoWdmAissTSZ7fvNUpm_hBYmXKTtk,258
|
19
19
|
hcpdiff/data/handler/controlnet.py,sha256=bRDMD9BP8-VaG5VrxzvcFKfkqeTbChNfrJSZ3vXbQgY,658
|
20
|
-
hcpdiff/data/handler/diffusion.py,sha256=
|
20
|
+
hcpdiff/data/handler/diffusion.py,sha256=S-_7o5Z1tm6LmRZVZs21rbJC7iUoq0tHOsSjKK6geVk,4156
|
21
21
|
hcpdiff/data/handler/text.py,sha256=gOzqB2oEkEUbiuy0kZWduo0c-w4Buu60KI6q6Nyl3aM,4208
|
22
|
-
hcpdiff/data/source/__init__.py,sha256=
|
22
|
+
hcpdiff/data/source/__init__.py,sha256=265M8qfWNUE4SKX0pdXhLYjCnCuae5YE4bfZpO-ydXc,187
|
23
23
|
hcpdiff/data/source/folder_class.py,sha256=bs4qPMTzwcnT6ZFlT3tpi9sclsRF9a2MBA1pQD-9EYs,961
|
24
|
-
hcpdiff/data/source/
|
24
|
+
hcpdiff/data/source/text.py,sha256=VgI5Ouq986Yy1jwD2fZ9iBlsRciPCeARZmOPEZIcaQY,1468
|
25
|
+
hcpdiff/data/source/text2img.py,sha256=acYdolQhZUEpkd7tUAdNkCTVnPc1SMJOVTmGqFt9ZpE,1813
|
25
26
|
hcpdiff/data/source/text2img_cond.py,sha256=yj1KpARA2rkjENutnnzC4uDkcU2Rye21FL2VdC25Hac,585
|
26
27
|
hcpdiff/diffusion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
28
|
hcpdiff/diffusion/noise/__init__.py,sha256=seBpOtd0YsU53PqMn7Nyl_RtwoC-ONEIOX7v2XLGpZQ,93
|
@@ -38,10 +39,10 @@ hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py,sha256=2PMIpg2K6CVoxew1y1pIqvC
|
|
38
39
|
hcpdiff/diffusion/sampler/sigma_scheduler/edm.py,sha256=fOPB3lgnS9uVo4oW26Fur_nc8X_wQ6mmUcbkKhnoQjs,1900
|
39
40
|
hcpdiff/easy/__init__.py,sha256=-emoyCOZlLCu3KNMI8L4qapUEtEYFSoiGU6-rKv1at4,149
|
40
41
|
hcpdiff/easy/sampler.py,sha256=dQSBkeGh71O0DAmZLhTHTbk1bY7XzyUCeW1oJO14A4I,1250
|
41
|
-
hcpdiff/easy/cfg/__init__.py,sha256=
|
42
|
-
hcpdiff/easy/cfg/sd15_train.py,sha256=
|
43
|
-
hcpdiff/easy/cfg/sdxl_train.py,sha256=
|
44
|
-
hcpdiff/easy/cfg/t2i.py,sha256=
|
42
|
+
hcpdiff/easy/cfg/__init__.py,sha256=SxHMWG6T2CXhX3dP0xizSMd9vFWPaZQDc4Gj4CF__yQ,253
|
43
|
+
hcpdiff/easy/cfg/sd15_train.py,sha256=kKdESVqAxNlBhhz12PvwrpHJBea80OUFzDDMHwiulVs,6710
|
44
|
+
hcpdiff/easy/cfg/sdxl_train.py,sha256=FUWE_hRJdQc9Qd9J6730jAyK0H4EIKS7-3BSufCItXU,4275
|
45
|
+
hcpdiff/easy/cfg/t2i.py,sha256=SnjFjZAKd9orjJr3RW5_N2_EIlW2Ree7JMvdNUAR9gc,9507
|
45
46
|
hcpdiff/easy/model/__init__.py,sha256=CA-7r3R2Jgweekk1XNByFYttLolbWyUV2bCnXygcD8w,133
|
46
47
|
hcpdiff/easy/model/cnet.py,sha256=m0NTH9V1kLzb5GybwBrSNT0KvTcRpPfGkzUeMz9jZZQ,1084
|
47
48
|
hcpdiff/easy/model/loader.py,sha256=Tdx-lhQEYf2NYjVM1A5B8x6ZZpJKcXUkFIPIbr7h7XM,3456
|
@@ -61,7 +62,7 @@ hcpdiff/models/lora_base.py,sha256=LGwBD9KP6qf4pgTx24i5-JLo4rDBQ6jFfterQKBjTbE,6
|
|
61
62
|
hcpdiff/models/lora_base_patch.py,sha256=WW3CULnROTxKXyynJiqirhHYCKN5JtxLhVpT5b7AUQg,6532
|
62
63
|
hcpdiff/models/lora_layers.py,sha256=O9W_Ue71lHj7Y_GbpioF4Hc3h2-z_zOqck93VYUra6s,7777
|
63
64
|
hcpdiff/models/lora_layers_patch.py,sha256=GYFYsJD2VSLZfdnLma9CmQEHz09HROFJcc4wc_gs9f0,8198
|
64
|
-
hcpdiff/models/text_emb_ex.py,sha256=
|
65
|
+
hcpdiff/models/text_emb_ex.py,sha256=O0XZqid01OrB0dHY7hCiBvdU2026SvZ38yfQaF2TWrs,8018
|
65
66
|
hcpdiff/models/textencoder_ex.py,sha256=JrTQ30Avx8tPbdr-Q6K5BvEWCEdsu8Z7eSOzMqpUuzg,8270
|
66
67
|
hcpdiff/models/tokenizer_ex.py,sha256=zKUn4BY7b3yXwK9PWkZtQKJPyKYwUc07E-hwB9NQybs,2446
|
67
68
|
hcpdiff/models/compose/__init__.py,sha256=lTNFTGg5csqvUuys22RqgjmWlk_7Okw6ZTsnTi1pqCg,217
|
@@ -95,20 +96,20 @@ hcpdiff/utils/net_utils.py,sha256=gdwLYDNKV2t3SP0jBIO3d0HtY6E7jRaf_rmPT8gKZZE,97
|
|
95
96
|
hcpdiff/utils/pipe_hook.py,sha256=-UDX3FtZGl-bxSk13gdbPXc1OvtbCcpk_fvKxLQo3Ag,31987
|
96
97
|
hcpdiff/utils/utils.py,sha256=hZnZP1IETgVpScxES0yIuRfc34TnzvAqmgOTK_56ssw,4976
|
97
98
|
hcpdiff/workflow/__init__.py,sha256=t7Zyc0XFORdNvcwHp9AsCtEkhJ3l7Hm41ugngIL0Sag,867
|
98
|
-
hcpdiff/workflow/diffusion.py,sha256=
|
99
|
+
hcpdiff/workflow/diffusion.py,sha256=yzhqKA3019OPu1RKggrLoytMgm919qf6j9S85PYOwjQ,8644
|
99
100
|
hcpdiff/workflow/fast.py,sha256=kZt7bKrvpFInSn7GzbkTkpoCSM0Z6IbDjgaDvcbFYf8,1024
|
100
101
|
hcpdiff/workflow/flow.py,sha256=FFbFFOAXT4c31L5bHBEB_qeVGuBQDLYhq8kTD1chGNo,2548
|
101
102
|
hcpdiff/workflow/io.py,sha256=aTrMR3s44apVJpnSyvZIabW2Op0tslk_Z9JFJl5svm0,2635
|
102
103
|
hcpdiff/workflow/model.py,sha256=1gj5yOTefYTnGXVR6JPAfxIwuB69YwN6E-BontRcuyQ,2913
|
103
|
-
hcpdiff/workflow/text.py,sha256=
|
104
|
+
hcpdiff/workflow/text.py,sha256=Z__SJHZyuaKyzkYJ6rbiAzOGRiYcCjwCGeqfpP1Jo7o,4336
|
104
105
|
hcpdiff/workflow/utils.py,sha256=xojaMG4lHsymslc8df5uiVXmmBVWpn_Phqka8qzJEWw,2226
|
105
106
|
hcpdiff/workflow/vae.py,sha256=cingDPkIOc4qGpOwwhXJK4EQbGoIxO583pm6gGov5t8,3118
|
106
107
|
hcpdiff/workflow/daam/__init__.py,sha256=ySIDaxloN-D3qM7OuVaG1BR3D-CibDoXYpoTgw0zUhU,59
|
107
108
|
hcpdiff/workflow/daam/act.py,sha256=tHbsFWTYYU4bvcZOo1Bpi_z6ofpJatRYccl4vvf8wIA,2756
|
108
109
|
hcpdiff/workflow/daam/hook.py,sha256=z9f9mBjKW21xuUZ-iQxQ0HbWOBXtZrisFB0VNMq6d0U,4383
|
109
|
-
hcpdiff-2.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
110
|
-
hcpdiff-2.1.dist-info/METADATA,sha256=
|
111
|
-
hcpdiff-2.1.dist-info/WHEEL,sha256=
|
112
|
-
hcpdiff-2.1.dist-info/entry_points.txt,sha256=
|
113
|
-
hcpdiff-2.1.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
|
114
|
-
hcpdiff-2.1.dist-info/RECORD,,
|
110
|
+
hcpdiff-2.2.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
111
|
+
hcpdiff-2.2.1.dist-info/METADATA,sha256=f96Tc90K5WTBbJ35wWJw60G2JR46eGpUvQSaPIysVDg,10323
|
112
|
+
hcpdiff-2.2.1.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
|
113
|
+
hcpdiff-2.2.1.dist-info/entry_points.txt,sha256=_4VRsEsEWOhHfzBDu9bx8Wh_S8Wi4ZTHpI0n6rU0J-I,258
|
114
|
+
hcpdiff-2.2.1.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
|
115
|
+
hcpdiff-2.2.1.dist-info/RECORD,,
|
hcpdiff/train_deepspeed.py
DELETED
@@ -1,69 +0,0 @@
|
|
1
|
-
import argparse
|
2
|
-
import os
|
3
|
-
import sys
|
4
|
-
import warnings
|
5
|
-
from functools import partial
|
6
|
-
|
7
|
-
import torch
|
8
|
-
|
9
|
-
from hcpdiff.ckpt_manager import CkptManagerPKL, CkptManagerSafe
|
10
|
-
from hcpdiff.train_ac_old import Trainer, load_config_with_cli
|
11
|
-
from hcpdiff.utils.net_utils import get_scheduler
|
12
|
-
|
13
|
-
class TrainerDeepSpeed(Trainer):
|
14
|
-
|
15
|
-
def build_ckpt_manager(self):
|
16
|
-
self.ckpt_manager = self.ckpt_manager_map[self.cfgs.ckpt_type](plugin_from_raw=True)
|
17
|
-
if self.is_local_main_process:
|
18
|
-
self.ckpt_manager.set_save_dir(os.path.join(self.exp_dir, 'ckpts'), emb_dir=self.cfgs.tokenizer_pt.emb_dir)
|
19
|
-
|
20
|
-
@property
|
21
|
-
def unet_raw(self):
|
22
|
-
return self.accelerator.unwrap_model(self.TE_unet).unet if self.train_TE else self.accelerator.unwrap_model(self.TE_unet.unet)
|
23
|
-
|
24
|
-
@property
|
25
|
-
def TE_raw(self):
|
26
|
-
return self.accelerator.unwrap_model(self.TE_unet).TE if self.train_TE else self.TE_unet.TE
|
27
|
-
|
28
|
-
def get_loss(self, model_pred, target, timesteps, att_mask):
|
29
|
-
if att_mask is None:
|
30
|
-
att_mask = 1.0
|
31
|
-
if getattr(self.criterion, 'need_timesteps', False):
|
32
|
-
loss = (self.criterion(model_pred.float(), target.float(), timesteps)*att_mask).mean()
|
33
|
-
else:
|
34
|
-
loss = (self.criterion(model_pred.float(), target.float())*att_mask).mean()
|
35
|
-
return loss
|
36
|
-
|
37
|
-
def build_optimizer_scheduler(self):
|
38
|
-
# set optimizer
|
39
|
-
parameters, parameters_pt = self.get_param_group_train()
|
40
|
-
|
41
|
-
if len(parameters_pt)>0: # do prompt-tuning
|
42
|
-
cfg_opt_pt = self.cfgs.train.optimizer_pt
|
43
|
-
# if self.cfgs.train.scale_lr_pt:
|
44
|
-
# self.scale_lr(parameters_pt)
|
45
|
-
assert isinstance(cfg_opt_pt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
|
46
|
-
weight_decay = cfg_opt_pt.keywords.get('weight_decay', None)
|
47
|
-
if weight_decay is not None:
|
48
|
-
for param in parameters_pt:
|
49
|
-
param['weight_decay'] = weight_decay
|
50
|
-
|
51
|
-
parameters += parameters_pt
|
52
|
-
warnings.warn('deepspeed dose not support multi optimizer and lr_scheduler. optimizer_pt and scheduler_pt will not work.')
|
53
|
-
|
54
|
-
if len(parameters)>0:
|
55
|
-
cfg_opt = self.cfgs.train.optimizer
|
56
|
-
if self.cfgs.train.scale_lr:
|
57
|
-
self.scale_lr(parameters)
|
58
|
-
assert isinstance(cfg_opt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
|
59
|
-
self.optimizer = cfg_opt(params=parameters)
|
60
|
-
self.lr_scheduler = get_scheduler(self.cfgs.train.scheduler, self.optimizer)
|
61
|
-
|
62
|
-
if __name__ == '__main__':
|
63
|
-
parser = argparse.ArgumentParser(description='Stable Diffusion Training')
|
64
|
-
parser.add_argument('--cfg', type=str, default='cfg/train/demo.yaml')
|
65
|
-
args, cfg_args = parser.parse_known_args()
|
66
|
-
|
67
|
-
conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
|
68
|
-
trainer = TrainerDeepSpeed(conf)
|
69
|
-
trainer.train()
|
File without changes
|
File without changes
|