hcpdiff 2.1__py3-none-any.whl → 2.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
1
  from .format import EmbFormat, DiffusersSD15Format, DiffusersModelFormat, DiffusersSDXLFormat, DiffusersPixArtFormat, OfficialSDXLFormat, \
2
- OfficialSD15Format
2
+ OfficialSD15Format, LoraWebuiFormat
3
3
  from .ckpt import EmbSaver, easy_emb_saver
4
4
  from .loader import HCPLoraLoader
@@ -2,7 +2,7 @@ import math
2
2
  import re
3
3
  from typing import List, Dict, Any
4
4
 
5
- from rainbowneko.ckpt_manager.format import CkptFormat
5
+ from rainbowneko.ckpt_manager.format import CkptFormat, SafeTensorFormat
6
6
  from torch.serialization import FILE_LIKE
7
7
 
8
8
  class LoraConverter:
@@ -36,7 +36,12 @@ class LoraConverter:
36
36
  if auto_scale_alpha:
37
37
  sd_unet = self.alpha_scale_from_webui(sd_unet)
38
38
  sd_TE = self.alpha_scale_from_webui(sd_TE)
39
- return {'plugin':sd_TE}, {'plugin':sd_unet}
39
+
40
+ sd = {
41
+ **{f'denoiser.{k}':v for k,v in sd_unet.items()},
42
+ **{f'TE.{k}':v for k,v in sd_TE.items()},
43
+ }
44
+ return {'base': sd}
40
45
 
41
46
  def convert_to_webui(self, sd_unet, sd_TE, auto_scale_alpha=False, sdxl=False):
42
47
  sd_unet = self.convert_to_webui_(sd_unet, prefix=self.prefix_unet)
@@ -207,9 +212,12 @@ class LoraConverter:
207
212
  return state
208
213
 
209
214
  class LoraWebuiFormat(CkptFormat):
210
- def __init__(self, format, auto_scale_alpha=False):
215
+ def __init__(self, format=None, auto_scale_alpha=False):
211
216
  self.converter = LoraConverter()
212
217
  self.auto_scale_alpha = auto_scale_alpha
218
+
219
+ if format is None:
220
+ format = SafeTensorFormat()
213
221
  self.format = format
214
222
 
215
223
  def save_ckpt(self, sd_model: Dict[str, Any], save_f: FILE_LIKE):
@@ -240,5 +248,5 @@ class LoraWebuiFormat(CkptFormat):
240
248
  sdxl = True
241
249
  break
242
250
 
243
- sd_TE, sd_unet = self.converter.convert_from_webui(sd_webui, auto_scale_alpha=self.auto_scale_alpha, sdxl=sdxl)
244
- return sd_TE, sd_unet
251
+ sd_all = self.converter.convert_from_webui(sd_webui, auto_scale_alpha=self.auto_scale_alpha, sdxl=sdxl)
252
+ return sd_all
hcpdiff/data/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
1
  from .dataset import TextImagePairDataset
2
- from .source import Text2ImageSource, Text2ImageLossMapSource, Text2ImageCondSource, T2IFolderClassSource
3
- from .handler import StableDiffusionHandler, LossMapHandler, DiffusionImageHandler
2
+ from .source import Text2ImageSource, Text2ImageLossMapSource, Text2ImageCondSource, T2IFolderClassSource, TextSource
3
+ from .handler import StableDiffusionHandler, LossMapHandler, DiffusionImageHandler, DiffusionTextHandler
4
4
  from .cache import VaeCache
@@ -1,3 +1,3 @@
1
- from .diffusion import StableDiffusionHandler, DiffusionImageHandler, LossMapHandler
1
+ from .diffusion import StableDiffusionHandler, DiffusionImageHandler, LossMapHandler, DiffusionTextHandler
2
2
  from .text import TokenizeHandler, TagEraseHandler, TagDropoutHandler, TagShuffleHandler, TemplateFillHandler
3
3
  from .controlnet import ControlNetHandler
@@ -49,14 +49,11 @@ class DiffusionImageHandler(DataHandler):
49
49
  else:
50
50
  return self.handlers(dict(image=image, image_size=image_size))
51
51
 
52
- class StableDiffusionHandler(DataHandler):
53
- def __init__(self, bucket, encoder_attention_mask=False, key_map_in=('image -> image', 'image_size -> image_size', 'prompt -> prompt'),
54
- key_map_out=('image -> image', 'coord -> coord', 'prompt -> prompt'),
55
- erase=0.15, dropout=0.0, shuffle=0.0, word_names={}, tokenize=True):
52
+ class DiffusionTextHandler(DataHandler):
53
+ def __init__(self, encoder_attention_mask=False, erase=0.0, dropout=0.0, shuffle=0.0, word_names={}, tokenize=True,
54
+ key_map_in=('prompt -> prompt', ), key_map_out=('prompt -> prompt', )):
56
55
  super().__init__(key_map_in, key_map_out)
57
56
 
58
- self.image_handlers = DiffusionImageHandler(bucket)
59
-
60
57
  text_handlers = {}
61
58
  if dropout>0:
62
59
  text_handlers['dropout'] = TagDropoutHandler(p=dropout)
@@ -67,7 +64,20 @@ class StableDiffusionHandler(DataHandler):
67
64
  text_handlers['fill'] = TemplateFillHandler(word_names)
68
65
  if tokenize:
69
66
  text_handlers['tokenize'] = TokenizeHandler(encoder_attention_mask)
70
- self.text_handlers = HandlerChain(**text_handlers)
67
+ self.handlers = HandlerChain(**text_handlers)
68
+
69
+ def handle(self, prompt: Union[str, Dict[str, str]]):
70
+ return self.handlers(dict(prompt=prompt))
71
+
72
+ class StableDiffusionHandler(DataHandler):
73
+ def __init__(self, bucket, encoder_attention_mask=False, key_map_in=('image -> image', 'image_size -> image_size', 'prompt -> prompt'),
74
+ key_map_out=('image -> image', 'coord -> coord', 'prompt -> prompt'),
75
+ erase=0.0, dropout=0.0, shuffle=0.0, word_names={}, tokenize=True):
76
+ super().__init__(key_map_in, key_map_out)
77
+
78
+ self.image_handlers = DiffusionImageHandler(bucket)
79
+ self.text_handlers = DiffusionTextHandler(encoder_attention_mask=encoder_attention_mask, erase=erase, dropout=dropout, shuffle=shuffle,
80
+ word_names=word_names, tokenize=tokenize)
71
81
 
