hcpdiff 0.3.5__tar.gz → 0.3.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. {hcpdiff-0.3.5/hcpdiff.egg-info → hcpdiff-0.3.7}/PKG-INFO +3 -1
  2. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/README.md +2 -0
  3. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/train/examples/DreamArtist++.yaml +28 -13
  4. hcpdiff-0.3.7/cfgs/train/examples/Lion_optimizer.yaml +15 -0
  5. hcpdiff-0.3.7/cfgs/train/examples/min_snr.yaml +7 -0
  6. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/train/train_base.yaml +6 -2
  7. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/data/bucket.py +1 -1
  8. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/data/utils.py +0 -1
  9. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/tools/convert_caption_txt2json.py +9 -5
  10. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/train_ac.py +26 -9
  11. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/utils/net_utils.py +19 -17
  12. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/visualizer.py +6 -5
  13. {hcpdiff-0.3.5 → hcpdiff-0.3.7/hcpdiff.egg-info}/PKG-INFO +3 -1
  14. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff.egg-info/SOURCES.txt +1 -0
  15. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/setup.py +1 -1
  16. hcpdiff-0.3.5/cfgs/train/examples/min_snr.yaml +0 -58
  17. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/LICENSE +0 -0
  18. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/infer/change_vae.yaml +0 -0
  19. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/infer/euler_a.yaml +0 -0
  20. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/infer/img2img.yaml +0 -0
  21. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/infer/img2img_controlnet.yaml +0 -0
  22. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/infer/text2img.yaml +0 -0
  23. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/infer/text2img_DA++.yaml +0 -0
  24. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/infer/webui_model_infer.yaml +0 -0
  25. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/plugins/plugin_controlnet.yaml +0 -0
  26. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/te_struct.txt +0 -0
  27. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/train/examples/CustomDiffusion.yaml +0 -0
  28. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/train/examples/DreamArtist.yaml +0 -0
  29. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/train/examples/DreamBooth.yaml +0 -0
  30. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/train/examples/TextualInversion.yaml +0 -0
  31. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/train/examples/controlnet.yaml +0 -0
  32. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/train/examples/fine-tuning.yaml +0 -0
  33. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/train/examples/locon.yaml +0 -0
  34. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/train/examples/lora_conventional.yaml +0 -0
  35. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/train/tuning_base.yaml +0 -0
  36. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/cfgs/unet_struct.txt +0 -0
  37. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/__init__.py +0 -0
  38. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/ckpt_manager/__init__.py +0 -0
  39. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/ckpt_manager/ckpt_pkl.py +0 -0
  40. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -0
  41. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/data/__init__.py +0 -0
  42. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/data/cond_pair_dataset.py +0 -0
  43. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/data/pair_dataset.py +0 -0
  44. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/loggers/__init__.py +0 -0
  45. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/loggers/base_logger.py +0 -0
  46. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/loggers/cli_logger.py +0 -0
  47. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/loggers/tensorboard_logger.py +0 -0
  48. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/loggers/wandb_logger.py +0 -0
  49. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/loss/__init__.py +0 -0
  50. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/loss/min_snr_loss.py +0 -0
  51. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/loss/mse_loss.py +0 -0
  52. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/models/__init__.py +0 -0
  53. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/models/cfg_context.py +0 -0
  54. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/models/controlnet.py +0 -0
  55. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/models/layers.py +0 -0
  56. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/models/lora_base.py +0 -0
  57. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/models/lora_layers.py +0 -0
  58. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/models/plugin.py +0 -0
  59. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/models/text_emb_ex.py +0 -0
  60. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/models/textencoder_ex.py +0 -0
  61. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/models/tokenizer_ex.py +0 -0
  62. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/tools/__init__.py +0 -0
  63. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/tools/create_embedding.py +0 -0
  64. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/tools/gen_from_ptlist.py +0 -0
  65. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/tools/init_proj.py +0 -0
  66. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/tools/lora_convert.py +0 -0
  67. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/tools/sd2diffusers.py +0 -0
  68. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/train_ac_single.py +0 -0
  69. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/train_colo.py +0 -0
  70. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/utils/__init__.py +0 -0
  71. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/utils/caption_tools.py +0 -0
  72. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/utils/cfg_net_tools.py +0 -0
  73. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/utils/colo_utils.py +0 -0
  74. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/utils/ema.py +0 -0
  75. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/utils/img_size_tool.py +0 -0
  76. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff/utils/utils.py +0 -0
  77. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff.egg-info/dependency_links.txt +0 -0
  78. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff.egg-info/entry_points.txt +0 -0
  79. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff.egg-info/requires.txt +0 -0
  80. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/hcpdiff.egg-info/top_level.txt +0 -0
  81. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/prompt_tuning_template/caption.txt +0 -0
  82. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/prompt_tuning_template/name.txt +0 -0
  83. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/prompt_tuning_template/name_2pt_caption.txt +0 -0
  84. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/prompt_tuning_template/name_caption.txt +0 -0
  85. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/prompt_tuning_template/object.txt +0 -0
  86. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/prompt_tuning_template/object_caption.txt +0 -0
  87. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/prompt_tuning_template/style.txt +0 -0
  88. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/prompt_tuning_template/style_caption.txt +0 -0
  89. {hcpdiff-0.3.5 → hcpdiff-0.3.7}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: hcpdiff
