hcpdiff 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. hcpdiff/data/bucket.py +1 -1
  2. hcpdiff/data/utils.py +0 -1
  3. hcpdiff/tools/convert_caption_txt2json.py +9 -5
  4. hcpdiff/train_ac.py +26 -9
  5. hcpdiff/utils/net_utils.py +19 -17
  6. hcpdiff/visualizer.py +6 -5
  7. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +28 -13
  8. hcpdiff-0.3.7.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +15 -0
  9. hcpdiff-0.3.7.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +7 -0
  10. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/train/train_base.yaml +6 -2
  11. {hcpdiff-0.3.5.dist-info → hcpdiff-0.3.7.dist-info}/METADATA +3 -1
  12. {hcpdiff-0.3.5.dist-info → hcpdiff-0.3.7.dist-info}/RECORD +43 -42
  13. hcpdiff-0.3.5.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -58
  14. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -0
  15. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -0
  16. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/infer/img2img.yaml +0 -0
  17. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -0
  18. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/infer/text2img.yaml +0 -0
  19. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -0
  20. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/infer/webui_model_infer.yaml +0 -0
  21. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -0
  22. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/te_struct.txt +0 -0
  23. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -0
  24. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -0
  25. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -0
  26. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -0
  27. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -0
  28. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -0
  29. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -0
  30. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -0
  31. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -0
  32. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/cfgs/unet_struct.txt +0 -0
  33. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/prompt_tuning_template/caption.txt +0 -0
  34. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/prompt_tuning_template/name.txt +0 -0
  35. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -0
  36. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -0
  37. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/prompt_tuning_template/object.txt +0 -0
  38. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -0
  39. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/prompt_tuning_template/style.txt +0 -0
  40. {hcpdiff-0.3.5.data → hcpdiff-0.3.7.data}/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -0
  41. {hcpdiff-0.3.5.dist-info → hcpdiff-0.3.7.dist-info}/LICENSE +0 -0
  42. {hcpdiff-0.3.5.dist-info → hcpdiff-0.3.7.dist-info}/WHEEL +0 -0
  43. {hcpdiff-0.3.5.dist-info → hcpdiff-0.3.7.dist-info}/entry_points.txt +0 -0
  44. {hcpdiff-0.3.5.dist-info → hcpdiff-0.3.7.dist-info}/top_level.txt +0 -0
hcpdiff/data/bucket.py CHANGED
@@ -194,7 +194,7 @@ class RatioBucket(BaseBucket):
194
194
  rs.shuffle(x)
195
195
 
196
196
  # shuffle of batches
197
- bucket_list = np.hstack(bucket_list).reshape(-1, self.bs)
197
+ bucket_list = np.hstack(bucket_list).reshape(-1, self.bs).astype(int)
198
198
  rs.shuffle(bucket_list)
199
199
 
200
200
  self.idx_arb = bucket_list.reshape(-1)
hcpdiff/data/utils.py CHANGED
@@ -72,7 +72,6 @@ def collate_fn_ft(batch):
72
72
 
73
73
  class CycleData():
74
74
  def __init__(self, data_loader):
75
- print(data_loader)
76
75
  self.data_loader = data_loader
77
76
 
78
77
  def __iter__(self):
@@ -6,6 +6,7 @@ from hcpdiff.utils.img_size_tool import types_support
6
6
 
7
7
  parser = argparse.ArgumentParser(description='Stable Diffusion Training')
8
8
  parser.add_argument('--data_root', type=str, default='')
9
+ parser.add_argument('--with_imgs', action="store_true")
9
10
  args = parser.parse_args()
10
11
 
11
12
 
@@ -16,10 +17,13 @@ def get_txt_caption(path):
16
17
 
17
18
  captions = {}
18
19
  for file in os.listdir(args.data_root):
19
- ext_idx = file.rfind('.')
20
- file_name = file[:ext_idx]
21
- if file[ext_idx + 1:] in types_support:
22
- captions[file] = get_txt_caption(os.path.join(args.data_root, f'{file_name}.txt'))
20
+ file_name, file_ext = file.rsplit('.', 1)
21
+ if args.with_imgs:
22
+ if file_ext in types_support:
23
+ captions[file] = get_txt_caption(os.path.join(args.data_root, f'{file_name}.txt'))
24
+ else:
25
+ if file_ext == 'txt':
26
+ captions[file] = get_txt_caption(os.path.join(args.data_root, f'{file_name}.txt'))
23
27
 
24
28
  with open(os.path.join(args.data_root, f'image_captions.json'), "w", encoding='utf8') as f:
25
- json.dump(captions, f)
29
+ json.dump(captions, f, indent=2, ensure_ascii=False)
hcpdiff/train_ac.py CHANGED
@@ -28,6 +28,7 @@ from diffusers import AutoencoderKL, UNet2DConditionModel
28
28
  from diffusers.utils.import_utils import is_xformers_available
29
29
  from omegaconf import OmegaConf
30
30
  from transformers import AutoTokenizer
31
+ from functools import partial
31
32
 
32
33
  from hcpdiff.ckpt_manager import CkptManagerPKL, CkptManagerSafe
33
34
  from hcpdiff.data import RatioBucket, DataGroup, collate_fn_ft
@@ -329,12 +330,16 @@ class Trainer:
329
330
  # set optimizer
330
331
  parameters, parameters_pt = self.get_param_group_train()
331
332
 
