hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +90 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +3 -3
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text.py +40 -0
  21. hcpdiff/data/source/text2img.py +36 -74
  22. hcpdiff/data/source/text2img_cond.py +9 -15
  23. hcpdiff/diffusion/__init__.py +0 -0
  24. hcpdiff/diffusion/noise/__init__.py +2 -0
  25. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  26. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  27. hcpdiff/diffusion/sampler/__init__.py +5 -0
  28. hcpdiff/diffusion/sampler/base.py +72 -0
  29. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  30. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  31. hcpdiff/diffusion/sampler/edm.py +22 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  35. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  36. hcpdiff/easy/__init__.py +2 -0
  37. hcpdiff/easy/cfg/__init__.py +3 -0
  38. hcpdiff/easy/cfg/sd15_train.py +207 -0
  39. hcpdiff/easy/cfg/sdxl_train.py +147 -0
  40. hcpdiff/easy/cfg/t2i.py +228 -0
  41. hcpdiff/easy/model/__init__.py +2 -0
  42. hcpdiff/easy/model/cnet.py +31 -0
  43. hcpdiff/easy/model/loader.py +79 -0
  44. hcpdiff/easy/sampler.py +46 -0
  45. hcpdiff/evaluate/__init__.py +1 -0
  46. hcpdiff/evaluate/previewer.py +60 -0
  47. hcpdiff/loss/__init__.py +4 -1
  48. hcpdiff/loss/base.py +41 -0
  49. hcpdiff/loss/gw.py +35 -0
  50. hcpdiff/loss/ssim.py +37 -0
  51. hcpdiff/loss/vlb.py +79 -0
  52. hcpdiff/loss/weighting.py +66 -0
  53. hcpdiff/models/__init__.py +2 -2
  54. hcpdiff/models/cfg_context.py +17 -14
  55. hcpdiff/models/compose/compose_hook.py +44 -23
  56. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  57. hcpdiff/models/compose/sdxl_composer.py +4 -4
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/init_proj.py +3 -21
  74. hcpdiff/tools/lora_convert.py +18 -17
  75. hcpdiff/tools/save_model.py +12 -0
  76. hcpdiff/tools/sd2diffusers.py +1 -1
  77. hcpdiff/train_colo.py +1 -1
  78. hcpdiff/train_deepspeed.py +1 -1
  79. hcpdiff/trainer_ac.py +79 -0
  80. hcpdiff/trainer_ac_single.py +31 -0
  81. hcpdiff/utils/__init__.py +0 -2
  82. hcpdiff/utils/inpaint_pipe.py +7 -2
  83. hcpdiff/utils/net_utils.py +29 -6
  84. hcpdiff/utils/pipe_hook.py +24 -7
  85. hcpdiff/utils/utils.py +21 -4
  86. hcpdiff/workflow/__init__.py +15 -10
  87. hcpdiff/workflow/daam/__init__.py +1 -0
  88. hcpdiff/workflow/daam/act.py +66 -0
  89. hcpdiff/workflow/daam/hook.py +109 -0
  90. hcpdiff/workflow/diffusion.py +118 -128
  91. hcpdiff/workflow/fast.py +31 -0
  92. hcpdiff/workflow/flow.py +67 -0
  93. hcpdiff/workflow/io.py +36 -130
  94. hcpdiff/workflow/model.py +46 -43
  95. hcpdiff/workflow/text.py +60 -47
  96. hcpdiff/workflow/utils.py +32 -12
  97. hcpdiff/workflow/vae.py +37 -38
  98. hcpdiff-2.2.dist-info/METADATA +299 -0
  99. hcpdiff-2.2.dist-info/RECORD +115 -0
  100. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
  101. hcpdiff-2.2.dist-info/entry_points.txt +5 -0
  102. hcpdiff/ckpt_manager/base.py +0 -16
  103. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  104. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  105. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  106. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  107. hcpdiff/data/bucket.py +0 -358
  108. hcpdiff/data/caption_loader.py +0 -80
  109. hcpdiff/data/cond_dataset.py +0 -40
  110. hcpdiff/data/crop_info_dataset.py +0 -40
  111. hcpdiff/data/data_processor.py +0 -33
  112. hcpdiff/data/pair_dataset.py +0 -146
  113. hcpdiff/data/sampler.py +0 -54
  114. hcpdiff/data/source/base.py +0 -30
  115. hcpdiff/data/utils.py +0 -80
  116. hcpdiff/deprecated/__init__.py +0 -1
  117. hcpdiff/deprecated/cfg_converter.py +0 -81
  118. hcpdiff/deprecated/lora_convert.py +0 -31
  119. hcpdiff/infer_workflow.py +0 -57
  120. hcpdiff/loggers/__init__.py +0 -13
  121. hcpdiff/loggers/base_logger.py +0 -76
  122. hcpdiff/loggers/cli_logger.py +0 -40
  123. hcpdiff/loggers/preview/__init__.py +0 -1
  124. hcpdiff/loggers/preview/image_previewer.py +0 -149
  125. hcpdiff/loggers/tensorboard_logger.py +0 -30
  126. hcpdiff/loggers/wandb_logger.py +0 -31
  127. hcpdiff/loggers/webui_logger.py +0 -9
  128. hcpdiff/loss/min_snr_loss.py +0 -52
  129. hcpdiff/models/layers.py +0 -81
  130. hcpdiff/models/plugin.py +0 -348
  131. hcpdiff/models/wrapper.py +0 -75
  132. hcpdiff/noise/__init__.py +0 -3
  133. hcpdiff/noise/noise_base.py +0 -16
  134. hcpdiff/noise/pyramid_noise.py +0 -50
  135. hcpdiff/noise/zero_terminal.py +0 -44
  136. hcpdiff/train_ac.py +0 -566
  137. hcpdiff/train_ac_single.py +0 -39
  138. hcpdiff/utils/caption_tools.py +0 -105
  139. hcpdiff/utils/cfg_net_tools.py +0 -321
  140. hcpdiff/utils/cfg_resolvers.py +0 -16
  141. hcpdiff/utils/ema.py +0 -52
  142. hcpdiff/utils/img_size_tool.py +0 -248
  143. hcpdiff/vis/__init__.py +0 -3
  144. hcpdiff/vis/base_interface.py +0 -12
  145. hcpdiff/vis/disk_interface.py +0 -48
  146. hcpdiff/vis/webui_interface.py +0 -17
  147. hcpdiff/viser_fast.py +0 -138
  148. hcpdiff/visualizer.py +0 -265
  149. hcpdiff/visualizer_reloadable.py +0 -237
  150. hcpdiff/workflow/base.py +0 -59
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  198. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  206. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  207. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  208. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  209. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
  211. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