3
- Version: 0.3.5
3
+ Version: 0.3.7
4
4
  Summary: A universal Stable-Diffusion toolbox
5
5
  Home-page: https://github.com/7eu7d7/HCP-Diffusion
6
6
  Author: Ziyi Dong
@@ -58,6 +58,8 @@ Compared to DreamArtist, DreamArtist++ is more stable with higher image quality
58
58
  * safetensors support
59
59
  * Controlnet (support train)
60
60
  * Min-SNR loss
61
+ * Custom optimizer (Lion, DAdaptation, pytorch-optimizer, ...)
62
+ * Custom lr scheduler
61
63
 
62
64
  ## Install
63
65
 
@@ -39,6 +39,8 @@ Compared to DreamArtist, DreamArtist++ is more stable with higher image quality
39
39
  * safetensors support
40
40
  * Controlnet (support train)
41
41
  * Min-SNR loss
42
+ * Custom optimizer (Lion, DAdaptation, pytorch-optimizer, ...)
43
+ * Custom lr scheduler
42
44
 
43
45
  ## Install
44
46
 
@@ -6,15 +6,21 @@ lora_unet:
6
6
  - lr: 1e-4
7
7
  rank: 0.01875
8
8
  branch: p
9
+ dropout: 0.1
9
10
  layers:
10
- - 're:.*\.attn.?$'
11
- #- 're:.*\.ff\.net\.0$' # Increases fitness, but potentially reduces controllability
11
+ - 're:.*\.to_k$'
12
+ - 're:.*\.to_v$'
13
+ - 're:.*\.ff$'
14
+ #- 're:.*\.attn.?$' # Increases fitness, but potentially reduces controllability
12
15
  - lr: 4e-5 # Low negative unet lr prevents image collapse
13
16
  rank: 0.01875
14
17
  branch: n
18
+ dropout: 0.1
15
19
  layers:
16
- - 're:.*\.attn.?$'
17
- #- 're:.*\.ff\.net\.0$' # Increases fitness, but potentially reduces controllability
20
+ - 're:.*\.to_k$'
21
+ - 're:.*\.to_v$'
22
+ - 're:.*\.ff$'
23
+ #- 're:.*\.attn.?$' # Increases fitness, but potentially reduces controllability
18
24
  # - lr: 1e-4
19
25
  # rank: 0.01875
20
26
  # type: p
@@ -27,23 +33,25 @@ lora_unet:
27
33
  # - 're:.*\.resnets$' # Increases fitness, but potentially reduces controllability and change style
28
34
 
29
35
  lora_text_encoder:
30
- - lr: 1e-5
31
- rank: 0.01
36
+ - lr: 2e-5
37
+ rank: 2
32
38
  branch: p
39
+ dropout: 0.1
33
40
  layers:
34
41
  - 're:.*self_attn$'
35
42
  - 're:.*mlp$'
36
- - lr: 1e-5
37
- rank: 0.01
43
+ - lr: 2e-5
44
+ rank: 2
38
45
  branch: n
46
+ dropout: 0.1
39
47
  layers:
40
48
  - 're:.*self_attn$'
41
49
  - 're:.*mlp$'
42
50
 
43
51
  tokenizer_pt:
44
52
  train: # prompt tuning embeddings
45
- - { name: 'pt-botdog1', lr: 0.003 }
46
- - { name: 'pt-botdog1-neg', lr: 0.003 }
53
+ - { name: 'pt-botdog1', lr: 0.0025 }
54
+ - { name: 'pt-botdog1-neg', lr: 0.0025 }
47
55
 