332
- cfg_opt = self.cfgs.train.optimizer
333
333
  if len(parameters)>0: # do fine-tuning
334
+ cfg_opt = self.cfgs.train.optimizer
334
335
  if self.cfgs.train.scale_lr:
335
336
  self.scale_lr(parameters)
336
337
 
337
- if cfg_opt.type == 'adamw_8bit':
338
+ if isinstance(cfg_opt, partial):
339
+ if 'type' in cfg_opt.keywords:
340
+ del cfg_opt.keywords['type']
341
+ self.optimizer = cfg_opt(params=parameters, lr=self.lr)
342
+ elif cfg_opt.type == 'adamw_8bit':
338
343
  import bitsandbytes as bnb
339
344
  self.optimizer = bnb.optim.AdamW8bit(params=parameters, lr=self.lr, weight_decay=cfg_opt.weight_decay)
340
345
  elif cfg_opt.type == 'deepspeed' and self.accelerator.state.deepspeed_plugin is not None:
@@ -343,23 +348,35 @@ class Trainer:
343
348
  elif cfg_opt.type == 'adamw':
344
349
  self.optimizer = torch.optim.AdamW(params=parameters, lr=self.lr, weight_decay=cfg_opt.weight_decay)
345
350
  else:
346
- self.optimizer = cfg_opt.optimizer.opt(parameters, lr=self.lr)
351
+ raise NotImplementedError(f'Unknown optimizer {cfg_opt.type}')
347
352
 
348
- self.lr_scheduler = get_scheduler(optimizer=self.optimizer, **self.cfgs.train.scheduler)
353
+ if isinstance(self.cfgs.train.scheduler, partial):
354
+ self.lr_scheduler = self.cfgs.train.scheduler(optimizer=self.optimizer)
355
+ else:
356
+ self.lr_scheduler = get_scheduler(optimizer=self.optimizer, **self.cfgs.train.scheduler)
349
357
 
350
358
  if len(parameters_pt)>0: # do prompt-tuning
359
+ cfg_opt_pt = self.cfgs.train.optimizer_pt
351
360
  if self.cfgs.train.scale_lr_pt:
352
361
  self.scale_lr(parameters_pt)
362
+ if isinstance(cfg_opt_pt, partial):
363
+ if 'type' in cfg_opt_pt.keywords:
364
+ del cfg_opt_pt.keywords['type']
365
+ self.optimizer_pt = cfg_opt_pt(params=parameters_pt, lr=self.lr)
366
+ else:
367
+ self.optimizer_pt = torch.optim.AdamW(params=parameters_pt, lr=self.lr, weight_decay=cfg_opt_pt.weight_decay)
353
368
 
354
- self.optimizer_pt = torch.optim.AdamW(params=parameters_pt, lr=self.lr, weight_decay=cfg_opt.weight_decay_pt)
355
- self.lr_scheduler_pt = get_scheduler(optimizer=self.optimizer_pt, **self.cfgs.train.scheduler_pt)
369
+ if isinstance(self.cfgs.train.scheduler_pt, partial):
370
+ self.lr_scheduler_pt = self.cfgs.train.scheduler_pt(optimizer=self.optimizer_pt)
371
+ else:
372
+ self.lr_scheduler_pt = get_scheduler(optimizer=self.optimizer_pt, **self.cfgs.train.scheduler_pt)
356
373
 
357
374
  def train(self):
358
375
  total_batch_size = sum(self.batch_size_list)*self.world_size*self.cfgs.train.gradient_accumulation_steps
359
376
 
360
377
  self.loggers.info("***** Running training *****")
361
378
  self.loggers.info(f" Num batches each epoch = {len(self.train_loader_group.loader_list[0])}")
362
- self.loggers.info(f" Num Steps = {self.cfgs.train.scheduler.num_training_steps}")
379
+ self.loggers.info(f" Num Steps = {self.cfgs.train.train_steps}")
363
380
  self.loggers.info(f" Instantaneous batch size per device = {sum(self.batch_size_list)}")
364
381
  self.loggers.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
365
382
  self.loggers.info(f" Gradient Accumulation steps = {self.cfgs.train.gradient_accumulation_steps}")
@@ -380,13 +397,13 @@ class Trainer:
380
397
  lr_model = self.lr_scheduler.get_last_lr()[0] if hasattr(self, 'lr_scheduler') else 0.
381
398
  lr_word = self.lr_scheduler_pt.get_last_lr()[0] if hasattr(self, 'lr_scheduler_pt') else 0.
382
399
  self.loggers.log(datas={
383
- 'Step':{'format':'[{}/{}]', 'data':[self.global_step, self.cfgs.train.scheduler.num_training_steps]},
400
+ 'Step':{'format':'[{}/{}]', 'data':[self.global_step, self.cfgs.train.train_steps]},
384
401
  'LR_model':{'format':'{:.2e}', 'data':[lr_model]},
385
402
  'LR_word':{'format':'{:.2e}', 'data':[lr_word]},
386
403
  'Loss':{'format':'{:.5f}', 'data':[loss_sum.mean()]},
387
404
  }, step=self.global_step)
388
405
 
389
- if self.global_step>=self.cfgs.train.scheduler.num_training_steps:
406
+ if self.global_step>=self.cfgs.train.train_steps:
390
407
  break
391
408
 
392
409
  self.wait_for_everyone()
@@ -1,12 +1,11 @@
1
1
  import os