@@ -1,932 +0,0 @@
1
- UNet2DConditionModel(
2
- (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
3
- (time_proj): Timesteps()
4
- (time_embedding): TimestepEmbedding(
5
- (linear_1): Linear(in_features=320, out_features=1280, bias=True)
6
- (act): SiLU()
7
- (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
8
- )
9
- (down_blocks): ModuleList(
10
- (0): CrossAttnDownBlock2D(
11
- (attentions): ModuleList(
12
- (0): Transformer2DModel(
13
- (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
14
- (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
15
- (transformer_blocks): ModuleList(
16
- (0): BasicTransformerBlock(
17
- (attn1): CrossAttention(
18
- (to_q): Linear(in_features=320, out_features=320, bias=False)
19
- (to_k): Linear(in_features=320, out_features=320, bias=False)
20
- (to_v): Linear(in_features=320, out_features=320, bias=False)
21
- (to_out): ModuleList(
22
- (0): Linear(in_features=320, out_features=320, bias=True)
23
- (1): Dropout(p=0.0, inplace=False)
24
- )
25
- )
26
- (ff): FeedForward(
27
- (net): ModuleList(
28
- (0): GEGLU(
29
- (proj): Linear(in_features=320, out_features=2560, bias=True)
30
- )
31
- (1): Dropout(p=0.0, inplace=False)
32
- (2): Linear(in_features=1280, out_features=320, bias=True)
33
- )
34
- )
35
- (attn2): CrossAttention(
36
- (to_q): Linear(in_features=320, out_features=320, bias=False)
37
- (to_k): Linear(in_features=768, out_features=320, bias=False)
38
- (to_v): Linear(in_features=768, out_features=320, bias=False)
39
- (to_out): ModuleList(
40
- (0): Linear(in_features=320, out_features=320, bias=True)
41
- (1): Dropout(p=0.0, inplace=False)
42
- )
43
- )
44
- (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
45
- (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
46
- (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
47
- )
48
- )
49
- (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
50
- )
51
- (1): Transformer2DModel(
52
- (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
53
- (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
54
- (transformer_blocks): ModuleList(
55
- (0): BasicTransformerBlock(
56
- (attn1): CrossAttention(
57
- (to_q): Linear(in_features=320, out_features=320, bias=False)
58
- (to_k): Linear(in_features=320, out_features=320, bias=False)
59
- (to_v): Linear(in_features=320, out_features=320, bias=False)
60
- (to_out): ModuleList(
61
- (0): Linear(in_features=320, out_features=320, bias=True)
62
- (1): Dropout(p=0.0, inplace=False)
63
- )
64
- )
65
- (ff): FeedForward(
66
- (net): ModuleList(
67
- (0): GEGLU(
68
- (proj): Linear(in_features=320, out_features=2560, bias=True)
69
- )
70
- (1): Dropout(p=0.0, inplace=False)
71
- (2): Linear(in_features=1280, out_features=320, bias=True)
72
- )
73
- )
74
- (attn2): CrossAttention(
75
- (to_q): Linear(in_features=320, out_features=320, bias=False)
76
- (to_k): Linear(in_features=768, out_features=320, bias=False)
77
- (to_v): Linear(in_features=768, out_features=320, bias=False)
78
- (to_out): ModuleList(
79
- (0): Linear(in_features=320, out_features=320, bias=True)
80
- (1): Dropout(p=0.0, inplace=False)
81
- )
82
- )
83
- (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
84
- (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
85
- (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
86
- )
87
- )
88
- (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
89
- )
90
- )
91
- (resnets): ModuleList(
92
- (0): ResnetBlock2D(
93
- (norm1): GroupNorm(32, 320, eps=1e-05, affine=True)
94
- (conv1): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
95
- (time_emb_proj): Linear(in_features=1280, out_features=320, bias=True)
96
- (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)
97
- (dropout): Dropout(p=0.0, inplace=False)
98
- (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
99
- (nonlinearity): SiLU()
100
- )
101
- (1): ResnetBlock2D(
102
- (norm1): GroupNorm(32, 320, eps=1e-05, affine=True)
103
- (conv1): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
104
- (time_emb_proj): Linear(in_features=1280, out_features=320, bias=True)
105
- (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)
106
- (dropout): Dropout(p=0.0, inplace=False)
107
- (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
108
- (nonlinearity): SiLU()
109
- )
110
- )
111
- (downsamplers): ModuleList(
112
- (0): Downsample2D(
113
- (conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
114
- )
115
- )
116
- )
117
- (1): CrossAttnDownBlock2D(
118
- (attentions): ModuleList(
119
- (0): Transformer2DModel(
120
- (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
121
- (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
122
- (transformer_blocks): ModuleList(
123
- (0): BasicTransformerBlock(
124
- (attn1): CrossAttention(
125
- (to_q): Linear(in_features=640, out_features=640, bias=False)
126
- (to_k): Linear(in_features=640, out_features=640, bias=False)
127
- (to_v): Linear(in_features=640, out_features=640, bias=False)
128
- (to_out): ModuleList(
129
- (0): Linear(in_features=640, out_features=640, bias=True)
130
- (1): Dropout(p=0.0, inplace=False)
131
- )
132
- )
133
- (ff): FeedForward(
134
- (net): ModuleList(
135
- (0): GEGLU(
136
- (proj): Linear(in_features=640, out_features=5120, bias=True)
137
- )
138
- (1): Dropout(p=0.0, inplace=False)
139
- (2): Linear(in_features=2560, out_features=640, bias=True)
140
- )
141
- )
142
- (attn2): CrossAttention(
143
- (to_q): Linear(in_features=640, out_features=640, bias=False)
144
- (to_k): Linear(in_features=768, out_features=640, bias=False)
145
- (to_v): Linear(in_features=768, out_features=640, bias=False)
146
- (to_out): ModuleList(
147
- (0): Linear(in_features=640, out_features=640, bias=True)
148
- (1): Dropout(p=0.0, inplace=False)
149
- )
150
- )
151
- (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
152
- (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
153
- (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
154
- )
155
- )
156
- (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
157
- )
158
- (1): Transformer2DModel(
159
- (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
160
- (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
161
- (transformer_blocks): ModuleList(
162
- (0): BasicTransformerBlock(
163
- (attn1): CrossAttention(
164
- (to_q): Linear(in_features=640, out_features=640, bias=False)
165
- (to_k): Linear(in_features=640, out_features=640, bias=False)
166
- (to_v): Linear(in_features=640, out_features=640, bias=False)
167
- (to_out): ModuleList(
168
- (0): Linear(in_features=640, out_features=640, bias=True)
169
- (1): Dropout(p=0.0, inplace=False)
170
- )
171
- )
172
- (ff): FeedForward(
173
- (net): ModuleList(
174
- (0): GEGLU(
175
- (proj): Linear(in_features=640, out_features=5120, bias=True)
176
- )
177
- (1): Dropout(p=0.0, inplace=False)
178
- (2): Linear(in_features=2560, out_features=640, bias=True)
179
- )
180
- )
181
- (attn2): CrossAttention(
182
- (to_q): Linear(in_features=640, out_features=640, bias=False)
183
- (to_k): Linear(in_features=768, out_features=640, bias=False)
184
- (to_v): Linear(in_features=768, out_features=640, bias=False)
185
- (to_out): ModuleList(
186
- (0): Linear(in_features=640, out_features=640, bias=True)
187
- (1): Dropout(p=0.0, inplace=False)
188
- )
189
- )
190
- (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
191
- (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
192
- (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
193
- )
194
- )
195
- (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
196
- )
197
- )
198
- (resnets): ModuleList(
199
- (0): ResnetBlock2D(
200
- (norm1): GroupNorm(32, 320, eps=1e-05, affine=True)
201
- (conv1): Conv2d(320, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
202
- (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)
203
- (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
204
- (dropout): Dropout(p=0.0, inplace=False)
205
- (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
206
- (nonlinearity): SiLU()
207
- (conv_shortcut): Conv2d(320, 640, kernel_size=(1, 1), stride=(1, 1))
208
- )
209
- (1): ResnetBlock2D(
210
- (norm1): GroupNorm(32, 640, eps=1e-05, affine=True)
211
- (conv1): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
212
- (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)
213
- (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
214
- (dropout): Dropout(p=0.0, inplace=False)
215
- (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
216
- (nonlinearity): SiLU()
217
- )
218
- )
219
- (downsamplers): ModuleList(
220
- (0): Downsample2D(
221
- (conv): Conv2d(640, 640, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
222
- )
223
- )
224
- )
225
- (2): CrossAttnDownBlock2D(
226
- (attentions): ModuleList(
227
- (0): Transformer2DModel(
228
- (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
229
- (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
230
- (transformer_blocks): ModuleList(
231
- (0): BasicTransformerBlock(
232
- (attn1): CrossAttention(
233
- (to_q): Linear(in_features=1280, out_features=1280, bias=False)
234
- (to_k): Linear(in_features=1280, out_features=1280, bias=False)
235
- (to_v): Linear(in_features=1280, out_features=1280, bias=False)
236
- (to_out): ModuleList(
237
- (0): Linear(in_features=1280, out_features=1280, bias=True)
238
- (1): Dropout(p=0.0, inplace=False)
239
- )
240
- )
241
- (ff): FeedForward(
242
- (net): ModuleList(
243
- (0): GEGLU(
244
- (proj): Linear(in_features=1280, out_features=10240, bias=True)
245
- )
246
- (1): Dropout(p=0.0, inplace=False)
247
- (2): Linear(in_features=5120, out_features=1280, bias=True)
248
- )
249
- )
250
- (attn2): CrossAttention(
251
- (to_q): Linear(in_features=1280, out_features=1280, bias=False)
252
- (to_k): Linear(in_features=768, out_features=1280, bias=False)
253
- (to_v): Linear(in_features=768, out_features=1280, bias=False)
254
- (to_out): ModuleList(
255
- (0): Linear(in_features=1280, out_features=1280, bias=True)
256
- (1): Dropout(p=0.0, inplace=False)
257
- )
258
- )
259
- (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
260
- (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
261
- (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
262
- )
263
- )
264
- (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
265
- )
266
- (1): Transformer2DModel(
267
- (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
268
- (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
269
- (transformer_blocks): ModuleList(
270
- (0): BasicTransformerBlock(
271
- (attn1): CrossAttention(
272
- (to_q): Linear(in_features=1280, out_features=1280, bias=False)
273
- (to_k): Linear(in_features=1280, out_features=1280, bias=False)
274
- (to_v): Linear(in_features=1280, out_features=1280, bias=False)
275
- (to_out): ModuleList(
276
- (0): Linear(in_features=1280, out_features=1280, bias=True)
277
- (1): Dropout(p=0.0, inplace=False)
278
- )
279
- )
280
- (ff): FeedForward(
281
- (net): ModuleList(
282
- (0): GEGLU(
283
- (proj): Linear(in_features=1280, out_features=10240, bias=True)
284
- )
285
- (1): Dropout(p=0.0, inplace=False)
286
- (2): Linear(in_features=5120, out_features=1280, bias=True)
287
- )
288
- )
289
- (attn2): CrossAttention(
290
- (to_q): Linear(in_features=1280, out_features=1280, bias=False)
291
- (to_k): Linear(in_features=768, out_features=1280, bias=False)
292
- (to_v): Linear(in_features=768, out_features=1280, bias=False)
293
- (to_out): ModuleList(
294
- (0): Linear(in_features=1280, out_features=1280, bias=True)
295
- (1): Dropout(p=0.0, inplace=False)
296
- )
297
- )
298
- (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
299
- (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
300
- (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
301
- )
302
- )
303
- (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
304
- )
305
- )
306
- (resnets): ModuleList(
307
- (0): ResnetBlock2D(
308
- (norm1): GroupNorm(32, 640, eps=1e-05, affine=True)
309
- (conv1): Conv2d(640, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
310
- (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
311
- (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
312
- (dropout): Dropout(p=0.0, inplace=False)
313
- (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
314
- (nonlinearity): SiLU()
315
- (conv_shortcut): Conv2d(640, 1280, kernel_size=(1, 1), stride=(1, 1))
316
- )
317
- (1): ResnetBlock2D(
318
- (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
319
- (conv1): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
320
- (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
321
- (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
322
- (dropout): Dropout(p=0.0, inplace=False)
323
- (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
324
- (nonlinearity): SiLU()
325
- )
326
- )
327
- (downsamplers): ModuleList(
328
- (0): Downsample2D(
329
- (conv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
330
- )
331
- )
332
- )
333
- (3): DownBlock2D(
334
- (resnets): ModuleList(
335
- (0): ResnetBlock2D(
336
- (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
337
- (conv1): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
338
- (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
339
- (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
340
- (dropout): Dropout(p=0.0, inplace=False)
341
- (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
342
- (nonlinearity): SiLU()
343
- )
344
- (1): ResnetBlock2D(
345
- (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
346
- (conv1): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
347
- (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
348
- (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
349
- (dropout): Dropout(p=0.0, inplace=False)
350
- (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
351
- (nonlinearity): SiLU()
352
- )
353
- )
354
- )
355
- )
356
- (up_blocks): ModuleList(
357
- (0): UpBlock2D(
358
- (resnets): ModuleList(
359
- (0): ResnetBlock2D(
360
- (norm1): GroupNorm(32, 2560, eps=1e-05, affine=True)
361
- (conv1): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
362
- (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
363
- (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
364
- (dropout): Dropout(p=0.0, inplace=False)
365
- (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
366
- (nonlinearity): SiLU()
367
- (conv_shortcut): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
368
- )
369
- (1): ResnetBlock2D(
370
- (norm1): GroupNorm(32, 2560, eps=1e-05, affine=True)
371
- (conv1): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
372
- (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
373
- (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
374
- (dropout): Dropout(p=0.0, inplace=False)
375
- (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
376
- (nonlinearity): SiLU()
377
- (conv_shortcut): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
378
- )
379
- (2): ResnetBlock2D(
380
- (norm1): GroupNorm(32, 2560, eps=1e-05, affine=True)
381
- (conv1): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
382
- (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
383
- (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
384
- (dropout): Dropout(p=0.0, inplace=False)
385
- (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
386
- (nonlinearity): SiLU()
387
- (conv_shortcut): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
388
- )
389
- )
390
- (upsamplers): ModuleList(
391
- (0): Upsample2D(
392
- (conv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
393
- )
394
- )
395
- )
396
- (1): CrossAttnUpBlock2D(
397
- (attentions): ModuleList(
398
- (0): Transformer2DModel(
399
- (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
400
- (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
401
- (transformer_blocks): ModuleList(
402
- (0): BasicTransformerBlock(
403
- (attn1): CrossAttention(
404
- (to_q): Linear(in_features=1280, out_features=1280, bias=False)
405
- (to_k): Linear(in_features=1280, out_features=1280, bias=False)
406
- (to_v): Linear(in_features=1280, out_features=1280, bias=False)
407
- (to_out): ModuleList(
408
- (0): Linear(in_features=1280, out_features=1280, bias=True)
409
- (1): Dropout(p=0.0, inplace=False)
410
- )
411
- )
412
- (ff): FeedForward(
413
- (net): ModuleList(
414
- (0): GEGLU(
415
- (proj): Linear(in_features=1280, out_features=10240, bias=True)
416
- )
417
- (1): Dropout(p=0.0, inplace=False)
418
- (2): Linear(in_features=5120, out_features=1280, bias=True)
419
- )
420
- )
421
- (attn2): CrossAttention(
422
- (to_q): Linear(in_features=1280, out_features=1280, bias=False)
423
- (to_k): Linear(in_features=768, out_features=1280, bias=False)
424
- (to_v): Linear(in_features=768, out_features=1280, bias=False)
425
- (to_out): ModuleList(
426
- (0): Linear(in_features=1280, out_features=1280, bias=True)
427
- (1): Dropout(p=0.0, inplace=False)
428
- )
429
- )
430
- (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
431
- (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
432
- (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
433
- )
434
- )
435
- (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
436
- )
437
- (1): Transformer2DModel(
438
- (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
439
- (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
440
- (transformer_blocks): ModuleList(
441
- (0): BasicTransformerBlock(
442
- (attn1): CrossAttention(
443
- (to_q): Linear(in_features=1280, out_features=1280, bias=False)
444
- (to_k): Linear(in_features=1280, out_features=1280, bias=False)
445
- (to_v): Linear(in_features=1280, out_features=1280, bias=False)
446
- (to_out): ModuleList(
447
- (0): Linear(in_features=1280, out_features=1280, bias=True)
448
- (1): Dropout(p=0.0, inplace=False)
449
- )
450
- )
451
- (ff): FeedForward(
452
- (net): ModuleList(
453
- (0): GEGLU(
454
- (proj): Linear(in_features=1280, out_features=10240, bias=True)
455
- )
456
- (1): Dropout(p=0.0, inplace=False)
457
- (2): Linear(in_features=5120, out_features=1280, bias=True)
458
- )
459
- )
460
- (attn2): CrossAttention(
461
- (to_q): Linear(in_features=1280, out_features=1280, bias=False)
462
- (to_k): Linear(in_features=768, out_features=1280, bias=False)
463
- (to_v): Linear(in_features=768, out_features=1280, bias=False)
464
- (to_out): ModuleList(
465
- (0): Linear(in_features=1280, out_features=1280, bias=True)
466
- (1): Dropout(p=0.0, inplace=False)
467
- )
468
- )
469
- (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
470
- (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
471
- (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
472
- )
473
- )
474
- (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
475
- )
476
- (2): Transformer2DModel(
477
- (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
478
- (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
479
- (transformer_blocks): ModuleList(
480
- (0): BasicTransformerBlock(
481
- (attn1): CrossAttention(
482
- (to_q): Linear(in_features=1280, out_features=1280, bias=False)
483
- (to_k): Linear(in_features=1280, out_features=1280, bias=False)
484
- (to_v): Linear(in_features=1280, out_features=1280, bias=False)
485
- (to_out): ModuleList(
486
- (0): Linear(in_features=1280, out_features=1280, bias=True)
487
- (1): Dropout(p=0.0, inplace=False)
488
- )
489
- )
490
- (ff): FeedForward(
491
- (net): ModuleList(
492
- (0): GEGLU(
493
- (proj): Linear(in_features=1280, out_features=10240, bias=True)
494
- )
495
- (1): Dropout(p=0.0, inplace=False)
496
- (2): Linear(in_features=5120, out_features=1280, bias=True)
497
- )
498
- )
499
- (attn2): CrossAttention(
500
- (to_q): Linear(in_features=1280, out_features=1280, bias=False)
501
- (to_k): Linear(in_features=768, out_features=1280, bias=False)
502
- (to_v): Linear(in_features=768, out_features=1280, bias=False)
503
- (to_out): ModuleList(
504
- (0): Linear(in_features=1280, out_features=1280, bias=True)
505
- (1): Dropout(p=0.0, inplace=False)
506
- )
507
- )
508
- (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
509
- (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
510
- (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
511
- )
512
- )
513
- (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
514
- )
515
- )
516
- (resnets): ModuleList(
517
- (0): ResnetBlock2D(
518
- (norm1): GroupNorm(32, 2560, eps=1e-05, affine=True)
519
- (conv1): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
520
- (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
521
- (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
522
- (dropout): Dropout(p=0.0, inplace=False)
523
- (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
524
- (nonlinearity): SiLU()
525
- (conv_shortcut): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
526
- )
527
- (1): ResnetBlock2D(
528
- (norm1): GroupNorm(32, 2560, eps=1e-05, affine=True)
529
- (conv1): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
530
- (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
531
- (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
532
- (dropout): Dropout(p=0.0, inplace=False)
533
- (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
534
- (nonlinearity): SiLU()
535
- (conv_shortcut): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
536
- )
537
- (2): ResnetBlock2D(
538
- (norm1): GroupNorm(32, 1920, eps=1e-05, affine=True)
539
- (conv1): Conv2d(1920, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
540
- (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
541
- (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
542
- (dropout): Dropout(p=0.0, inplace=False)
543
- (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
544
- (nonlinearity): SiLU()
545
- (conv_shortcut): Conv2d(1920, 1280, kernel_size=(1, 1), stride=(1, 1))
546
- )
547
- )
548
- (upsamplers): ModuleList(
549
- (0): Upsample2D(
550
- (conv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
551
- )
552
- )
553
- )
554
- (2): CrossAttnUpBlock2D(
555
- (attentions): ModuleList(
556
- (0): Transformer2DModel(
557
- (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
558
- (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
559
- (transformer_blocks): ModuleList(
560
- (0): BasicTransformerBlock(
561
- (attn1): CrossAttention(
562
- (to_q): Linear(in_features=640, out_features=640, bias=False)
563
- (to_k): Linear(in_features=640, out_features=640, bias=False)
564
- (to_v): Linear(in_features=640, out_features=640, bias=False)
565
- (to_out): ModuleList(
566
- (0): Linear(in_features=640, out_features=640, bias=True)
567
- (1): Dropout(p=0.0, inplace=False)
568
- )
569
- )
570
- (ff): FeedForward(
571
- (net): ModuleList(
572
- (0): GEGLU(
573
- (proj): Linear(in_features=640, out_features=5120, bias=True)
574
- )
575
- (1): Dropout(p=0.0, inplace=False)
576
- (2): Linear(in_features=2560, out_features=640, bias=True)
577
- )
578
- )
579
- (attn2): CrossAttention(
580
- (to_q): Linear(in_features=640, out_features=640, bias=False)
581
- (to_k): Linear(in_features=768, out_features=640, bias=False)
582
- (to_v): Linear(in_features=768, out_features=640, bias=False)
583
- (to_out): ModuleList(
584
- (0): Linear(in_features=640, out_features=640, bias=True)
585
- (1): Dropout(p=0.0, inplace=False)
586
- )
587
- )
588
- (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
589
- (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
590
- (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
591
- )
592
- )
593
- (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
594
- )
595
- (1): Transformer2DModel(
596
- (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
597
- (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
598
- (transformer_blocks): ModuleList(
599
- (0): BasicTransformerBlock(
600
- (attn1): CrossAttention(
601
- (to_q): Linear(in_features=640, out_features=640, bias=False)
602
- (to_k): Linear(in_features=640, out_features=640, bias=False)
603
- (to_v): Linear(in_features=640, out_features=640, bias=False)
604
- (to_out): ModuleList(
605
- (0): Linear(in_features=640, out_features=640, bias=True)
606
- (1): Dropout(p=0.0, inplace=False)
607
- )
608
- )
609
- (ff): FeedForward(
610
- (net): ModuleList(
611
- (0): GEGLU(
612
- (proj): Linear(in_features=640, out_features=5120, bias=True)
613
- )
614
- (1): Dropout(p=0.0, inplace=False)
615
- (2): Linear(in_features=2560, out_features=640, bias=True)
616
- )
617
- )
618
- (attn2): CrossAttention(
619
- (to_q): Linear(in_features=640, out_features=640, bias=False)
620
- (to_k): Linear(in_features=768, out_features=640, bias=False)
621
- (to_v): Linear(in_features=768, out_features=640, bias=False)
622
- (to_out): ModuleList(
623
- (0): Linear(in_features=640, out_features=640, bias=True)
624
- (1): Dropout(p=0.0, inplace=False)
625
- )
626
- )
627
- (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
628
- (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
629
- (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
630
- )
631
- )
632
- (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
633
- )
634
- (2): Transformer2DModel(
635
- (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
636
- (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
637
- (transformer_blocks): ModuleList(
638
- (0): BasicTransformerBlock(
639
- (attn1): CrossAttention(
640
- (to_q): Linear(in_features=640, out_features=640, bias=False)
641
- (to_k): Linear(in_features=640, out_features=640, bias=False)
642
- (to_v): Linear(in_features=640, out_features=640, bias=False)
643
- (to_out): ModuleList(
644
- (0): Linear(in_features=640, out_features=640, bias=True)
645
- (1): Dropout(p=0.0, inplace=False)
646
- )
647
- )
648
- (ff): FeedForward(
649
- (net): ModuleList(
650
- (0): GEGLU(
651
- (proj): Linear(in_features=640, out_features=5120, bias=True)
652
- )
653
- (1): Dropout(p=0.0, inplace=False)
654
- (2): Linear(in_features=2560, out_features=640, bias=True)
655
- )
656
- )
657
- (attn2): CrossAttention(
658
- (to_q): Linear(in_features=640, out_features=640, bias=False)
659
- (to_k): Linear(in_features=768, out_features=640, bias=False)
660
- (to_v): Linear(in_features=768, out_features=640, bias=False)
661
- (to_out): ModuleList(
662
- (0): Linear(in_features=640, out_features=640, bias=True)
663
- (1): Dropout(p=0.0, inplace=False)
664
- )
665
- )
666
- (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
667
- (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
668
- (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
669
- )
670
- )
671
- (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
672
- )
673
- )
674
- (resnets): ModuleList(
675
- (0): ResnetBlock2D(
676
- (norm1): GroupNorm(32, 1920, eps=1e-05, affine=True)
677
- (conv1): Conv2d(1920, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
678
- (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)
679
- (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
680
- (dropout): Dropout(p=0.0, inplace=False)
681
- (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
682
- (nonlinearity): SiLU()
683
- (conv_shortcut): Conv2d(1920, 640, kernel_size=(1, 1), stride=(1, 1))
684
- )
685
- (1): ResnetBlock2D(
686
- (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
687
- (conv1): Conv2d(1280, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
688
- (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)
689
- (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
690
- (dropout): Dropout(p=0.0, inplace=False)
691
- (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
692
- (nonlinearity): SiLU()
693
- (conv_shortcut): Conv2d(1280, 640, kernel_size=(1, 1), stride=(1, 1))
694
- )
695
- (2): ResnetBlock2D(
696
- (norm1): GroupNorm(32, 960, eps=1e-05, affine=True)
697
- (conv1): Conv2d(960, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
698
- (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)
699
- (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
700
- (dropout): Dropout(p=0.0, inplace=False)
701
- (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
702
- (nonlinearity): SiLU()
703
- (conv_shortcut): Conv2d(960, 640, kernel_size=(1, 1), stride=(1, 1))
704
- )
705
- )
706
- (upsamplers): ModuleList(
707
- (0): Upsample2D(
708
- (conv): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
709
- )
710
- )
711
- )
712
- (3): CrossAttnUpBlock2D(
713
- (attentions): ModuleList(
714
- (0): Transformer2DModel(
715
- (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
716
- (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
717
- (transformer_blocks): ModuleList(
718
- (0): BasicTransformerBlock(
719
- (attn1): CrossAttention(
720
- (to_q): Linear(in_features=320, out_features=320, bias=False)
721
- (to_k): Linear(in_features=320, out_features=320, bias=False)
722
- (to_v): Linear(in_features=320, out_features=320, bias=False)
723
- (to_out): ModuleList(
724
- (0): Linear(in_features=320, out_features=320, bias=True)
725
- (1): Dropout(p=0.0, inplace=False)
726
- )
727
- )
728
- (ff): FeedForward(
729
- (net): ModuleList(
730
- (0): GEGLU(
731
- (proj): Linear(in_features=320, out_features=2560, bias=True)
732
- )
733
- (1): Dropout(p=0.0, inplace=False)
734
- (2): Linear(in_features=1280, out_features=320, bias=True)
735
- )
736
- )
737
- (attn2): CrossAttention(
738
- (to_q): Linear(in_features=320, out_features=320, bias=False)
739
- (to_k): Linear(in_features=768, out_features=320, bias=False)
740
- (to_v): Linear(in_features=768, out_features=320, bias=False)
741
- (to_out): ModuleList(
742
- (0): Linear(in_features=320, out_features=320, bias=True)
743
- (1): Dropout(p=0.0, inplace=False)
744
- )
745
- )
746
- (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
747
- (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
748
- (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
749
- )
750
- )
751
- (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
752
- )
753
- (1): Transformer2DModel(
754
- (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
755
- (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
756
- (transformer_blocks): ModuleList(
757
- (0): BasicTransformerBlock(
758
- (attn1): CrossAttention(
759
- (to_q): Linear(in_features=320, out_features=320, bias=False)
760
- (to_k): Linear(in_features=320, out_features=320, bias=False)
761
- (to_v): Linear(in_features=320, out_features=320, bias=False)
762
- (to_out): ModuleList(
763
- (0): Linear(in_features=320, out_features=320, bias=True)
764
- (1): Dropout(p=0.0, inplace=False)
765
- )
766
- )
767
- (ff): FeedForward(
768
- (net): ModuleList(
769
- (0): GEGLU(
770
- (proj): Linear(in_features=320, out_features=2560, bias=True)
771
- )
772
- (1): Dropout(p=0.0, inplace=False)
773
- (2): Linear(in_features=1280, out_features=320, bias=True)
774
- )
775
- )
776
- (attn2): CrossAttention(
777
- (to_q): Linear(in_features=320, out_features=320, bias=False)
778
- (to_k): Linear(in_features=768, out_features=320, bias=False)
779
- (to_v): Linear(in_features=768, out_features=320, bias=False)
780
- (to_out): ModuleList(
781
- (0): Linear(in_features=320, out_features=320, bias=True)
782
- (1): Dropout(p=0.0, inplace=False)
783
- )
784
- )
785
- (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
786
- (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
787
- (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
788
- )
789
- )
790
- (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
791
- )
792
- (2): Transformer2DModel(
793
- (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
794
- (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
795
- (transformer_blocks): ModuleList(
796
- (0): BasicTransformerBlock(
797
- (attn1): CrossAttention(
798
- (to_q): Linear(in_features=320, out_features=320, bias=False)
799
- (to_k): Linear(in_features=320, out_features=320, bias=False)
800
- (to_v): Linear(in_features=320, out_features=320, bias=False)
801
- (to_out): ModuleList(
802
- (0): Linear(in_features=320, out_features=320, bias=True)
803
- (1): Dropout(p=0.0, inplace=False)
804
- )
805
- )
806
- (ff): FeedForward(
807
- (net): ModuleList(
808
- (0): GEGLU(
809
- (proj): Linear(in_features=320, out_features=2560, bias=True)
810
- )
811
- (1): Dropout(p=0.0, inplace=False)
812
- (2): Linear(in_features=1280, out_features=320, bias=True)
813
- )
814
- )
815
- (attn2): CrossAttention(
816
- (to_q): Linear(in_features=320, out_features=320, bias=False)
817
- (to_k): Linear(in_features=768, out_features=320, bias=False)
818
- (to_v): Linear(in_features=768, out_features=320, bias=False)
819
- (to_out): ModuleList(
820
- (0): Linear(in_features=320, out_features=320, bias=True)
821
- (1): Dropout(p=0.0, inplace=False)
822
- )
823
- )
824
- (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
825
- (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
826
- (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
827
- )
828
- )
829
- (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
830
- )
831
- )
832
- (resnets): ModuleList(
833
- (0): ResnetBlock2D(
834
- (norm1): GroupNorm(32, 960, eps=1e-05, affine=True)
835
- (conv1): Conv2d(960, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
836
- (time_emb_proj): Linear(in_features=1280, out_features=320, bias=True)
837
- (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)
838
- (dropout): Dropout(p=0.0, inplace=False)
839
- (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
840
- (nonlinearity): SiLU()
841
- (conv_shortcut): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1))
842
- )
843
- (1): ResnetBlock2D(
844
- (norm1): GroupNorm(32, 640, eps=1e-05, affine=True)
845
- (conv1): Conv2d(640, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
846
- (time_emb_proj): Linear(in_features=1280, out_features=320, bias=True)
847
- (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)
848
- (dropout): Dropout(p=0.0, inplace=False)
849
- (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
850
- (nonlinearity): SiLU()
851
- (conv_shortcut): Conv2d(640, 320, kernel_size=(1, 1), stride=(1, 1))
852
- )
853
- (2): ResnetBlock2D(
854
- (norm1): GroupNorm(32, 640, eps=1e-05, affine=True)
855
- (conv1): Conv2d(640, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
856
- (time_emb_proj): Linear(in_features=1280, out_features=320, bias=True)
857
- (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)
858
- (dropout): Dropout(p=0.0, inplace=False)
859
- (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
860
- (nonlinearity): SiLU()
861
- (conv_shortcut): Conv2d(640, 320, kernel_size=(1, 1), stride=(1, 1))
862
- )
863
- )
864
- )
865
- )
866
- (mid_block): UNetMidBlock2DCrossAttn(
867
- (attentions): ModuleList(
868
- (0): Transformer2DModel(
869
- (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
870
- (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
871
- (transformer_blocks): ModuleList(
872
- (0): BasicTransformerBlock(
873
- (attn1): CrossAttention(
874
- (to_q): Linear(in_features=1280, out_features=1280, bias=False)
875
- (to_k): Linear(in_features=1280, out_features=1280, bias=False)
876
- (to_v): Linear(in_features=1280, out_features=1280, bias=False)
877
- (to_out): ModuleList(
878
- (0): Linear(in_features=1280, out_features=1280, bias=True)
879
- (1): Dropout(p=0.0, inplace=False)
880
- )
881
- )
882
- (ff): FeedForward(
883
- (net): ModuleList(
884
- (0): GEGLU(
885
- (proj): Linear(in_features=1280, out_features=10240, bias=True)
886
- )
887
- (1): Dropout(p=0.0, inplace=False)
888
- (2): Linear(in_features=5120, out_features=1280, bias=True)
889
- )
890
- )
891
- (attn2): CrossAttention(
892
- (to_q): Linear(in_features=1280, out_features=1280, bias=False)
893
- (to_k): Linear(in_features=768, out_features=1280, bias=False)
894
- (to_v): Linear(in_features=768, out_features=1280, bias=False)
895
- (to_out): ModuleList(
896
- (0): Linear(in_features=1280, out_features=1280, bias=True)
897
- (1): Dropout(p=0.0, inplace=False)
898
- )
899
- )
900
- (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
901
- (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
902
- (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
903
- )
904
- )
905
- (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
906
- )
907
- )
908
- (resnets): ModuleList(
909
- (0): ResnetBlock2D(
910
- (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
911
- (conv1): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
912
- (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
913
- (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
914
- (dropout): Dropout(p=0.0, inplace=False)
915
- (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
916
- (nonlinearity): SiLU()
917
- )
918
- (1): ResnetBlock2D(
919
- (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
920
- (conv1): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
921
- (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
922
- (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
923
- (dropout): Dropout(p=0.0, inplace=False)
924
- (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
925
- (nonlinearity): SiLU()
926
- )
927
- )
928
- )
929
- (conv_norm_out): GroupNorm(32, 320, eps=1e-05, affine=True)
930
- (conv_act): SiLU()
931
- (conv_out): Conv2d(320, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
932
- )