48
56
  train:
49
57
  gradient_accumulation_steps: 1
@@ -52,13 +60,19 @@ train:
52
60
  #cfg_scale: '1.0-3.0:cos' # dynamic CFG with timestamp
53
61
  cfg_scale: '3.0'
54
62
 
63
+ loss:
64
+ criterion: # min SNR loss
65
+ _target_: hcpdiff.loss.MinSNRLoss
66
+ gamma: 2.0
67
+
55
68
  scheduler:
56
- name: 'constant_with_warmup'
57
- num_warmup_steps: 50
69
+ name: one_cycle
70
+ num_warmup_steps: 200
58
71
  num_training_steps: 1000
72
+ scheduler_kwargs: { }
59
73
 
60
74
  scheduler_pt:
61
- name: 'one_cycle'
75
+ name: one_cycle
62
76
  num_warmup_steps: 200
63
77
  num_training_steps: 1000
64
78
  scheduler_kwargs: {}
@@ -68,6 +82,7 @@ model:
68
82
  tokenizer_repeats: 1
69
83
  ema_unet: 0
70
84
  ema_text_encoder: 0
85
+ clip_skip: 0
71
86
 
72
87
  data:
73
88
  dataset1:
@@ -0,0 +1,15 @@
1
+ _base_: [cfgs/train/examples/fine-tuning.yaml]
2
+
3
+ # Install: pip install lion-pytorch
4
+
5
+ train:
6
+ optimizer:
7
+ _target_: lion_pytorch.Lion
8
+ _partial_: True
9
+ weight_decay: 1e-2
10
+ #use_triton: True # set this to True to use cuda kernel w/ Triton lang (Tillet et al)
11
+
12
+ optimizer_pt:
13
+ _target_: lion_pytorch.Lion
14
+ _partial_: True
15
+ weight_decay: 1e-3
@@ -0,0 +1,7 @@
1
+ _base_: [cfgs/train/examples/fine-tuning.yaml]
2
+
3
+ train:
4
+ loss:
5
+ criterion: # min SNR loss
6
+ _target_: hcpdiff.loss.MinSNRLoss
7
+ gamma: 2.0
@@ -10,6 +10,7 @@ vis_info:
10
10
  negative_prompt: ''
11
11
 
12
12
  train:
13
+ train_steps: 1000
13
14
  gradient_accumulation_steps: 1
14
15
  workers: 4
15
16
  max_grad_norm: 1.0
@@ -35,7 +36,10 @@ train:
35
36
  optimizer:
36
37
  type: adamw
37
38
  weight_decay: 1e-3
38
- weight_decay_pt: 5e-4
39
+
40
+ optimizer_pt:
41
+ type: adamw
42
+ weight_decay: 5e-4
39
43
 
40
44
  scale_lr: True # auto scale lr with total batch size
41
45
  scheduler:
@@ -58,7 +62,7 @@ model:
58
62
  revision: null
59
63
  pretrained_model_name_or_path: null
60
64
  tokenizer_name: null
61
- tokenizer_repeats: 3
65
+ tokenizer_repeats: 2
62
66
  enable_xformers: True
63
67
  gradient_checkpointing: True
64
68
  ema_unet: 0 # 0 to disable
@@ -194,7 +194,7 @@ class RatioBucket(BaseBucket):
194
194
  rs.shuffle(x)
195
195
 
196
196
  # shuffle of batches
197
- bucket_list = np.hstack(bucket_list).reshape(-1, self.bs)
197
+ bucket_list = np.hstack(bucket_list).reshape(-1, self.bs).astype(int)
198
198
  rs.shuffle(bucket_list)
199
199
 
200
200
  self.idx_arb = bucket_list.reshape(-1)
@@ -72,7 +72,6 @@ def collate_fn_ft(batch):
72
72
 
73
73
  class CycleData():
74
74
  def __init__(self, data_loader):
75
- print(data_loader)
76
75
  self.data_loader = data_loader
77
76
 
78
77
  def __iter__(self):
@@ -6,6 +6,7 @@ from hcpdiff.utils.img_size_tool import types_support
6
6
 
7
7
  parser = argparse.ArgumentParser(description='Stable Diffusion Training')
8
8
  parser.add_argument('--data_root', type=str, default='')