2
- from typing import Optional, Union, Tuple, Dict, Callable
2
+ from typing import Optional, Union
3
3
 
4
4
  import torch
5
+ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, Optimizer
5
6
  from torch import nn
6
7
  from torch.optim import lr_scheduler
7
- from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, Optimizer
8
8
  from transformers import PretrainedConfig
9
- from collections import OrderedDict
10
9
 
11
10
  class TEUnetWrapper(nn.Module):
12
11
  def __init__(self, unet, TE):
@@ -33,7 +32,7 @@ def get_scheduler(
33
32
  optimizer: Optimizer,
34
33
  num_warmup_steps: Optional[int] = None,
35
34
  num_training_steps: Optional[int] = None,
36
- scheduler_kwargs = {},
35
+ scheduler_kwargs={},
37
36
  ):
38
37
  """
39
38
  Unified API to get any scheduler from its name.
@@ -64,11 +63,11 @@ def get_scheduler(
64
63
  if num_warmup_steps is None:
65
64
  raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
66
65
 
67
- #One Cycle for super convergence
68
- if name=='one_cycle':
66
+ # One Cycle for super convergence
67
+ if name == 'one_cycle':
69
68
  scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=[x['lr'] for x in optimizer.state_dict()['param_groups']],
70
- steps_per_epoch=num_training_steps, epochs=1,
71
- pct_start=num_warmup_steps/num_training_steps, **scheduler_kwargs)
69
+ steps_per_epoch=num_training_steps, epochs=1,
70
+ pct_start=num_warmup_steps/num_training_steps, **scheduler_kwargs)
72
71
  return scheduler
73
72
 
74
73
  name = SchedulerType(name)
@@ -117,44 +116,46 @@ def remove_all_hooks(model: nn.Module) -> None:
117
116
  child._backward_hooks.clear()
118
117
 
119
118
  def remove_layers(model: nn.Module, layer_class):
120
- named_modules = {k: v for k, v in model.named_modules()}
121
- for k,v in named_modules.items():
119
+ named_modules = {k:v for k, v in model.named_modules()}
120
+ for k, v in named_modules.items():
122
121
  if isinstance(v, layer_class):
123
122
  parent, name = named_modules[k.rsplit('.', 1)]
124
123
  delattr(parent, name)
125
124
  del v
126
125
 
127
126
  def load_emb(path):
128
- emb=torch.load(path, map_location='cpu')['string_to_param']['*']
127
+ emb = torch.load(path, map_location='cpu')['string_to_param']['*']
129
128
  emb.requires_grad_(False)
130
129
  return emb
131
130
 
132
- def save_emb(path, emb:torch.Tensor, replace=False):
131
+ def save_emb(path, emb: torch.Tensor, replace=False):
133
132
  name = os.path.basename(path)
134
133
  if os.path.exists(path) and not replace:
135
134
  raise FileExistsError(f'embedding "{name}" already exist.')
136
- name=name[:name.rfind('.')]
135
+ name = name[:name.rfind('.')]
137
136
  torch.save({'string_to_param':{'*':emb}, 'name':name}, path)
138
137
 
139
-
140
138
  def hook_compile(model):
141
139
  named_modules = {k:v for k, v in model.named_modules()}
142
140
 
143
141
  for name, block in named_modules.items():
144
142
  if len(block._forward_hooks)>0:
145
- for hook in block._forward_hooks.values(): # 从前往后执行
143
+ for hook in block._forward_hooks.values(): # 从前往后执行
146
144
  old_forward = block.forward
145
+
147
146
  def new_forward(*args, **kwargs):
148
147
  result = old_forward(*args, **kwargs)
149
148
  hook_result = hook(block, args, result)
150
149
  if hook_result is not None:
151
150
  result = hook_result
152
151
  return result
152
+
153
153
  block.forward = new_forward
154
154
 
155
155
  if len(block._forward_pre_hooks)>0:
156
- for hook in list(block._forward_pre_hooks.values())[::-1]: # 从前往后执行
156
+ for hook in list(block._forward_pre_hooks.values())[::-1]: # 从前往后执行
157
157
  old_forward = block.forward
158
+
158
159
  def new_forward(*args, **kwargs):
159
160
  result = hook(block, args)
160
161
  if result is not None:
@@ -163,5 +164,6 @@ def hook_compile(model):
163
164
  else:
164
165
  result = args
165
166
  return old_forward(*result, **kwargs)
167
+
166
168
  block.forward = new_forward
167
- remove_all_hooks(model)
169
+ remove_all_hooks(model)
hcpdiff/visualizer.py CHANGED
@@ -124,12 +124,9 @@ class Visualizer:
124
124
  images = self.pipe(prompt_embeds=emb_p, negative_prompt_embeds=emb_n, **kwargs).images
125
125
  return images
126
126
 
127
- @torch.no_grad()
128
- def vis_to_dir(self, root, prompt, negative_prompt='', save_cfg=True, **kwargs):
127
+ def save_images(self, images, root, prompt, negative_prompt='', save_cfg=True):
129
128
  os.makedirs(root, exist_ok=True)
130
- num_img_exist = max([int(x.split('-',1)[0]) for x in os.listdir(root) if x.rsplit('.', 1)[-1] in types_support])+1
131
-
132
- images = self.vis_images(prompt, negative_prompt, **kwargs)
129
+ num_img_exist = max([int(x.split('-', 1)[0]) for x in os.listdir(root) if x.rsplit('.', 1)[-1] in types_support]) + 1
133
130
 
134
131
  for p, pn, img in zip(prompt, negative_prompt, images):
135
132
  img.save(os.path.join(root, f"{num_img_exist}-{to_validate_file(prompt[0])}.{self.cfgs.save.image_type}"), quality=self.cfgs.save.quality)
@@ -139,6 +136,10 @@ class Visualizer:
139
136
  f.write(OmegaConf.to_yaml(self.cfgs_raw))
140
137
  num_img_exist += 1
141
138
 
139
+ def vis_to_dir(self, root, prompt, negative_prompt='', save_cfg=True, **kwargs):
140
+ images = self.vis_images(prompt, negative_prompt, **kwargs)
141
+ self.save_images(images, root, prompt, negative_prompt, save_cfg=save_cfg)
142
+
142
143
  def show_latent(self, prompt, negative_prompt='', **kwargs):
143
144
  emb_n, emb_p = self.te_hook.encode_prompt_to_emb(negative_prompt + prompt).chunk(2)
144
145
  emb_p = self.te_hook.mult_attn(emb_p, self.token_ex.parse_attn_mult(prompt))
@@ -6,15 +6,21 @@ lora_unet:
6
6
  - lr: 1e-4
7
7
  rank: 0.01875
8
8
  branch: p
9
+ dropout: 0.1
9
10
  layers:
10
- - 're:.*\.attn.?$'
11
- #- 're:.*\.ff\.net\.0$' # Increases fitness, but potentially reduces controllability
11
+ - 're:.*\.to_k$'
12
+ - 're:.*\.to_v$'
13
+ - 're:.*\.ff$'
14
+ #- 're:.*\.attn.?$' # Increases fitness, but potentially reduces controllability
12
15
  - lr: 4e-5 # Low negative unet lr prevents image collapse
13
16
  rank: 0.01875
14
17
  branch: n
18
+ dropout: 0.1
15
19
  layers:
16
- - 're:.*\.attn.?$'
17
- #- 're:.*\.ff\.net\.0$' # Increases fitness, but potentially reduces controllability
20
+ - 're:.*\.to_k$'
21
+ - 're:.*\.to_v$'
22
+ - 're:.*\.ff$'
23
+ #- 're:.*\.attn.?$' # Increases fitness, but potentially reduces controllability
18
24
  # - lr: 1e-4
19
25
  # rank: 0.01875
20
26
  # type: p
@@ -27,23 +33,25 @@ lora_unet:
27
33
  # - 're:.*\.resnets$' # Increases fitness, but potentially reduces controllability and change style
28
34
 
29
35
  lora_text_encoder:
30
- - lr: 1e-5
31
- rank: 0.01
36
+ - lr: 2e-5
37
+ rank: 2
32
38
  branch: p
39
+ dropout: 0.1
33
40
  layers:
34
41
  - 're:.*self_attn$'
35
42
  - 're:.*mlp$'
36
- - lr: 1e-5
37
- rank: 0.01
43
+ - lr: 2e-5
44
+ rank: 2
38
45
  branch: n
46
+ dropout: 0.1
39
47
  layers:
40
48
  - 're:.*self_attn$'
41
49
  - 're:.*mlp$'
42
50
 
43
51
  tokenizer_pt:
44
52
  train: # prompt tuning embeddings
45
- - { name: 'pt-botdog1', lr: 0.003 }
46
- - { name: 'pt-botdog1-neg', lr: 0.003 }
53
+ - { name: 'pt-botdog1', lr: 0.0025 }
54
+ - { name: 'pt-botdog1-neg', lr: 0.0025 }
47
55
 
48
56
  train:
49
57
  gradient_accumulation_steps: 1
@@ -52,13 +60,19 @@ train:
52
60
  #cfg_scale: '1.0-3.0:cos' # dynamic CFG with timestamp
53
61
  cfg_scale: '3.0'
54
62
 
63
+ loss:
64
+ criterion: # min SNR loss
65
+ _target_: hcpdiff.loss.MinSNRLoss
66
+ gamma: 2.0
67
+
55
68
  scheduler:
56
- name: 'constant_with_warmup'
57
- num_warmup_steps: 50
69
+ name: one_cycle
70
+ num_warmup_steps: 200
58
71
  num_training_steps: 1000
72
+ scheduler_kwargs: { }
59
73
 
60
74
  scheduler_pt:
61
- name: 'one_cycle'
75
+ name: one_cycle
62
76
  num_warmup_steps: 200
63
77
  num_training_steps: 1000
64
78
  scheduler_kwargs: {}
@@ -68,6 +82,7 @@ model:
68
82
  tokenizer_repeats: 1
69
83
  ema_unet: 0
70
84
  ema_text_encoder: 0
85
+ clip_skip: 0
71
86
 
72
87
  data:
73
88
  dataset1:
@@ -0,0 +1,15 @@
1
+ _base_: [cfgs/train/examples/fine-tuning.yaml]
2
+
3
+ # Install: pip install lion-pytorch
4
+
5
+ train:
6
+ optimizer:
7
+ _target_: lion_pytorch.Lion
8
+ _partial_: True
9
+ weight_decay: 1e-2
10
+ #use_triton: True # set this to True to use cuda kernel w/ Triton lang (Tillet et al)
11
+
12
+ optimizer_pt:
13
+ _target_: lion_pytorch.Lion
14
+ _partial_: True
15
+ weight_decay: 1e-3
@@ -0,0 +1,7 @@
1
+ _base_: [cfgs/train/examples/fine-tuning.yaml]
2
+
3
+ train:
4
+ loss:
5
+ criterion: # min SNR loss
6
+ _target_: hcpdiff.loss.MinSNRLoss
7
+ gamma: 2.0
@@ -10,6 +10,7 @@ vis_info:
10
10
  negative_prompt: ''
11
11
 
12
12
  train:
13
+ train_steps: 1000
13
14
  gradient_accumulation_steps: 1
14
15
  workers: 4
15
16
  max_grad_norm: 1.0
@@ -35,7 +36,10 @@ train:
35
36
  optimizer:
36
37
  type: adamw
37
38
  weight_decay: 1e-3
38
- weight_decay_pt: 5e-4
39
+
40
+ optimizer_pt:
41
+ type: adamw
42
+ weight_decay: 5e-4
39
43
 
40
44
  scale_lr: True # auto scale lr with total batch size
41
45
  scheduler:
@@ -58,7 +62,7 @@ model:
58
62
  revision: null
59
63
  pretrained_model_name_or_path: null
60
64
  tokenizer_name: null
61
- tokenizer_repeats: 3
65
+ tokenizer_repeats: 2
62
66
  enable_xformers: True
63
67
  gradient_checkpointing: True
64
68
  ema_unet: 0 # 0 to disable
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: hcpdiff
3
- Version: 0.3.5
3
+ Version: 0.3.7
4
4
  Summary: A universal Stable-Diffusion toolbox
5
5
  Home-page: https://github.com/7eu7d7/HCP-Diffusion
6
6
  Author: Ziyi Dong
@@ -78,6 +78,8 @@ Compared to DreamArtist, DreamArtist++ is more stable with higher image quality
78
78
  * safetensors support
79
79
  * Controlnet (support train)
80
80
  * Min-SNR loss
81
+ * Custom optimizer (Lion, DAdaptation, pytorch-optimizer, ...)
82
+ * Custom lr scheduler
81
83
 
82
84
  ## Install
83
85
 
@@ -1,16 +1,16 @@
1
1
  hcpdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- hcpdiff/train_ac.py,sha256=03SQFub2-rSkVZqFBDoh7vAYP8VGqz9m7nvDHoP7lD8,26244
2
+ hcpdiff/train_ac.py,sha256=1ZHWKryEy-1JPzmJsnoSzn47M9pKdbjDGxWMSNDLOYs,27108
3
3
  hcpdiff/train_ac_single.py,sha256=v35IWDwla9tunDksozwMBijas2aMZuhDofAKHX1RHzE,4785
4
4
  hcpdiff/train_colo.py,sha256=gApEmLYETlpMlngkBaKvb1GvwCH3EIqElZn1k8TSmCo,9319
5
- hcpdiff/visualizer.py,sha256=uNR9J6wiyJA_4fMRp7HVJxbTBzivU4C4zS9zHmbKVLc,7395
5
+ hcpdiff/visualizer.py,sha256=6R3C-MdDJlMK3UCnw7is2jGfF18ens9rwaYYURKD1Ec,7544
6
6
  hcpdiff/ckpt_manager/__init__.py,sha256=AQOsLRmYOVXVw-yhzuYadiMKE3vUhvDr86srqOuhoW0,200
7
7
  hcpdiff/ckpt_manager/ckpt_pkl.py,sha256=WX-79vYZCeNnFixs7B3BDeo9659jUwVC9CZ5kVCjEpE,4173
8
8
  hcpdiff/ckpt_manager/ckpt_safetensor.py,sha256=UHrekZO9UfK9pa7lOjimbGxCHN65YX1NfC4w8N56rrM,1835
9
9
  hcpdiff/data/__init__.py,sha256=0enIUGK7Oo4k1FuSaLizWvPk1uI3QbvcPZuQlM9ZpaE,795
10
- hcpdiff/data/bucket.py,sha256=-p9Yr9ANMVPsksVmKmu8UunHkO0Q9SonBolPKZnrnGA,8554
10
+ hcpdiff/data/bucket.py,sha256=5Wd5TRCZAhV6vLf2IGyPiV1lvLNrTC1yQGjxXn8iadY,8566
11
11
  hcpdiff/data/cond_pair_dataset.py,sha256=v12acr_m2l0k42yaXs-utyuvi8sqRSIuEOBG7Uqsll0,2214
12
12
  hcpdiff/data/pair_dataset.py,sha256=32H-wUWufXL7qtN5G1dfc2wQEBSEv70erisQzsecH1I,6104
13
- hcpdiff/data/utils.py,sha256=zPbf9VzJsieCCVY6nGU8YJzkyo9h2q0yydkcsVj7J8E,2615
13
+ hcpdiff/data/utils.py,sha256=HpqzeiDOHT1sUWP0_Sm90W5l-6Xv67MWDChR3LeHL7Q,2588
14
14
  hcpdiff/loggers/__init__.py,sha256=mdVbijYrSxpQPC7WnSD6FKDPdc9NyE2_GueB1yGNJ6M,275
15
15
  hcpdiff/loggers/base_logger.py,sha256=E8L3hDuP_2QahUrtWPcKqSbr67Xtt5Qv47IOexFSPHc,1845
16
16
  hcpdiff/loggers/cli_logger.py,sha256=pzWjPr1M-7UleZJWW863Mu4StXY8rC1nR9tMYM16ulo,1276
@@ -30,7 +30,7 @@ hcpdiff/models/text_emb_ex.py,sha256=JKd4Dta-CU0YATBVos6MWvLu-nUYGx_Qdjig-2alkvQ
30
30
  hcpdiff/models/textencoder_ex.py,sha256=zI_SXbvzZvsjxQb4m9vC_c82Y7JVjktn7MsBnfmYRqU,5674
31
31
  hcpdiff/models/tokenizer_ex.py,sha256=lb6_k0BZMskXHnYUXliqA47tjCVY44OvqATUrGn7DhM,2365
32
32
  hcpdiff/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
- hcpdiff/tools/convert_caption_txt2json.py,sha256=FGolVVlJpJ60MOLb7AqqI1HQXV_VKx56Nzv68vZAE0c,728
33
+ hcpdiff/tools/convert_caption_txt2json.py,sha256=ToKhrfAcN4Vzxc2gTwQOr8RQwKkPBgvpRndUlHzJOOE,955
34
34
  hcpdiff/tools/create_embedding.py,sha256=tnPpeSY1vW1P-jsjH_XvnDK83_zGx5RLr34lD38WCz0,1870
35
35
  hcpdiff/tools/gen_from_ptlist.py,sha256=LyNyc7u1fWNX1f92AgVYH3w_ePhOXmAM9HEWWSB6_T4,1691
36
36
  hcpdiff/tools/init_proj.py,sha256=7W2bsD_Yt4q85s3-dG2hSatgqn5MQb0JpuUXwmNpSto,756
@@ -42,41 +42,42 @@ hcpdiff/utils/cfg_net_tools.py,sha256=vFVyg5_cY6ygCovyNp4rXnRRW3BRZy8J4AOYzNy0CK
42
42
  hcpdiff/utils/colo_utils.py,sha256=JyLUvVnISa48CnryNLrgVxMo-jxu2UhBq70eYPrkjuI,837
43
43
  hcpdiff/utils/ema.py,sha256=-ZEp0wsaG9yiH-Xy6KE4jj2VKFkaH0Nbz_u9RuTPyfo,2993
44
44
  hcpdiff/utils/img_size_tool.py,sha256=Yh3_jZnGyXjb_5fKc-GkCVbotWarTrr5B90A0GETv5I,8433
45
- hcpdiff/utils/net_utils.py,sha256=-eAsiM3IPcm1oU0TUh7TEqL_58lCJ6wvI5E8SgP-_Wg,7052
45
+ hcpdiff/utils/net_utils.py,sha256=7Zg47SJw6BCBpl2Nd8E5taxcvnur1NyWSZld8mziLEs,7029
46
46
  hcpdiff/utils/utils.py,sha256=zv54E2DtCctHAr1fdod--68fSiDl0FrfYyposkrOF6E,2756
47
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/te_struct.txt,sha256=yL4_mZ1MFVYsXgZc2amL3HcfHhVU97huy0Kufd4Nm3c,9778
48
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/unet_struct.txt,sha256=tsyLnBNkBvOCgz9BguOkaD_WDgzDUSAOTREUh3lF1Ns,45190
49
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/infer/change_vae.yaml,sha256=vMqGIwHTNlI-8_UZ4GskH7ILOMRSRrxwlCRfmkkVVrw,174
50
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/infer/euler_a.yaml,sha256=_NVVhFhPLXNm1ywxxq6HXURAuSnPjutO9rx7zPFfZl8,646
51
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/infer/img2img.yaml,sha256=UxOxUMShKYk6uR5w5aTRITINwkfleGrp-YQUJcmdD64,1608
52
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml,sha256=ucOHq0TD6CXWmrBKm1fzixH8jum3a092Iqgu5X3PpUA,960
53
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/infer/text2img.yaml,sha256=FlsXEhjAYaByOOHNEp_SloVVkVv2Rq9ANoV6s2eWFmk,1216
54
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml,sha256=xHHkMRklE4uH15WUBAy-BgzCcZp6hOKApEofIrcME80,1542
55
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/infer/webui_model_infer.yaml,sha256=A9id_rgF2MO9Hu2yFYZeqaWK2PmVHN4QC67c1R6fDPk,1330
56
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml,sha256=paoKVoYqxmMc4P65qVWTIBnPAcKvaWtWQSD9b6WrnTg,426
57
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/train/train_base.yaml,sha256=U1snV7Nc31nE4fNuBtaK8fZJknrpsQ63vfH-XGkslEM,2849
58
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/train/tuning_base.yaml,sha256=3sDKNvwe1ztqpQnE5kTvVXDNARdzz9_v7dhhyRZU6-E,984
59
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml,sha256=0ILjxlNRCrsadKj8d8md4hlHBAfpj0EUp-kQ4d7yBUU,2412
60
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml,sha256=m4VtlXiOxHRu55JK07mKLfpKyG_50vaK3wW1rTDGAQc,3404
61
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml,sha256=GEoSf2lsOdL2rkE5vpjysL6dc6FBcpTnGhB2dVJnTkU,1431
62
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml,sha256=88jbKuEYmvCaZxtBtVZk4C4YZawPlRg08xWJd3xcyB4,2219
63
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml,sha256=ZGf-8oz2-X0ELTihy9TF2qurn1mCB9auMOYt7hbP3Ak,1224
64
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml,sha256=XeDgjcjmor_gY4Qsy86et0tYm9VR8RVAJaoW-jF3UvQ,1141
65
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml,sha256=_FV-ixk5UacI-opd7PJNO1rZs8Be0ofBGy9QAl_E8vg,1263
66
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/train/examples/locon.yaml,sha256=1v3RxbwuG5iZmljcSqycq8IYvuTVC7pNlzmGb1RO2A8,1471
67
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml,sha256=0i1cl31v0n91Jmcq8LYuIr5ugvZ7pxmZNw0w3udkPtY,1310
68
- hcpdiff-0.3.5.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml,sha256=tdYy3TIS8AHb09y0wftnxCBQUUgd7OFJOpABRLd1C1s,1359
69
- hcpdiff-0.3.5.data/data/hcpdiff/prompt_tuning_template/caption.txt,sha256=BFF5IJffv3lvSo6Axoc9tNLeby0T84zYSFN49wHqgWc,14
70
- hcpdiff-0.3.5.data/data/hcpdiff/prompt_tuning_template/name.txt,sha256=M5ZazRDxHC0Tj4HDe8iZ2BDklywaiRRk1UG0G_Sf01g,14
71
- hcpdiff-0.3.5.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt,sha256=3hP46lbjvV480jkaHyjJ6kOh-RBuBt-qJ56XkdfFUts,55
72
- hcpdiff-0.3.5.data/data/hcpdiff/prompt_tuning_template/name_caption.txt,sha256=r0J196gVkTGguSrkClVdQVrOi0m7UYHmjE-JPNS1eWg,25
73
- hcpdiff-0.3.5.data/data/hcpdiff/prompt_tuning_template/object.txt,sha256=AAQFaMPewZpoCbz_GLqqb3G3BR9pKAXCSj3PBec2iNU,889
74
- hcpdiff-0.3.5.data/data/hcpdiff/prompt_tuning_template/object_caption.txt,sha256=KykccgokzxIxIvJqUcqKZLNlmuWFQ2m__7eeTGRAeCo,1186
75
- hcpdiff-0.3.5.data/data/hcpdiff/prompt_tuning_template/style.txt,sha256=JiYeLSWrDlDZ1xCiyFaoQUNDYsVXJwUPNCcW4VDlj2w,872
76
- hcpdiff-0.3.5.data/data/hcpdiff/prompt_tuning_template/style_caption.txt,sha256=p3LUkXf_AxO9U5fsdNxDoSYjvTkF1LxPwUX3KRio8P0,1081
77
- hcpdiff-0.3.5.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
78
- hcpdiff-0.3.5.dist-info/METADATA,sha256=ZZhmR0gtimlND5bCGzX0W9depz7KismiyLleIGKfz3M,6219
79
- hcpdiff-0.3.5.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
80
- hcpdiff-0.3.5.dist-info/entry_points.txt,sha256=TZu3cbc9T3KmoNHx6JfTdHV_Ff89yn4XD7DoluIjZWk,57
81
- hcpdiff-0.3.5.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
82
- hcpdiff-0.3.5.dist-info/RECORD,,
47
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/te_struct.txt,sha256=yL4_mZ1MFVYsXgZc2amL3HcfHhVU97huy0Kufd4Nm3c,9778
48
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/unet_struct.txt,sha256=tsyLnBNkBvOCgz9BguOkaD_WDgzDUSAOTREUh3lF1Ns,45190
49
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/infer/change_vae.yaml,sha256=vMqGIwHTNlI-8_UZ4GskH7ILOMRSRrxwlCRfmkkVVrw,174
50
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/infer/euler_a.yaml,sha256=_NVVhFhPLXNm1ywxxq6HXURAuSnPjutO9rx7zPFfZl8,646
51
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/infer/img2img.yaml,sha256=UxOxUMShKYk6uR5w5aTRITINwkfleGrp-YQUJcmdD64,1608
52
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml,sha256=ucOHq0TD6CXWmrBKm1fzixH8jum3a092Iqgu5X3PpUA,960
53
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/infer/text2img.yaml,sha256=FlsXEhjAYaByOOHNEp_SloVVkVv2Rq9ANoV6s2eWFmk,1216
54
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml,sha256=xHHkMRklE4uH15WUBAy-BgzCcZp6hOKApEofIrcME80,1542
55
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/infer/webui_model_infer.yaml,sha256=A9id_rgF2MO9Hu2yFYZeqaWK2PmVHN4QC67c1R6fDPk,1330
56
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml,sha256=paoKVoYqxmMc4P65qVWTIBnPAcKvaWtWQSD9b6WrnTg,426
57
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/train/train_base.yaml,sha256=xxGXHZ5LiOzMzk59N6zCRWbGfsPnkcQCvYNQtHLLM70,2899
58
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/train/tuning_base.yaml,sha256=3sDKNvwe1ztqpQnE5kTvVXDNARdzz9_v7dhhyRZU6-E,984
59
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml,sha256=0ILjxlNRCrsadKj8d8md4hlHBAfpj0EUp-kQ4d7yBUU,2412
60
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml,sha256=fIiLT2bslMHNMm2YVi401-iuFVdvp7ORI9gDT1M-4FY,3667
61
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml,sha256=GEoSf2lsOdL2rkE5vpjysL6dc6FBcpTnGhB2dVJnTkU,1431
62
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml,sha256=88jbKuEYmvCaZxtBtVZk4C4YZawPlRg08xWJd3xcyB4,2219
63
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml,sha256=06wR6KFhhzYJ9Pn2LFI1s5iUIwaZX229E7OrzqT_CsA,361
64
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml,sha256=ZGf-8oz2-X0ELTihy9TF2qurn1mCB9auMOYt7hbP3Ak,1224
65
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml,sha256=XeDgjcjmor_gY4Qsy86et0tYm9VR8RVAJaoW-jF3UvQ,1141
66
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml,sha256=_FV-ixk5UacI-opd7PJNO1rZs8Be0ofBGy9QAl_E8vg,1263
67
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/train/examples/locon.yaml,sha256=1v3RxbwuG5iZmljcSqycq8IYvuTVC7pNlzmGb1RO2A8,1471
68
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml,sha256=0i1cl31v0n91Jmcq8LYuIr5ugvZ7pxmZNw0w3udkPtY,1310
69
+ hcpdiff-0.3.7.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml,sha256=ZftPXDHzkHOQ_6NtvZ1dbSdHw0-RGyr4B7uE8TAcdOE,149
70
+ hcpdiff-0.3.7.data/data/hcpdiff/prompt_tuning_template/caption.txt,sha256=BFF5IJffv3lvSo6Axoc9tNLeby0T84zYSFN49wHqgWc,14
71
+ hcpdiff-0.3.7.data/data/hcpdiff/prompt_tuning_template/name.txt,sha256=M5ZazRDxHC0Tj4HDe8iZ2BDklywaiRRk1UG0G_Sf01g,14
72
+ hcpdiff-0.3.7.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt,sha256=3hP46lbjvV480jkaHyjJ6kOh-RBuBt-qJ56XkdfFUts,55
73
+ hcpdiff-0.3.7.data/data/hcpdiff/prompt_tuning_template/name_caption.txt,sha256=r0J196gVkTGguSrkClVdQVrOi0m7UYHmjE-JPNS1eWg,25
74
+ hcpdiff-0.3.7.data/data/hcpdiff/prompt_tuning_template/object.txt,sha256=AAQFaMPewZpoCbz_GLqqb3G3BR9pKAXCSj3PBec2iNU,889
75
+ hcpdiff-0.3.7.data/data/hcpdiff/prompt_tuning_template/object_caption.txt,sha256=KykccgokzxIxIvJqUcqKZLNlmuWFQ2m__7eeTGRAeCo,1186
76
+ hcpdiff-0.3.7.data/data/hcpdiff/prompt_tuning_template/style.txt,sha256=JiYeLSWrDlDZ1xCiyFaoQUNDYsVXJwUPNCcW4VDlj2w,872
77
+ hcpdiff-0.3.7.data/data/hcpdiff/prompt_tuning_template/style_caption.txt,sha256=p3LUkXf_AxO9U5fsdNxDoSYjvTkF1LxPwUX3KRio8P0,1081
78
+ hcpdiff-0.3.7.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
79
+ hcpdiff-0.3.7.dist-info/METADATA,sha256=MfXHvfiw7MJmu95uMfkfeONiBN1f--kVZhchBgFSZPU,6304
80
+ hcpdiff-0.3.7.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
81
+ hcpdiff-0.3.7.dist-info/entry_points.txt,sha256=TZu3cbc9T3KmoNHx6JfTdHV_Ff89yn4XD7DoluIjZWk,57
82
+ hcpdiff-0.3.7.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
83
+ hcpdiff-0.3.7.dist-info/RECORD,,
@@ -1,58 +0,0 @@
1
- _base_: [cfgs/train/train_base.yaml, cfgs/train/tuning_base.yaml]
2
-
3
- unet:
4
- -
5
- lr: 1e-6
6
- layers:
7
- - '' # fine-tuning all layers in unet
8
-
9
- ## fine-tuning text-encoder
10
- #text_encoder:
11
- # - lr: 1e-6
12
- # layers:
13
- # - ''
14
-
15
- tokenizer_pt:
16
- train: null
17
-
18
- train:
19
- gradient_accumulation_steps: 1
20
- save_step: 100
21
-
22
- loss:
23
- criterion: # min SNR loss
24
- _target_: hcpdiff.loss.MinSNRLoss
25
- gamma: 2.0
26
-
27
- scheduler:
28
- name: 'constant_with_warmup'
29
- num_warmup_steps: 50
30
- num_training_steps: 600
31
-
32
- model:
33
- pretrained_model_name_or_path: 'runwayml/stable-diffusion-v1-5'
34
- tokenizer_repeats: 1
35
- ema_unet: 0
36
- ema_text_encoder: 0
37
-
38
- data:
39
- dataset1:
40
- batch_size: 4
41
- cache_latents: True
42
-
43
- source:
44
- data_source1:
45
- img_root: 'imgs/'
46
- prompt_template: 'prompt_tuning_template/object.txt'
47
- caption_file: null # path to image captions (file_words)
48
- tag_transforms:
49
- transforms:
50
- - _target_: hcpdiff.utils.caption_tools.TagShuffle
51
- - _target_: hcpdiff.utils.caption_tools.TagDropout
52
- p: 0.1
53
- - _target_: hcpdiff.utils.caption_tools.TemplateFill
54
- word_names: { }
55
- bucket:
56
- _target_: hcpdiff.data.bucket.RatioBucket.from_files # aspect ratio bucket
57
- target_area: {_target_: "builtins.eval", _args_: ['512*512']}
58
- num_bucket: 5