72
82
  def handle(self, image: Image.Image, image_size: np.ndarray[int], prompt: str):
73
83
  return dict(**self.image_handlers(dict(image=image, image_size=image_size)), **self.text_handlers(dict(prompt=prompt)))
@@ -1,3 +1,4 @@
1
1
  from .text2img import Text2ImageSource, Text2ImageLossMapSource
2
2
  from .text2img_cond import Text2ImageCondSource
3
- from .folder_class import T2IFolderClassSource
3
+ from .folder_class import T2IFolderClassSource
4
+ from .text import TextSource
@@ -0,0 +1,40 @@
1
+ from rainbowneko.data import UnLabelSource, DataSource
2
+ from rainbowneko.data.label_loader import BaseLabelLoader, auto_label_loader
3
+ from typing import Union, Dict, Any
4
+ import random
5
+
6
+ class TextSource(DataSource):
7
+ def __init__(self, label_file, prompt_template=None, repeat=1, **kwargs):
8
+ super().__init__(repeat=repeat)
9
+ self.label_file = label_file
10
+ self.label_dict = self._load_label_data(label_file)
11
+ self.img_ids = self._load_img_ids(self.label_dict)
12
+ self.prompt_template = self.load_template(prompt_template)
13
+
14
+ def _load_img_ids(self, label_dict):
15
+ return list(label_dict.keys()) * self.repeat
16
+
17
+ def _load_label_data(self, label_file: Union[str, BaseLabelLoader]):
18
+ if label_file is None:
19
+ return {}
20
+ elif isinstance(label_file, str):
21
+ return auto_label_loader(label_file).load()
22
+ else:
23
+ return label_file.load()
24
+
25
+ def load_template(self, template_file):
26
+ if template_file is None:
27
+ return ['{caption}']
28
+ else:
29
+ with open(template_file, 'r', encoding='utf-8') as f:
30
+ return f.read().strip().split('\n')
31
+
32
+ def __getitem__(self, index) -> Dict[str, Any]:
33
+ img_name = self.img_ids[index]
34
+ return {
35
+ 'id':img_name,
36
+ 'prompt':{
37
+ 'template':random.choice(self.prompt_template),
38
+ 'caption':self.label_dict[img_name],
39
+ }
40
+ }
@@ -25,7 +25,7 @@ class Text2ImageSource(ImageLabelSource):
25
25
 
26
26
  def __getitem__(self, index) -> Dict[str, Any]:
27
27
  img_name = self.img_ids[index]
28
- path = os.path.join(self.img_root, img_name)
28
+ path = self.img_root / img_name
29
29
 
30
30
  return {
31
31
  'id':img_name,
@@ -1,3 +1,3 @@
1
1
  from .sd15_train import SD15_lora_train, cfg_data_SD_ARB, cfg_data_SD_resize_crop, SD15_finetuning
2
2
  from .sdxl_train import SDXL_lora_train, SDXL_finetuning
3
- from .t2i import SD15_t2i, SDXL_t2i, SDXL_t2i_lora, SD15_t2i_lora
3
+ from .t2i import SD15_t2i, SDXL_t2i, SDXL_t2i_lora, SD15_t2i_lora, SDXL_t2i_parts, SD15_t2i_parts
@@ -1,9 +1,10 @@
1
1
  import torch
2
- from rainbowneko.ckpt_manager import ckpt_saver, LAYERS_TRAINABLE, plugin_saver
2
+ from rainbowneko.ckpt_manager import ckpt_saver, LAYERS_TRAINABLE, NekoPluginSaver, SafeTensorFormat
3
3
  from rainbowneko.data import RatioBucket, FixedBucket
4
4
  from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
5
5
  from rainbowneko.utils import ConstantLR, Path_Like
6
6
 
7
+ from hcpdiff.ckpt_manager import LoraWebuiFormat
7
8
  from hcpdiff.data import TextImagePairDataset, Text2ImageSource, StableDiffusionHandler
8
9
  from hcpdiff.data import VaeCache
9
10
  from hcpdiff.easy import SD15_auto_loader
@@ -46,7 +47,7 @@ def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
46
47
 
47
48
  optimizer=optimizer,
48
49
 
49
- scheduler=ConstantLR(
50
+ lr_scheduler=ConstantLR(
50
51
  _partial_=True,
51
52
  warmup_steps=warmup_steps,
52
53
  ),
@@ -69,7 +70,7 @@ def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
69
70
  @neko_cfg
70
71
  def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, lr: float = 1e-4, rank: int = 4, alpha: float = None,
71
72
  clip_skip: int = 0, with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0,
72
- name: str = 'SD15'):
73
+ name: str = 'SD15', save_webui_format=False):
73
74
  with disable_neko_cfg:
74
75
  if alpha is None:
75
76
  alpha = rank
@@ -95,6 +96,11 @@ def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
95
96
  else:
96
97
  optimizer = torch.optim.AdamW(_partial_=True, betas=(0.9, 0.99))
97
98
 
99
+ if save_webui_format:
100
+ lora_format = LoraWebuiFormat()
101
+ else:
102
+ lora_format = SafeTensorFormat()
103
+
98
104
  from cfgs.train.py.examples import SD_FT
99
105
 