9
+ parser.add_argument('--with_imgs', action="store_true")
9
10
  args = parser.parse_args()
10
11
 
11
12
 
@@ -16,10 +17,13 @@ def get_txt_caption(path):
16
17
 
17
18
  captions = {}
18
19
  for file in os.listdir(args.data_root):
19
- ext_idx = file.rfind('.')
20
- file_name = file[:ext_idx]
21
- if file[ext_idx + 1:] in types_support:
22
- captions[file] = get_txt_caption(os.path.join(args.data_root, f'{file_name}.txt'))
20
+ file_name, file_ext = file.rsplit('.', 1)
21
+ if args.with_imgs:
22
+ if file_ext in types_support:
23
+ captions[file] = get_txt_caption(os.path.join(args.data_root, f'{file_name}.txt'))
24
+ else:
25
+ if file_ext == 'txt':
26
+ captions[file] = get_txt_caption(os.path.join(args.data_root, f'{file_name}.txt'))
23
27
 
24
28
  with open(os.path.join(args.data_root, f'image_captions.json'), "w", encoding='utf8') as f:
25
- json.dump(captions, f)
29
+ json.dump(captions, f, indent=2, ensure_ascii=False)
@@ -28,6 +28,7 @@ from diffusers import AutoencoderKL, UNet2DConditionModel
28
28
  from diffusers.utils.import_utils import is_xformers_available
29
29
  from omegaconf import OmegaConf
30
30
  from transformers import AutoTokenizer
31
+ from functools import partial
31
32
 
32
33
  from hcpdiff.ckpt_manager import CkptManagerPKL, CkptManagerSafe
33
34
  from hcpdiff.data import RatioBucket, DataGroup, collate_fn_ft
@@ -329,12 +330,16 @@ class Trainer:
329
330
  # set optimizer
330
331
  parameters, parameters_pt = self.get_param_group_train()
331
332
 
332
- cfg_opt = self.cfgs.train.optimizer
333
333
  if len(parameters)>0: # do fine-tuning
334
+ cfg_opt = self.cfgs.train.optimizer
334
335
  if self.cfgs.train.scale_lr:
335
336
  self.scale_lr(parameters)
336
337
 
337
- if cfg_opt.type == 'adamw_8bit':
338
+ if isinstance(cfg_opt, partial):
339
+ if 'type' in cfg_opt.keywords:
340
+ del cfg_opt.keywords['type']
341
+ self.optimizer = cfg_opt(params=parameters, lr=self.lr)
342
+ elif cfg_opt.type == 'adamw_8bit':
338
343
  import bitsandbytes as bnb
339
344
  self.optimizer = bnb.optim.AdamW8bit(params=parameters, lr=self.lr, weight_decay=cfg_opt.weight_decay)
340
345
  elif cfg_opt.type == 'deepspeed' and self.accelerator.state.deepspeed_plugin is not None:
@@ -343,23 +348,35 @@ class Trainer:
343
348
  elif cfg_opt.type == 'adamw':
344
349
  self.optimizer = torch.optim.AdamW(params=parameters, lr=self.lr, weight_decay=cfg_opt.weight_decay)
345
350
  else:
346
- self.optimizer = cfg_opt.optimizer.opt(parameters, lr=self.lr)
351
+ raise NotImplementedError(f'Unknown optimizer {cfg_opt.type}')
347
352
 
348
- self.lr_scheduler = get_scheduler(optimizer=self.optimizer, **self.cfgs.train.scheduler)
353
+ if isinstance(self.cfgs.train.scheduler, partial):
354
+ self.lr_scheduler = self.cfgs.train.scheduler(optimizer=self.optimizer)
355
+ else:
356
+ self.lr_scheduler = get_scheduler(optimizer=self.optimizer, **self.cfgs.train.scheduler)
349
357
 
350
358
  if len(parameters_pt)>0: # do prompt-tuning
359
+ cfg_opt_pt = self.cfgs.train.optimizer_pt
351
360
  if self.cfgs.train.scale_lr_pt:
352
361
  self.scale_lr(parameters_pt)
362
+ if isinstance(cfg_opt_pt, partial):
363
+ if 'type' in cfg_opt_pt.keywords:
364
+ del cfg_opt_pt.keywords['type']
365
+ self.optimizer_pt = cfg_opt_pt(params=parameters_pt, lr=self.lr)
366
+ else:
367
+ self.optimizer_pt = torch.optim.AdamW(params=parameters_pt, lr=self.lr, weight_decay=cfg_opt_pt.weight_decay)
353
368
 
354
- self.optimizer_pt = torch.optim.AdamW(params=parameters_pt, lr=self.lr, weight_decay=cfg_opt.weight_decay_pt)
355
- self.lr_scheduler_pt = get_scheduler(optimizer=self.optimizer_pt, **self.cfgs.train.scheduler_pt)
369
+ if isinstance(self.cfgs.train.scheduler_pt, partial):
370
+ self.lr_scheduler_pt = self.cfgs.train.scheduler_pt(optimizer=self.optimizer_pt)
371
+ else:
372
+ self.lr_scheduler_pt = get_scheduler(optimizer=self.optimizer_pt, **self.cfgs.train.scheduler_pt)
356
373
 
357
374
  def train(self):
358
375
  total_batch_size = sum(self.batch_size_list)*self.world_size*self.cfgs.train.gradient_accumulation_steps
359
376
 
360
377
  self.loggers.info("***** Running training *****")
361
378
  self.loggers.info(f" Num batches each epoch = {len(self.train_loader_group.loader_list[0])}")
362
- self.loggers.info(f" Num Steps = {self.cfgs.train.scheduler.num_training_steps}")
379
+ self.loggers.info(f" Num Steps = {self.cfgs.train.train_steps}")
363
380
  self.loggers.info(f" Instantaneous batch size per device = {sum(self.batch_size_list)}")
364
381
  self.loggers.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
365
382
  self.loggers.info(f" Gradient Accumulation steps = {self.cfgs.train.gradient_accumulation_steps}")
@@ -380,13 +397,13 @@ class Trainer:
380
397
  lr_model = self.lr_scheduler.get_last_lr()[0] if hasattr(self, 'lr_scheduler') else 0.
381
398
  lr_word = self.lr_scheduler_pt.get_last_lr()[0] if hasattr(self, 'lr_scheduler_pt') else 0.
382
399
  self.loggers.log(datas={
383
- 'Step':{'format':'[{}/{}]', 'data':[self.global_step, self.cfgs.train.scheduler.num_training_steps]},
400
+ 'Step':{'format':'[{}/{}]', 'data':[self.global_step, self.cfgs.train.train_steps]},
384
401
  'LR_model':{'format':'{:.2e}', 'data':[lr_model]},
385
402
  'LR_word':{'format':'{:.2e}', 'data':[lr_word]},
386
403
  'Loss':{'format':'{:.5f}', 'data':[loss_sum.mean()]},
387
404
  }, step=self.global_step)
388
405
 
389
- if self.global_step>=self.cfgs.train.scheduler.num_training_steps:
406
+ if self.global_step>=self.cfgs.train.train_steps:
390
407
  break
391
408
 
392
409
  self.wait_for_everyone()
@@ -1,12 +1,11 @@
1
1
  import os
2
- from typing import Optional, Union, Tuple, Dict, Callable
2
+ from typing import Optional, Union
3
3
 
4
4
  import torch
5
+ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, Optimizer
5
6
  from torch import nn
6
7
  from torch.optim import lr_scheduler
7
- from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, Optimizer
8
8
  from transformers import PretrainedConfig
9
- from collections import OrderedDict
10
9
 
11
10
  class TEUnetWrapper(nn.Module):
12
11
  def __init__(self, unet, TE):
@@ -33,7 +32,7 @@ def get_scheduler(
33
32
  optimizer: Optimizer,
34
33
  num_warmup_steps: Optional[int] = None,
35
34
  num_training_steps: Optional[int] = None,
36
- scheduler_kwargs = {},
35
+ scheduler_kwargs={},
37
36
  ):