100
106
  return dict(
@@ -114,8 +120,8 @@ def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
114
120
 
115
121
  ckpt_saver=dict(
116
122
  _replace_ = True,
117
- lora_unet=plugin_saver(
118
- ckpt_type='safetensors',
123
+ lora_unet=NekoPluginSaver(
124
+ format=lora_format,
119
125
  target_plugin='lora1',
120
126
  )
121
127
  ),
@@ -126,7 +132,7 @@ def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
126
132
 
127
133
  optimizer=optimizer,
128
134
 
129
- scheduler=ConstantLR(
135
+ lr_scheduler=ConstantLR(
130
136
  _partial_=True,
131
137
  warmup_steps=warmup_steps,
132
138
  ),
@@ -1,11 +1,12 @@
1
1
  import torch
2
- from rainbowneko.ckpt_manager import ckpt_saver, plugin_saver, LAYERS_TRAINABLE
2
+ from rainbowneko.ckpt_manager import ckpt_saver, NekoPluginSaver, LAYERS_TRAINABLE, SafeTensorFormat
3
3
  from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
4
4
  from rainbowneko.utils import ConstantLR
5
5
 
6
6
  from hcpdiff.easy import SDXL_auto_loader
7
7
  from hcpdiff.models import SDXLWrapper
8
8
  from hcpdiff.models.lora_layers_patch import LoraLayer
9
+ from hcpdiff.ckpt_manager import LoraWebuiFormat
9
10
 
10
11
  @neko_cfg
11
12
  def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int = 500, lr: float = 1e-5,
@@ -43,7 +44,7 @@ def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
43
44
 
44
45
  optimizer=optimizer,
45
46
 
46
- scheduler=ConstantLR(
47
+ lr_scheduler=ConstantLR(
47
48
  _partial_=True,
48
49
  warmup_steps=warmup_steps,
49
50
  ),
@@ -64,7 +65,8 @@ def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
64
65
 
65
66
  @neko_cfg
66
67
  def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, lr: float = 1e-4, rank: int = 4, alpha: float = None,
67
- with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SD15'):
68
+ with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SDXL',
69
+ save_webui_format=False):
68
70
  with disable_neko_cfg:
69
71
  if alpha is None:
70
72
  alpha = rank
@@ -90,6 +92,11 @@ def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
90
92
  else:
91
93
  optimizer = torch.optim.AdamW(_partial_=True, betas=(0.9, 0.99))
92
94
 
95
+ if save_webui_format:
96
+ lora_format = LoraWebuiFormat()
97
+ else:
98
+ lora_format = SafeTensorFormat()
99
+
93
100
  from cfgs.train.py.examples import SD_FT
94
101
 
95
102
  return dict(
@@ -109,8 +116,8 @@ def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
109
116
 
110
117
  ckpt_saver=dict(
111
118
  _replace_ = True,
112
- lora_unet=plugin_saver(
113
- ckpt_type='safetensors',
119
+ lora_unet=NekoPluginSaver(
120
+ format=lora_format,
114
121
  target_plugin='lora1',
115
122
  )
116
123
  ),
@@ -121,7 +128,7 @@ def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
121
128
 
122
129
  optimizer=optimizer,
123
130
 
124
- scheduler=ConstantLR(
131
+ lr_scheduler=ConstantLR(
125
132
  _partial_=True,
126
133
  warmup_steps=warmup_steps,
127
134
  ),
hcpdiff/easy/cfg/t2i.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import torch
2
2
  from rainbowneko.infer.workflow import (Actions, PrepareAction, LoopAction, LoadModelAction)
3
+ from rainbowneko.ckpt_manager import NekoModelLoader
3
4
  from rainbowneko.parser import neko_cfg, disable_neko_cfg
4
5
  from typing import Union, List
5
6
 
@@ -25,6 +26,29 @@ def build_model(pretrained_model='ckpts/any5', noise_sampler=Diffusers_SD.dpmpp_
25
26
  ),
26
27
  ])
27
28
 
29
+ @neko_cfg
30
+ def load_parts(info: List[str]) -> Actions:
31
+ acts = []
32
+ for i, path in enumerate(info):
33
+ part_unet = LoadModelAction(cfg={
34
+ f'part_unet_{i}':NekoModelLoader(
35
+ path=path,
36
+ state_prefix='denoiser.'
37
+ )
38
+ }, key_map_in=('denoiser -> model', 'in_preview -> in_preview'))
39
+ part_TE = LoadModelAction(cfg={
40
+ f'part_TE_{i}':NekoModelLoader(
41
+ path=path,
42
+ state_prefix='TE.',
43
+ )
44
+ }, key_map_in=('TE -> model', 'in_preview -> in_preview'))
45
+
46
+ with disable_neko_cfg:
47
+ acts.append(part_unet)
48
+ acts.append(part_TE)
49
+
50
+ return Actions(acts)
51
+
28
52
  @neko_cfg
29
53
  def load_lora(info: List[List]) -> Actions:
30
54
  lora_acts = []
@@ -37,7 +61,7 @@ def load_lora(info: List[List]) -> Actions:
37
61
  )
38
62
  }, key_map_in=('denoiser -> model', 'in_preview -> in_preview'))
39
63
  lora_TE = LoadModelAction(cfg={
40
- f'lora_unet_{i}':HCPLoraLoader(
64
+ f'lora_TE_{i}':HCPLoraLoader(
41
65
  path=item[0],
42
66
  state_prefix='TE.',
43
67
  alpha=item[1],
@@ -59,9 +83,9 @@ def optimize_model() -> Actions:
59
83
  ])
60
84
 
61
85
  @neko_cfg
62
- def text(prompt, negative_prompt=negative_prompt, bs=4) -> Actions:
86
+ def text(prompt, negative_prompt=negative_prompt, bs=4, N_repeats=1, layer_skip=1) -> Actions:
63
87
  return Actions([
64
- TextHookAction(N_repeats=1, layer_skip=1),
88
+ TextHookAction(N_repeats=N_repeats, layer_skip=layer_skip),
65
89
  AttnMultTextEncodeAction(
66
90
  prompt=prompt,
67
91
  negative_prompt=negative_prompt,
@@ -84,9 +108,9 @@ def build_model_SDXL(pretrained_model='ckpts/any5', noise_sampler=Diffusers_SD.d
84
108
  ])
85
109
 
86
110
  @neko_cfg
87
- def text_SDXL(prompt, negative_prompt=negative_prompt, bs=4) -> Actions:
111
+ def text_SDXL(prompt, negative_prompt=negative_prompt, bs=4, N_repeats=1, layer_skip=1) -> Actions:
88
112
  return Actions([
89
- TextHookAction(N_repeats=1, layer_skip=1, TE_final_norm=False),
113
+ TextHookAction(N_repeats=N_repeats, layer_skip=layer_skip, TE_final_norm=False),
90
114
  AttnMultTextEncodeAction(
91
115
  prompt=prompt,
92
116
  negative_prompt=negative_prompt,
@@ -128,11 +152,24 @@ def resize(width=1024, height=1024):
128
152
 
129
153
  @neko_cfg
130
154
  def SD15_t2i(pretrained_model, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4, width=512, height=512,
131
- seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/'):
155
+ seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
156
+ return dict(workflow=Actions(actions=[
157
+ build_model(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
158
+ optimize_model(),
159
+ text(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
160
+ config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
161
+ diffusion(guidance_scale=guidance_scale),
162
+ decode(save_root=save_root)
163
+ ]))
164
+
165
+ @neko_cfg
166
+ def SD15_t2i_parts(pretrained_model, parts, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4, width=512, height=512,
167
+ seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
132
168
  return dict(workflow=Actions(actions=[
133
169
  build_model(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
170
+ load_parts(parts),
134
171
  optimize_model(),
135
- text(prompt=prompt, negative_prompt=negative_prompt, bs=bs),
172
+ text(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
136
173
  config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
137
174
  diffusion(guidance_scale=guidance_scale),
138
175
  decode(save_root=save_root)
@@ -140,12 +177,12 @@ def SD15_t2i(pretrained_model, prompt, negative_prompt=negative_prompt, noise_sa
140
177
 
141
178
  @neko_cfg
142
179
  def SD15_t2i_lora(pretrained_model, lora_info, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4,
143
- width=512, height=512, seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/'):
180
+ width=512, height=512, seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
144
181
  return dict(workflow=Actions(actions=[
145
182
  build_model(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
146
183
  load_lora(info=lora_info),
147
184
  optimize_model(),
148
- text(prompt=prompt, negative_prompt=negative_prompt, bs=bs),
185
+ text(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
149
186
  config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
150
187
  diffusion(guidance_scale=guidance_scale),
151
188
  decode(save_root=save_root)
@@ -153,24 +190,38 @@ def SD15_t2i_lora(pretrained_model, lora_info, prompt, negative_prompt=negative_
153
190
 
154
191
  @neko_cfg
155
192
  def SDXL_t2i(pretrained_model, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4, width=1024, height=1024,
156
- seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/'):
193
+ seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
157
194
  return dict(workflow=Actions(actions=[
158
195
  build_model_SDXL(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
159
196
  optimize_model(),
160
- text_SDXL(prompt=prompt, negative_prompt=negative_prompt, bs=bs),
197
+ text_SDXL(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
161
198
  config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
162
199
  diffusion(guidance_scale=guidance_scale),
163
200
  decode(save_root=save_root)
164
201
  ]))
165
202
 
203
+ @neko_cfg
204
+ def SDXL_t2i_parts(pretrained_model, parts, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4, width=1024, height=1024,
205
+ seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
206
+ return dict(workflow=Actions(actions=[
207
+ build_model_SDXL(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
208
+ load_parts(parts),
209
+ optimize_model(),
210
+ text_SDXL(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
211
+ config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
212
+ diffusion(guidance_scale=guidance_scale),
213
+ decode(save_root=save_root)
214
+ ]))
215
+
216
+
166
217
  @neko_cfg
167
218
  def SDXL_t2i_lora(pretrained_model, lora_info, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4,
168
- width=1024, height=1024, seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/'):
219
+ width=1024, height=1024, seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
169
220
  return dict(workflow=Actions(actions=[
170
221
  build_model_SDXL(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
171
222
  load_lora(info=lora_info),
172
223
  optimize_model(),
173
- text_SDXL(prompt=prompt, negative_prompt=negative_prompt, bs=bs),
224
+ text_SDXL(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
174
225
  config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
175
226
  diffusion(guidance_scale=guidance_scale),
176
227
  decode(save_root=save_root)
@@ -126,6 +126,10 @@ class EmbeddingPTInterpHook(SinglePluginBlock):
126
126
  BOS = repeat(inputs_embeds[0,0,:], 'e -> r 1 e', r=self.N_repeats)
127
127
  EOS = repeat(inputs_embeds[0,-1,:], 'e -> r 1 e', r=self.N_repeats)
128
128
 
129
+ # make DDP happy
130
+ if len(self.emb_train) > 0:
131
+ BOS = BOS + sum(emb.mean()*0 for emb in self.emb_train if emb.requires_grad)
132
+
129
133
  replaced_embeds = []
130
134
  for item, rep_idxs, ids_raw in zip(inputs_embeds, rep_idxs_B, self.input_ids):
131
135
  # insert pt to embeddings
hcpdiff/trainer_ac.py CHANGED
@@ -42,13 +42,6 @@ class HCPTrainer(Trainer):
42
42
  def pt_trainable(self):
43
43
  return self.cfgs.emb_pt is not None
44
44
 
45
- def get_loss(self, ds_name, model_pred, inputs):
46
- loss = super().get_loss(ds_name, model_pred, inputs)
47
- # make DDP happy
48
- if len(self.train_pts)>0:
49
- loss = loss+0*sum([emb.mean() for emb in self.train_pts.values()])
50
- return loss
51
-
52
45
  def save_model(self, from_raw=False):
53
46
  NekoSaver.save_all(
54
47
  self.model_raw,
@@ -0,0 +1,47 @@
1
+ import argparse
2
+ import warnings
3
+
4
+ import torch
5
+ from rainbowneko.ckpt_manager import NekoPluginSaver
6
+ from rainbowneko.train.trainer import TrainerDeepspeed
7
+ from rainbowneko.utils import xformers_available
8
+
9
+ from hcpdiff.trainer_ac import HCPTrainer, load_config_with_cli
10
+
11
+ class HCPTrainerDeepspeed(TrainerDeepspeed, HCPTrainer):
12
+ def config_model(self):
13
+ if self.cfgs.model.enable_xformers:
14
+ if xformers_available:
15
+ self.model_wrapper.enable_xformers()
16
+ else:
17
+ warnings.warn("xformers is not available. Make sure it is installed correctly")
18
+
19
+ if self.model_wrapper.vae is not None:
20
+ self.vae_dtype = self.weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
21
+ self.model_wrapper.set_dtype(self.weight_dtype, self.vae_dtype)
22
+
23
+ if self.cfgs.model.gradient_checkpointing:
24
+ self.model_wrapper.enable_gradient_checkpointing()
25
+
26
+ if self.is_local_main_process:
27
+ for saver in self.ckpt_saver.values():
28
+ if isinstance(saver, NekoPluginSaver):
29
+ saver.plugin_from_raw = True
30
+
31
+ def hcp_train():
32
+ import subprocess
33
+ parser = argparse.ArgumentParser(description='HCP-Diffusion Launcher')
34
+ parser.add_argument('--launch_cfg', type=str, default='cfgs/launcher/deepspeed.yaml')
35
+ args, train_args = parser.parse_known_args()
36
+
37
+ subprocess.run(["accelerate", "launch", '--config_file', args.launch_cfg, "-m",
38
+ "hcpdiff.trainer_deepspeed"]+train_args, check=True)
39
+
40
+ if __name__ == '__main__':
41
+ parser = argparse.ArgumentParser(description='HCP Diffusion Trainer for DeepSpeed')
42
+ parser.add_argument("--cfg", type=str, default=None, required=True)
43
+ args, cfg_args = parser.parse_known_args()
44
+
45
+ parser, conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
46
+ trainer = HCPTrainerDeepspeed(parser, conf)
47
+ trainer.train()
@@ -32,14 +32,15 @@ class SeedAction(BasicAction):
32
32
  self.seed = seed
33
33
  self.bs = bs
34
34
 
35
- def forward(self, device, gen_step=0, **states):
35
+ def forward(self, device, seed=None, **states):
36
36
  bs = states['prompt_embeds'].shape[0]//2 if 'prompt_embeds' in states else self.bs
37
- if self.seed is None:
37
+ seed = seed or self.seed
38
+ if seed is None:
38
39
  seeds = [None]*bs
39
- elif isinstance(self.seed, int):
40
- seeds = list(range(self.seed+gen_step*bs, self.seed+(gen_step+1)*bs))
40
+ elif isinstance(seed, int):
41
+ seeds = list(range(seed, seed+bs))
41
42
  else:
42
- seeds = self.seed
43
+ seeds = seed
43
44
  seeds = [s or random.randint(0, 1 << 30) for s in seeds]
44
45
 
45
46
  G = prepare_seed(seeds, device=device)
hcpdiff/workflow/text.py CHANGED
@@ -48,18 +48,9 @@ class TextEncodeAction(BasicAction):
48
48
  self.negative_prompt = negative_prompt
49
49
  self.bs = bs
50
50
 
51
- def forward(self, te_hook, TE, dtype: str, device, amp=None, gen_step=None, prompt_all=None, negative_prompt_all=None, model_offload=False,
52
- **states):
53
- prompt_all = prompt_all or self.prompt
54
- negative_prompt_all = negative_prompt_all or self.negative_prompt
55
-
56
- if gen_step is not None:
57
- idx = (gen_step*self.bs)%len(prompt_all)
58
- prompt = prompt_all[idx:idx+self.bs]
59
- negative_prompt = negative_prompt_all[idx:idx+self.bs]
60
- else:
61
- prompt = prompt_all
62
- negative_prompt = negative_prompt_all
51
+ def forward(self, te_hook, TE, dtype: str, device, amp=None, prompt=None, negative_prompt=None, model_offload=False, **states):
52
+ prompt = prompt or self.prompt
53
+ negative_prompt = negative_prompt or self.negative_prompt
63
54
 
64
55
  if model_offload:
65
56
  to_cuda(TE)
@@ -78,19 +69,9 @@ class TextEncodeAction(BasicAction):
78
69
  'pooled_output':pooled_output}
79
70
 
80
71
  class AttnMultTextEncodeAction(TextEncodeAction):
81
-
82
- def forward(self, te_hook, token_ex, TE, dtype: str, device, amp=None, gen_step=None, prompt_all=None, negative_prompt_all=None,
83
- model_offload=False, **states):
84
- prompt_all = prompt_all if prompt_all is not None else self.prompt
85
- negative_prompt_all = negative_prompt_all if negative_prompt_all is not None else self.negative_prompt
86
-
87
- if gen_step is not None:
88
- idx = (gen_step*self.bs)%len(prompt_all)
89
- prompt = prompt_all[idx:idx+self.bs]
90
- negative_prompt = negative_prompt_all[idx:idx+self.bs]
91
- else:
92
- prompt = prompt_all
93
- negative_prompt = negative_prompt_all
72
+ def forward(self, te_hook, token_ex, TE, dtype: str, device, amp=None, prompt=None, negative_prompt=None, model_offload=False, **states):
73
+ prompt = prompt or self.prompt
74
+ negative_prompt = negative_prompt or self.negative_prompt
94
75
 
95
76
  if model_offload:
96
77
  to_cuda(TE)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hcpdiff
3
- Version: 2.1
3
+ Version: 2.2.1
4
4
  Summary: A universal Diffusion toolbox
5
5
  Home-page: https://github.com/IrisRainbowNeko/HCP-Diffusion
6
6
  Author: Ziyi Dong
@@ -17,7 +17,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
17
  Requires-Python: >=3.8
18
18
  Description-Content-Type: text/markdown
19
19
  License-File: LICENSE
20
- Requires-Dist: rainbowneko
20
+ Requires-Dist: rainbowneko==1.6
21
21
  Requires-Dist: diffusers
22
22
  Requires-Dist: matplotlib
23
23
  Requires-Dist: pyarrow
@@ -65,6 +65,8 @@ Compared to the original DreamArtist, it offers better stability, image quality,
65
65
 
66
66
  ## Installation
67
67
 
68
+ Install [pytorch](https://pytorch.org/)
69
+
68
70
  Install via pip:
69
71
 
70
72
  ```bash
@@ -205,6 +207,18 @@ After parsing, the framework will instantiate the components accordingly. This m
205
207
  | CCIP Score | 🚧 In Development |
206
208
  | Corrupt Score | 🚧 In Development |
207
209
 
210
+ ---
211
+
212
+ ### ⚡️ Image Generation
213
+
214
+ | 功能 | 描述/支持情况 |
215
+ |------------------------------|------------------------------------|
216
+ | Batch Generation | ✅ Supported |
217
+ | Generate from Prompt Dataset | ✅ Supported |
218
+ | Image to Image | ✅ Supported |
219
+ | Inpaint | ✅ Supported |
220
+ | Token Weight | ✅ Supported |
221
+
208
222
  </details>
209
223
 
210
224
  ---
@@ -248,9 +262,13 @@ hcp_run --cfg cfgs/workflow/text2img_cli.py \
248
262
  seed=42
249
263
  ```
250
264
 
251
- ### Tutorials
265
+ ### 📚 Tutorials
252
266
 
253
- 🚧 In Development
267
+ + 🧠 [Model Training Guide](https://hcpdiff.readthedocs.io/en/latest/user_guides/train.html)
268
+ + 🔧 [LoRA Training Tutorial](https://hcpdiff.readthedocs.io/enlatest/tutorial/lora.html)
269
+ + 🎨 [Image Generation Guide](https://hcpdiff.readthedocs.io/en/latest/user_guides/workflow.html)
270
+ + ⚙️ [Configuration File Explanation](https://hcpdiff.readthedocs.io/en/latest/user_guides/cfg.html)
271
+ + 🧩 [Model Format Explanation](https://hcpdiff.readthedocs.io/en/latest/user_guides/model_format.html)
254
272
 
255
273
  ---
256
274
 
@@ -1,27 +1,28 @@
1
1
  hcpdiff/__init__.py,sha256=dwNwrEgvG4g60fGMG6b50K3q3AWD1XCfzlIgbxkSUpE,177
2
2
  hcpdiff/train_colo.py,sha256=EsuNSzLBvGTZWU_LEk0JpP-F5eNW0lwkawIRAX38jmE,9250
3
- hcpdiff/train_deepspeed.py,sha256=PwyNukWi0of6TXy_VRDgBQSMLCZBhipO5g3Lq0nCYNk,2988
4
- hcpdiff/trainer_ac.py,sha256=6KAzo54in7ZRHud_rHjJdwRRZ4uWtc0B4SxVCxgcrmM,2990
3
+ hcpdiff/trainer_ac.py,sha256=scH3FU0onCQtwLiy0-pcrhuowTZob3fLQqRP52iwY0c,2717
5
4
  hcpdiff/trainer_ac_single.py,sha256=0PIC5EScqcxp49EaeIWq4KS5K_09OZfKajqbFu-hUb8,1108
6
- hcpdiff/ckpt_manager/__init__.py,sha256=LfMwz9R4jV4xpiSFt5vhpwaF7-8UHEZ_iDoW-3QGvt0,239
5
+ hcpdiff/trainer_deepspeed.py,sha256=7lGsiAstWuIlmhRMwWTcJCkoxzUaakVxBngKDnJdSJk,1947
6
+ hcpdiff/ckpt_manager/__init__.py,sha256=Mn_5KOC4xbf2GcN6OXg_XdbF5wO9zWeER_1ZO_prKAI,256
7
7
  hcpdiff/ckpt_manager/ckpt.py,sha256=Pa3uXQbCi2T99mpV5fYddQ-OGHcpk8r1ll-0lmP_WXk,965
8
8
  hcpdiff/ckpt_manager/loader.py,sha256=Ch1xsZmseq4nyPhpox9-nebN-dZB4k0rqBEHos-ZLso,3245
9
9
  hcpdiff/ckpt_manager/format/__init__.py,sha256=a3cdKkOTDgdVbDQwSC4mlxOigjX2hBvRb5_X7E3TQWs,237
10
10
  hcpdiff/ckpt_manager/format/diffusers.py,sha256=T81WN95Nj1il9DfQp9iioVn0uqFEWOlmdIYs2beNOFU,3769
11
11
  hcpdiff/ckpt_manager/format/emb.py,sha256=FrqfTfJ8H7f0Zw17NTWCP2AJtpsJI5oXR5IAd4NekhU,680
12
- hcpdiff/ckpt_manager/format/lora_webui.py,sha256=j7SpXnSx_Ys8tnWBgojuB1HEJIm46lhCBuNNYLhaF9w,9824
12
+ hcpdiff/ckpt_manager/format/lora_webui.py,sha256=4y_T9RdmFTxWzsXd8guNjCiukmyILa5j4MPrhVIL4Qk,10017
13
13
  hcpdiff/ckpt_manager/format/sd_single.py,sha256=LpCAL_7nAVooCHTFznVVsNMku1G3C77NBORxxr8GDtQ,2328
14
- hcpdiff/data/__init__.py,sha256=-z47HsEQSubc-AfriVComMACbQXlXTWAKMOPBkATHxA,258
14
+ hcpdiff/data/__init__.py,sha256=ZFKtanOoMo3G3eKUJPhysnHXnr8BNARERkcMB6B897U,292
15
15
  hcpdiff/data/dataset.py,sha256=1k4GldW13eVyqK_9hrQniqr3_XYAapnWF7iXl_1GXGg,877
16
16
  hcpdiff/data/cache/__init__.py,sha256=ToCmokYH6DghlSwm7HJFirPRIWJ0LkgzqVOYlgoAkQw,25
17
17
  hcpdiff/data/cache/vae.py,sha256=gB89zs4CdNlvukDXhVYU9QZrY6VTFUWfzjeF2psNQ50,4070
18
- hcpdiff/data/handler/__init__.py,sha256=D1HyqY0qfrUHgf25itpYj57JUvgn06G6EQ9d2vRRtys,236
18
+ hcpdiff/data/handler/__init__.py,sha256=G8ZTQF91ilkTRmUoWdmAissTSZ7fvNUpm_hBYmXKTtk,258
19
19
  hcpdiff/data/handler/controlnet.py,sha256=bRDMD9BP8-VaG5VrxzvcFKfkqeTbChNfrJSZ3vXbQgY,658
20
- hcpdiff/data/handler/diffusion.py,sha256=8n60UYdGNR08xw45HoI4EB5AaIui03tSGNDfjazO-5w,3516
20
+ hcpdiff/data/handler/diffusion.py,sha256=S-_7o5Z1tm6LmRZVZs21rbJC7iUoq0tHOsSjKK6geVk,4156
21
21
  hcpdiff/data/handler/text.py,sha256=gOzqB2oEkEUbiuy0kZWduo0c-w4Buu60KI6q6Nyl3aM,4208
22
- hcpdiff/data/source/__init__.py,sha256=AB1VicA272KjTm-Q5L6XvDM8CLQhVPylAPuPMtpfw4g,158
22
+ hcpdiff/data/source/__init__.py,sha256=265M8qfWNUE4SKX0pdXhLYjCnCuae5YE4bfZpO-ydXc,187
23
23
  hcpdiff/data/source/folder_class.py,sha256=bs4qPMTzwcnT6ZFlT3tpi9sclsRF9a2MBA1pQD-9EYs,961
24
- hcpdiff/data/source/text2img.py,sha256=MWXqAEbzmK6pkBY40t9u37ngY25mgdKQ2idwNld8-bo,1826
24
+ hcpdiff/data/source/text.py,sha256=VgI5Ouq986Yy1jwD2fZ9iBlsRciPCeARZmOPEZIcaQY,1468
25
+ hcpdiff/data/source/text2img.py,sha256=acYdolQhZUEpkd7tUAdNkCTVnPc1SMJOVTmGqFt9ZpE,1813
25
26
  hcpdiff/data/source/text2img_cond.py,sha256=yj1KpARA2rkjENutnnzC4uDkcU2Rye21FL2VdC25Hac,585
26
27
  hcpdiff/diffusion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
28
  hcpdiff/diffusion/noise/__init__.py,sha256=seBpOtd0YsU53PqMn7Nyl_RtwoC-ONEIOX7v2XLGpZQ,93
@@ -38,10 +39,10 @@ hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py,sha256=2PMIpg2K6CVoxew1y1pIqvC
38
39
  hcpdiff/diffusion/sampler/sigma_scheduler/edm.py,sha256=fOPB3lgnS9uVo4oW26Fur_nc8X_wQ6mmUcbkKhnoQjs,1900
39
40
  hcpdiff/easy/__init__.py,sha256=-emoyCOZlLCu3KNMI8L4qapUEtEYFSoiGU6-rKv1at4,149
40
41
  hcpdiff/easy/sampler.py,sha256=dQSBkeGh71O0DAmZLhTHTbk1bY7XzyUCeW1oJO14A4I,1250
41
- hcpdiff/easy/cfg/__init__.py,sha256=aVDEDPxHdX5n-aFkP_4ic8ZhQfSeKu8lZOkgW_4m398,221
42
- hcpdiff/easy/cfg/sd15_train.py,sha256=LRCJLHNU0JEd1m3MC_NFWUCw5LmwztiLiJlV7u_DeKM,6493
43
- hcpdiff/easy/cfg/sdxl_train.py,sha256=R0wolSVOrRlI9A-vAfz592SzSnwuDd4ku1oc5yRKrfU,4038
44
- hcpdiff/easy/cfg/t2i.py,sha256=6Pyy4werXNalwoBBHVMBLBg67kMS85Heb7R3t26GJqQ,6871
42
+ hcpdiff/easy/cfg/__init__.py,sha256=SxHMWG6T2CXhX3dP0xizSMd9vFWPaZQDc4Gj4CF__yQ,253
43
+ hcpdiff/easy/cfg/sd15_train.py,sha256=kKdESVqAxNlBhhz12PvwrpHJBea80OUFzDDMHwiulVs,6710
44
+ hcpdiff/easy/cfg/sdxl_train.py,sha256=FUWE_hRJdQc9Qd9J6730jAyK0H4EIKS7-3BSufCItXU,4275
45
+ hcpdiff/easy/cfg/t2i.py,sha256=SnjFjZAKd9orjJr3RW5_N2_EIlW2Ree7JMvdNUAR9gc,9507
45
46
  hcpdiff/easy/model/__init__.py,sha256=CA-7r3R2Jgweekk1XNByFYttLolbWyUV2bCnXygcD8w,133
46
47
  hcpdiff/easy/model/cnet.py,sha256=m0NTH9V1kLzb5GybwBrSNT0KvTcRpPfGkzUeMz9jZZQ,1084
47
48
  hcpdiff/easy/model/loader.py,sha256=Tdx-lhQEYf2NYjVM1A5B8x6ZZpJKcXUkFIPIbr7h7XM,3456
@@ -61,7 +62,7 @@ hcpdiff/models/lora_base.py,sha256=LGwBD9KP6qf4pgTx24i5-JLo4rDBQ6jFfterQKBjTbE,6
61
62
  hcpdiff/models/lora_base_patch.py,sha256=WW3CULnROTxKXyynJiqirhHYCKN5JtxLhVpT5b7AUQg,6532
62
63
  hcpdiff/models/lora_layers.py,sha256=O9W_Ue71lHj7Y_GbpioF4Hc3h2-z_zOqck93VYUra6s,7777
63
64
  hcpdiff/models/lora_layers_patch.py,sha256=GYFYsJD2VSLZfdnLma9CmQEHz09HROFJcc4wc_gs9f0,8198
64
- hcpdiff/models/text_emb_ex.py,sha256=a5QImxzvj0zWR12qXOPP9kmpESl8J9VLabA0W9D_i_c,7867
65
+ hcpdiff/models/text_emb_ex.py,sha256=O0XZqid01OrB0dHY7hCiBvdU2026SvZ38yfQaF2TWrs,8018
65
66
  hcpdiff/models/textencoder_ex.py,sha256=JrTQ30Avx8tPbdr-Q6K5BvEWCEdsu8Z7eSOzMqpUuzg,8270
66
67
  hcpdiff/models/tokenizer_ex.py,sha256=zKUn4BY7b3yXwK9PWkZtQKJPyKYwUc07E-hwB9NQybs,2446
67
68
  hcpdiff/models/compose/__init__.py,sha256=lTNFTGg5csqvUuys22RqgjmWlk_7Okw6ZTsnTi1pqCg,217
@@ -95,20 +96,20 @@ hcpdiff/utils/net_utils.py,sha256=gdwLYDNKV2t3SP0jBIO3d0HtY6E7jRaf_rmPT8gKZZE,97
95
96
  hcpdiff/utils/pipe_hook.py,sha256=-UDX3FtZGl-bxSk13gdbPXc1OvtbCcpk_fvKxLQo3Ag,31987
96
97
  hcpdiff/utils/utils.py,sha256=hZnZP1IETgVpScxES0yIuRfc34TnzvAqmgOTK_56ssw,4976
97
98
  hcpdiff/workflow/__init__.py,sha256=t7Zyc0XFORdNvcwHp9AsCtEkhJ3l7Hm41ugngIL0Sag,867
98
- hcpdiff/workflow/diffusion.py,sha256=yrl2cXE2d2FNeVzYZDRQNLjy5-QnVgOWioIHSmszk2Y,8662
99
+ hcpdiff/workflow/diffusion.py,sha256=yzhqKA3019OPu1RKggrLoytMgm919qf6j9S85PYOwjQ,8644
99
100
  hcpdiff/workflow/fast.py,sha256=kZt7bKrvpFInSn7GzbkTkpoCSM0Z6IbDjgaDvcbFYf8,1024
100
101
  hcpdiff/workflow/flow.py,sha256=FFbFFOAXT4c31L5bHBEB_qeVGuBQDLYhq8kTD1chGNo,2548
101
102
  hcpdiff/workflow/io.py,sha256=aTrMR3s44apVJpnSyvZIabW2Op0tslk_Z9JFJl5svm0,2635
102
103
  hcpdiff/workflow/model.py,sha256=1gj5yOTefYTnGXVR6JPAfxIwuB69YwN6E-BontRcuyQ,2913
103
- hcpdiff/workflow/text.py,sha256=FSFUm_zEeZjMeg0qRXZAPplnJkg2pR_2FA3XljpoN2w,5110
104
+ hcpdiff/workflow/text.py,sha256=Z__SJHZyuaKyzkYJ6rbiAzOGRiYcCjwCGeqfpP1Jo7o,4336
104
105
  hcpdiff/workflow/utils.py,sha256=xojaMG4lHsymslc8df5uiVXmmBVWpn_Phqka8qzJEWw,2226
105
106
  hcpdiff/workflow/vae.py,sha256=cingDPkIOc4qGpOwwhXJK4EQbGoIxO583pm6gGov5t8,3118
106
107
  hcpdiff/workflow/daam/__init__.py,sha256=ySIDaxloN-D3qM7OuVaG1BR3D-CibDoXYpoTgw0zUhU,59
107
108
  hcpdiff/workflow/daam/act.py,sha256=tHbsFWTYYU4bvcZOo1Bpi_z6ofpJatRYccl4vvf8wIA,2756
108
109
  hcpdiff/workflow/daam/hook.py,sha256=z9f9mBjKW21xuUZ-iQxQ0HbWOBXtZrisFB0VNMq6d0U,4383
109
- hcpdiff-2.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
110
- hcpdiff-2.1.dist-info/METADATA,sha256=NpBZuj23d1gTKPQhJ0TBRV8QsfICa4LCGSk6PJNniSw,9248
111
- hcpdiff-2.1.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
112
- hcpdiff-2.1.dist-info/entry_points.txt,sha256=86wPOMzsfWWflTJ-sQPLc7WG5Vtu0kGYBH9C_vR3ur8,207
113
- hcpdiff-2.1.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
114
- hcpdiff-2.1.dist-info/RECORD,,
110
+ hcpdiff-2.2.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
111
+ hcpdiff-2.2.1.dist-info/METADATA,sha256=f96Tc90K5WTBbJ35wWJw60G2JR46eGpUvQSaPIysVDg,10323
112
+ hcpdiff-2.2.1.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
113
+ hcpdiff-2.2.1.dist-info/entry_points.txt,sha256=_4VRsEsEWOhHfzBDu9bx8Wh_S8Wi4ZTHpI0n6rU0J-I,258
114
+ hcpdiff-2.2.1.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
115
+ hcpdiff-2.2.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (78.1.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -2,4 +2,5 @@
2
2
  hcp_run = rainbowneko.infer.infer_workflow:run_workflow
3
3
  hcp_train = hcpdiff.trainer_ac:hcp_train
4
4
  hcp_train_1gpu = hcpdiff.trainer_ac_single:hcp_train
5
+ hcp_train_ds = hcpdiff.trainer_deepspeed:hcp_train
5
6
  hcpinit = hcpdiff.tools.init_proj:main
@@ -1,69 +0,0 @@
1
- import argparse
2
- import os
3
- import sys
4
- import warnings
5
- from functools import partial
6
-
7
- import torch
8
-
9
- from hcpdiff.ckpt_manager import CkptManagerPKL, CkptManagerSafe
10
- from hcpdiff.train_ac_old import Trainer, load_config_with_cli
11
- from hcpdiff.utils.net_utils import get_scheduler
12
-
13
- class TrainerDeepSpeed(Trainer):
14
-
15
- def build_ckpt_manager(self):
16
- self.ckpt_manager = self.ckpt_manager_map[self.cfgs.ckpt_type](plugin_from_raw=True)
17
- if self.is_local_main_process:
18
- self.ckpt_manager.set_save_dir(os.path.join(self.exp_dir, 'ckpts'), emb_dir=self.cfgs.tokenizer_pt.emb_dir)
19
-
20
- @property
21
- def unet_raw(self):
22
- return self.accelerator.unwrap_model(self.TE_unet).unet if self.train_TE else self.accelerator.unwrap_model(self.TE_unet.unet)
23
-
24
- @property
25
- def TE_raw(self):
26
- return self.accelerator.unwrap_model(self.TE_unet).TE if self.train_TE else self.TE_unet.TE
27
-
28
- def get_loss(self, model_pred, target, timesteps, att_mask):
29
- if att_mask is None:
30
- att_mask = 1.0
31
- if getattr(self.criterion, 'need_timesteps', False):
32
- loss = (self.criterion(model_pred.float(), target.float(), timesteps)*att_mask).mean()
33
- else:
34
- loss = (self.criterion(model_pred.float(), target.float())*att_mask).mean()
35
- return loss
36
-
37
- def build_optimizer_scheduler(self):
38
- # set optimizer
39
- parameters, parameters_pt = self.get_param_group_train()
40
-
41
- if len(parameters_pt)>0: # do prompt-tuning
42
- cfg_opt_pt = self.cfgs.train.optimizer_pt
43
- # if self.cfgs.train.scale_lr_pt:
44
- # self.scale_lr(parameters_pt)
45
- assert isinstance(cfg_opt_pt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
46
- weight_decay = cfg_opt_pt.keywords.get('weight_decay', None)
47
- if weight_decay is not None:
48
- for param in parameters_pt:
49
- param['weight_decay'] = weight_decay
50
-
51
- parameters += parameters_pt
52
- warnings.warn('deepspeed dose not support multi optimizer and lr_scheduler. optimizer_pt and scheduler_pt will not work.')
53
-
54
- if len(parameters)>0:
55
- cfg_opt = self.cfgs.train.optimizer
56
- if self.cfgs.train.scale_lr:
57
- self.scale_lr(parameters)
58
- assert isinstance(cfg_opt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
59
- self.optimizer = cfg_opt(params=parameters)
60
- self.lr_scheduler = get_scheduler(self.cfgs.train.scheduler, self.optimizer)
61
-
62
- if __name__ == '__main__':
63
- parser = argparse.ArgumentParser(description='Stable Diffusion Training')
64
- parser.add_argument('--cfg', type=str, default='cfg/train/demo.yaml')
65
- args, cfg_args = parser.parse_known_args()
66
-
67
- conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
68
- trainer = TrainerDeepSpeed(conf)
69
- trainer.train()