38
37
  """
39
38
  Unified API to get any scheduler from its name.
@@ -64,11 +63,11 @@ def get_scheduler(
64
63
  if num_warmup_steps is None:
65
64
  raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
66
65
 
67
- #One Cycle for super convergence
68
- if name=='one_cycle':
66
+ # One Cycle for super convergence
67
+ if name == 'one_cycle':
69
68
  scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=[x['lr'] for x in optimizer.state_dict()['param_groups']],
70
- steps_per_epoch=num_training_steps, epochs=1,
71
- pct_start=num_warmup_steps/num_training_steps, **scheduler_kwargs)
69
+ steps_per_epoch=num_training_steps, epochs=1,
70
+ pct_start=num_warmup_steps/num_training_steps, **scheduler_kwargs)
72
71
  return scheduler
73
72
 
74
73
  name = SchedulerType(name)
@@ -117,44 +116,46 @@ def remove_all_hooks(model: nn.Module) -> None:
117
116
  child._backward_hooks.clear()
118
117
 
119
118
  def remove_layers(model: nn.Module, layer_class):
120
- named_modules = {k: v for k, v in model.named_modules()}
121
- for k,v in named_modules.items():
119
+ named_modules = {k:v for k, v in model.named_modules()}
120
+ for k, v in named_modules.items():
122
121
  if isinstance(v, layer_class):
123
122
  parent, name = named_modules[k.rsplit('.', 1)]
124
123
  delattr(parent, name)
125
124
  del v
126
125
 
127
126
  def load_emb(path):
128
- emb=torch.load(path, map_location='cpu')['string_to_param']['*']
127
+ emb = torch.load(path, map_location='cpu')['string_to_param']['*']
129
128
  emb.requires_grad_(False)
130
129
  return emb
131
130
 
132
- def save_emb(path, emb:torch.Tensor, replace=False):
131
+ def save_emb(path, emb: torch.Tensor, replace=False):
133
132
  name = os.path.basename(path)
134
133
  if os.path.exists(path) and not replace:
135
134
  raise FileExistsError(f'embedding "{name}" already exist.')
136
- name=name[:name.rfind('.')]
135
+ name = name[:name.rfind('.')]
137
136
  torch.save({'string_to_param':{'*':emb}, 'name':name}, path)
138
137
 
139
-
140
138
  def hook_compile(model):
141
139
  named_modules = {k:v for k, v in model.named_modules()}
142
140
 
143
141
  for name, block in named_modules.items():
144
142
  if len(block._forward_hooks)>0:
145
- for hook in block._forward_hooks.values(): # 从前往后执行
143
+ for hook in block._forward_hooks.values(): # 从前往后执行
146
144
  old_forward = block.forward
145
+
147
146
  def new_forward(*args, **kwargs):
148
147
  result = old_forward(*args, **kwargs)
149
148
  hook_result = hook(block, args, result)
150
149
  if hook_result is not None:
151
150
  result = hook_result
152
151
  return result
152
+
153
153
  block.forward = new_forward
154
154
 
155
155
  if len(block._forward_pre_hooks)>0:
156
- for hook in list(block._forward_pre_hooks.values())[::-1]: # 从前往后执行
156
+ for hook in list(block._forward_pre_hooks.values())[::-1]: # 从前往后执行
157
157
  old_forward = block.forward
158
+
158
159
  def new_forward(*args, **kwargs):
159
160
  result = hook(block, args)
160
161
  if result is not None:
@@ -163,5 +164,6 @@ def hook_compile(model):
163
164
  else:
164
165
  result = args
165
166
  return old_forward(*result, **kwargs)
167
+
166
168
  block.forward = new_forward
167
- remove_all_hooks(model)
169
+ remove_all_hooks(model)
@@ -124,12 +124,9 @@ class Visualizer:
124
124
  images = self.pipe(prompt_embeds=emb_p, negative_prompt_embeds=emb_n, **kwargs).images
125
125
  return images
126
126
 
127
- @torch.no_grad()
128
- def vis_to_dir(self, root, prompt, negative_prompt='', save_cfg=True, **kwargs):
127
+ def save_images(self, images, root, prompt, negative_prompt='', save_cfg=True):
129
128
  os.makedirs(root, exist_ok=True)
130
- num_img_exist = max([int(x.split('-',1)[0]) for x in os.listdir(root) if x.rsplit('.', 1)[-1] in types_support])+1
131
-
132
- images = self.vis_images(prompt, negative_prompt, **kwargs)
129
+ num_img_exist = max([int(x.split('-', 1)[0]) for x in os.listdir(root) if x.rsplit('.', 1)[-1] in types_support]) + 1
133
130
 
134
131
  for p, pn, img in zip(prompt, negative_prompt, images):
135
132
  img.save(os.path.join(root, f"{num_img_exist}-{to_validate_file(prompt[0])}.{self.cfgs.save.image_type}"), quality=self.cfgs.save.quality)
@@ -139,6 +136,10 @@ class Visualizer:
139
136
  f.write(OmegaConf.to_yaml(self.cfgs_raw))
140
137
  num_img_exist += 1
141
138
 
139
+ def vis_to_dir(self, root, prompt, negative_prompt='', save_cfg=True, **kwargs):
140
+ images = self.vis_images(prompt, negative_prompt, **kwargs)
141
+ self.save_images(images, root, prompt, negative_prompt, save_cfg=save_cfg)
142
+
142
143
  def show_latent(self, prompt, negative_prompt='', **kwargs):
143
144
  emb_n, emb_p = self.te_hook.encode_prompt_to_emb(negative_prompt + prompt).chunk(2)
144
145
  emb_p = self.te_hook.mult_attn(emb_p, self.token_ex.parse_attn_mult(prompt))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: hcpdiff
3
- Version: 0.3.5
3
+ Version: 0.3.7
4
4
  Summary: A universal Stable-Diffusion toolbox
5
5
  Home-page: https://github.com/7eu7d7/HCP-Diffusion
6
6
  Author: Ziyi Dong
@@ -58,6 +58,8 @@ Compared to DreamArtist, DreamArtist++ is more stable with higher image quality
58
58
  * safetensors support
59
59
  * Controlnet (support train)
60
60
  * Min-SNR loss
61
+ * Custom optimizer (Lion, DAdaptation, pytorch-optimizer, ...)
62
+ * Custom lr scheduler
61
63
 
62
64
  ## Install
63
65
 
@@ -17,6 +17,7 @@ cfgs/train/examples/CustomDiffusion.yaml
17
17
  cfgs/train/examples/DreamArtist++.yaml
18
18
  cfgs/train/examples/DreamArtist.yaml
19
19
  cfgs/train/examples/DreamBooth.yaml
20
+ cfgs/train/examples/Lion_optimizer.yaml
20
21
  cfgs/train/examples/TextualInversion.yaml
21
22
  cfgs/train/examples/controlnet.yaml
22
23
  cfgs/train/examples/fine-tuning.yaml
@@ -23,7 +23,7 @@ def get_data_files(data_dir, prefix=''):
23
23
  setuptools.setup(
24
24
  name="hcpdiff",
25
25
  py_modules=["hcpdiff"],
26
- version="0.3.5",
26
+ version="0.3.7",
27
27
  author="Ziyi Dong",
28
28
  author_email="dzy7eu7d7@gmail.com",
29
29
  description="A universal Stable-Diffusion toolbox",
@@ -1,58 +0,0 @@
1
- _base_: [cfgs/train/train_base.yaml, cfgs/train/tuning_base.yaml]
2
-
3
- unet:
4
- -
5
- lr: 1e-6
6
- layers:
7
- - '' # fine-tuning all layers in unet
8
-
9
- ## fine-tuning text-encoder
10
- #text_encoder:
11
- # - lr: 1e-6
12
- # layers:
13
- # - ''
14
-
15
- tokenizer_pt:
16
- train: null
17
-
18
- train:
19
- gradient_accumulation_steps: 1
20
- save_step: 100
21
-
22
- loss:
23
- criterion: # min SNR loss
24
- _target_: hcpdiff.loss.MinSNRLoss
25
- gamma: 2.0
26
-
27
- scheduler:
28
- name: 'constant_with_warmup'
29
- num_warmup_steps: 50
30
- num_training_steps: 600
31
-
32
- model:
33
- pretrained_model_name_or_path: 'runwayml/stable-diffusion-v1-5'
34
- tokenizer_repeats: 1
35
- ema_unet: 0
36
- ema_text_encoder: 0
37
-
38
- data:
39
- dataset1:
40
- batch_size: 4
41
- cache_latents: True
42
-
43
- source:
44
- data_source1:
45
- img_root: 'imgs/'
46
- prompt_template: 'prompt_tuning_template/object.txt'
47
- caption_file: null # path to image captions (file_words)
48
- tag_transforms:
49
- transforms:
50
- - _target_: hcpdiff.utils.caption_tools.TagShuffle
51
- - _target_: hcpdiff.utils.caption_tools.TagDropout
52
- p: 0.1
53
- - _target_: hcpdiff.utils.caption_tools.TemplateFill
54
- word_names: { }
55
- bucket:
56
- _target_: hcpdiff.data.bucket.RatioBucket.from_files # aspect ratio bucket
57
- target_area: {_target_: "builtins.eval", _args_: ['512*512']}
58
- num_bucket: 5
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes