ai-edge-torch-nightly 0.2.0.dev20240611__py3-none-any.whl → 0.2.0.dev20240619__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (24) hide show
  1. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +19 -0
  2. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -2
  3. ai_edge_torch/debug/__init__.py +1 -0
  4. ai_edge_torch/debug/culprit.py +70 -29
  5. ai_edge_torch/debug/test/test_search_model.py +50 -0
  6. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +9 -6
  7. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +33 -25
  8. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +523 -202
  9. ai_edge_torch/generative/examples/t5/t5_attention.py +10 -39
  10. ai_edge_torch/generative/layers/attention.py +154 -26
  11. ai_edge_torch/generative/layers/model_config.py +3 -0
  12. ai_edge_torch/generative/layers/unet/blocks_2d.py +473 -49
  13. ai_edge_torch/generative/layers/unet/builder.py +20 -2
  14. ai_edge_torch/generative/layers/unet/model_config.py +157 -5
  15. ai_edge_torch/generative/test/test_model_conversion.py +24 -0
  16. ai_edge_torch/generative/test/test_quantize.py +1 -0
  17. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +860 -0
  18. ai_edge_torch/generative/utilities/t5_loader.py +33 -17
  19. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/METADATA +1 -1
  20. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/RECORD +23 -22
  21. ai_edge_torch/generative/utilities/autoencoder_loader.py +0 -298
  22. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/LICENSE +0 -0
  23. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/WHEEL +0 -0
  24. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/top_level.txt +0 -0
@@ -15,230 +15,551 @@
15
15
 
16
16
  import torch
17
17
  from torch import nn
18
- from torch.nn import functional as F
19
18
 
20
- from ai_edge_torch.generative.examples.stable_diffusion.attention import CrossAttention # NOQA
21
- from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
22
-
23
-
24
- class TimeEmbedding(nn.Module):
25
-
26
- def __init__(self, n_embd):
27
- super().__init__()
28
- self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
29
- self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)
30
-
31
- def forward(self, x):
32
- x = self.linear_1(x)
33
- x = F.silu(x)
34
- x = self.linear_2(x)
35
- return x
36
-
37
-
38
- class ResidualBlock(nn.Module):
39
-
40
- def __init__(self, in_channels, out_channels, n_time=1280):
41
- super().__init__()
42
- self.groupnorm_feature = nn.GroupNorm(32, in_channels)
43
- self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
44
- self.linear_time = nn.Linear(n_time, out_channels)
45
-
46
- self.groupnorm_merged = nn.GroupNorm(32, out_channels)
47
- self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
48
-
49
- if in_channels == out_channels:
50
- self.residual_layer = nn.Identity()
51
- else:
52
- self.residual_layer = nn.Conv2d(
53
- in_channels, out_channels, kernel_size=1, padding=0
54
- )
55
-
56
- def forward(self, feature, time):
57
- residue = feature
58
-
59
- feature = self.groupnorm_feature(feature)
60
- feature = F.silu(feature)
61
- feature = self.conv_feature(feature)
62
-
63
- time = F.silu(time)
64
- time = self.linear_time(time)
19
+ import ai_edge_torch.generative.layers.builder as layers_builder
20
+ import ai_edge_torch.generative.layers.model_config as layers_cfg
21
+ import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
22
+ import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
23
+ import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
24
+
25
+ _down_encoder_blocks_tensor_names = [
26
+ stable_diffusion_loader.DownEncoderBlockTensorNames(
27
+ residual_block_tensor_names=[
28
+ stable_diffusion_loader.ResidualBlockTensorNames(
29
+ norm_1=f"unet.encoders.{i*3+j+1}.0.groupnorm_feature",
30
+ conv_1=f"unet.encoders.{i*3+j+1}.0.conv_feature",
31
+ norm_2=f"unet.encoders.{i*3+j+1}.0.groupnorm_merged",
32
+ conv_2=f"unet.encoders.{i*3+j+1}.0.conv_merged",
33
+ time_embedding=f"unet.encoders.{i*3+j+1}.0.linear_time",
34
+ residual_layer=f"unet.encoders.{i*3+j+1}.0.residual_layer"
35
+ if (i * 3 + j + 1) in [4, 7]
36
+ else None,
37
+ )
38
+ for j in range(2)
39
+ ],
40
+ transformer_block_tensor_names=[
41
+ stable_diffusion_loader.TransformerBlockTensorNames(
42
+ pre_conv_norm=f"unet.encoders.{i*3+j+1}.1.groupnorm",
43
+ conv_in=f"unet.encoders.{i*3+j+1}.1.conv_input",
44
+ conv_out=f"unet.encoders.{i*3+j+1}.1.conv_output",
45
+ self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
46
+ norm=f"unet.encoders.{i*3+j+1}.1.layernorm_1",
47
+ fused_qkv_proj=f"unet.encoders.{i*3+j+1}.1.attention_1.in_proj",
48
+ output_proj=f"unet.encoders.{i*3+j+1}.1.attention_1.out_proj",
49
+ ),
50
+ cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
51
+ norm=f"unet.encoders.{i*3+j+1}.1.layernorm_2",
52
+ q_proj=f"unet.encoders.{i*3+j+1}.1.attention_2.q_proj",
53
+ k_proj=f"unet.encoders.{i*3+j+1}.1.attention_2.k_proj",
54
+ v_proj=f"unet.encoders.{i*3+j+1}.1.attention_2.v_proj",
55
+ output_proj=f"unet.encoders.{i*3+j+1}.1.attention_2.out_proj",
56
+ ),
57
+ feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
58
+ norm=f"unet.encoders.{i*3+j+1}.1.layernorm_3",
59
+ ge_glu=f"unet.encoders.{i*3+j+1}.1.linear_geglu_1",
60
+ w2=f"unet.encoders.{i*3+j+1}.1.linear_geglu_2",
61
+ ),
62
+ )
63
+ for j in range(2)
64
+ ]
65
+ if i < 3
66
+ else None,
67
+ downsample_conv=f"unet.encoders.{i*3+3}.0" if i < 3 else None,
68
+ )
69
+ for i in range(4)
70
+ ]
71
+
72
+ _mid_block_tensor_names = stable_diffusion_loader.MidBlockTensorNames(
73
+ residual_block_tensor_names=[
74
+ stable_diffusion_loader.ResidualBlockTensorNames(
75
+ norm_1=f"unet.bottleneck.{i}.groupnorm_feature",
76
+ conv_1=f"unet.bottleneck.{i}.conv_feature",
77
+ norm_2=f"unet.bottleneck.{i}.groupnorm_merged",
78
+ conv_2=f"unet.bottleneck.{i}.conv_merged",
79
+ time_embedding=f"unet.bottleneck.{i}.linear_time",
80
+ )
81
+ for i in [0, 2]
82
+ ],
83
+ transformer_block_tensor_names=[
84
+ stable_diffusion_loader.TransformerBlockTensorNames(
85
+ pre_conv_norm=f"unet.bottleneck.{i}.groupnorm",
86
+ conv_in=f"unet.bottleneck.{i}.conv_input",
87
+ conv_out=f"unet.bottleneck.{i}.conv_output",
88
+ self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
89
+ norm=f"unet.bottleneck.{i}.layernorm_1",
90
+ fused_qkv_proj=f"unet.bottleneck.{i}.attention_1.in_proj",
91
+ output_proj=f"unet.bottleneck.{i}.attention_1.out_proj",
92
+ ),
93
+ cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
94
+ norm=f"unet.bottleneck.{i}.layernorm_2",
95
+ q_proj=f"unet.bottleneck.{i}.attention_2.q_proj",
96
+ k_proj=f"unet.bottleneck.{i}.attention_2.k_proj",
97
+ v_proj=f"unet.bottleneck.{i}.attention_2.v_proj",
98
+ output_proj=f"unet.bottleneck.{i}.attention_2.out_proj",
99
+ ),
100
+ feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
101
+ norm=f"unet.bottleneck.{i}.layernorm_3",
102
+ ge_glu=f"unet.bottleneck.{i}.linear_geglu_1",
103
+ w2=f"unet.bottleneck.{i}.linear_geglu_2",
104
+ ),
105
+ )
106
+ for i in [1]
107
+ ],
108
+ )
109
+
110
+ _up_decoder_blocks_tensor_names = [
111
+ stable_diffusion_loader.SkipUpDecoderBlockTensorNames(
112
+ residual_block_tensor_names=[
113
+ stable_diffusion_loader.ResidualBlockTensorNames(
114
+ norm_1=f"unet.decoders.{i*3+j}.0.groupnorm_feature",
115
+ conv_1=f"unet.decoders.{i*3+j}.0.conv_feature",
116
+ norm_2=f"unet.decoders.{i*3+j}.0.groupnorm_merged",
117
+ conv_2=f"unet.decoders.{i*3+j}.0.conv_merged",
118
+ time_embedding=f"unet.decoders.{i*3+j}.0.linear_time",
119
+ residual_layer=f"unet.decoders.{i*3+j}.0.residual_layer",
120
+ )
121
+ for j in range(3)
122
+ ],
123
+ transformer_block_tensor_names=[
124
+ stable_diffusion_loader.TransformerBlockTensorNames(
125
+ pre_conv_norm=f"unet.decoders.{i*3+j}.1.groupnorm",
126
+ conv_in=f"unet.decoders.{i*3+j}.1.conv_input",
127
+ conv_out=f"unet.decoders.{i*3+j}.1.conv_output",
128
+ self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
129
+ norm=f"unet.decoders.{i*3+j}.1.layernorm_1",
130
+ fused_qkv_proj=f"unet.decoders.{i*3+j}.1.attention_1.in_proj",
131
+ output_proj=f"unet.decoders.{i*3+j}.1.attention_1.out_proj",
132
+ ),
133
+ cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
134
+ norm=f"unet.decoders.{i*3+j}.1.layernorm_2",
135
+ q_proj=f"unet.decoders.{i*3+j}.1.attention_2.q_proj",
136
+ k_proj=f"unet.decoders.{i*3+j}.1.attention_2.k_proj",
137
+ v_proj=f"unet.decoders.{i*3+j}.1.attention_2.v_proj",
138
+ output_proj=f"unet.decoders.{i*3+j}.1.attention_2.out_proj",
139
+ ),
140
+ feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
141
+ norm=f"unet.decoders.{i*3+j}.1.layernorm_3",
142
+ ge_glu=f"unet.decoders.{i*3+j}.1.linear_geglu_1",
143
+ w2=f"unet.decoders.{i*3+j}.1.linear_geglu_2",
144
+ ),
145
+ )
146
+ for j in range(3)
147
+ ]
148
+ if i > 0
149
+ else None,
150
+ upsample_conv=f"unet.decoders.{i*3+2}.2.conv"
151
+ if 0 < i < 3
152
+ else (f"unet.decoders.2.1.conv" if i == 0 else None),
153
+ )
154
+ for i in range(4)
155
+ ]
65
156
 
66
- merged = feature + time.unsqueeze(-1).unsqueeze(-1)
67
- merged = self.groupnorm_merged(merged)
68
- merged = F.silu(merged)
69
- merged = self.conv_merged(merged)
70
157
 
71
- return merged + self.residual_layer(residue)
158
+ TENSORS_NAMES = stable_diffusion_loader.DiffusionModelLoader.TensorNames(
159
+ time_embedding=stable_diffusion_loader.TimeEmbeddingTensorNames(
160
+ w1="time_embedding.linear_1",
161
+ w2="time_embedding.linear_2",
162
+ ),
163
+ conv_in="unet.encoders.0.0",
164
+ conv_out="final.conv",
165
+ final_norm="final.groupnorm",
166
+ down_encoder_blocks_tensor_names=_down_encoder_blocks_tensor_names,
167
+ mid_block_tensor_names=_mid_block_tensor_names,
168
+ up_decoder_blocks_tensor_names=_up_decoder_blocks_tensor_names,
169
+ )
72
170
 
73
171
 
74
- class AttentionBlock(nn.Module):
172
+ class TimeEmbedding(nn.Module):
75
173
 
76
- def __init__(self, n_head: int, n_embd: int, d_context=768):
174
+ def __init__(self, in_dim, out_dim):
77
175
  super().__init__()
78
- channels = n_head * n_embd
79
-
80
- self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
81
- self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
82
-
83
- self.layernorm_1 = nn.LayerNorm(channels)
84
- self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
85
- self.layernorm_2 = nn.LayerNorm(channels)
86
- self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
87
- self.layernorm_3 = nn.LayerNorm(channels)
88
- self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
89
- self.linear_geglu_2 = nn.Linear(4 * channels, channels)
90
-
91
- self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
92
-
93
- def forward(self, x, context):
94
- residue_long = x
95
-
96
- x = self.groupnorm(x)
97
- x = self.conv_input(x)
98
-
99
- n, c, h, w = x.shape
100
- x = x.view((n, c, h * w)) # (n, c, hw)
101
- x = x.transpose(-1, -2) # (n, hw, c)
102
-
103
- residue_short = x
104
- x = self.layernorm_1(x)
105
- x = self.attention_1(x)
106
- x += residue_short
107
-
108
- residue_short = x
109
- x = self.layernorm_2(x)
110
- x = self.attention_2(x, context)
111
- x += residue_short
112
-
113
- residue_short = x
114
- x = self.layernorm_3(x)
115
- x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)
116
- x = x * F.gelu(gate)
117
- x = self.linear_geglu_2(x)
118
- x += residue_short
119
-
120
- x = x.transpose(-1, -2) # (n, c, hw)
121
- x = x.view((n, c, h, w)) # (n, c, h, w)
122
-
123
- return self.conv_output(x) + residue_long
176
+ self.w1 = nn.Linear(in_dim, out_dim)
177
+ self.w2 = nn.Linear(out_dim, out_dim)
178
+ self.act = layers_builder.get_activation(
179
+ layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU)
180
+ )
124
181
 
182
+ def forward(self, x: torch.Tensor):
183
+ return self.w2(self.act(self.w1(x)))
125
184
 
126
- class Upsample(nn.Module):
127
185
 
128
- def __init__(self, channels):
186
+ class Diffusion(nn.Module):
187
+ """The Diffusion model used in Stable Diffusion.
188
+
189
+ For details, see https://arxiv.org/abs/2103.00020
190
+
191
+ Sturcture of the Diffusion model:
192
+
193
+ latents text context time embed
194
+ │ │ │
195
+ │ │ │
196
+ ┌─────────▼─────────┐ │ ┌─────────▼─────────┐
197
+ │ ConvIn │ │ │ Time Embedding │
198
+ └─────────┬─────────┘ │ └─────────┬─────────┘
199
+ │ │ │
200
+ ┌─────────▼─────────┐ │ │
201
+ ┌──────┤ DownEncoder2D │ ◄─────┼────────────┤
202
+ │ └─────────┬─────────┘ x 4 │ │
203
+ │ │ │ │
204
+ │ ┌─────────▼─────────┐ │ │
205
+ skip connection │ MidBlock2D │ ◄─────┼────────────┤
206
+ │ └─────────┬─────────┘ │ │
207
+ │ │ │ │
208
+ │ ┌─────────▼─────────┐ │ │
209
+ └──────► SkipUpDecoder2D │ ◄─────┴────────────┘
210
+ └─────────┬─────────┘ x 4
211
+
212
+ ┌─────────▼─────────┐
213
+ │ FinalNorm │
214
+ └─────────┬─────────┘
215
+
216
+ ┌─────────▼─────────┐
217
+ │ Activation │
218
+ └─────────┬─────────┘
219
+
220
+ ┌─────────▼─────────┐
221
+ │ ConvOut │
222
+ └─────────┬─────────┘
223
+
224
+
225
+ output image
226
+ """
227
+
228
+ def __init__(self, config: unet_cfg.DiffusionModelConfig):
129
229
  super().__init__()
130
- self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
131
-
132
- def forward(self, x):
133
- x = F.interpolate(x, scale_factor=2, mode="nearest")
134
- return self.conv(x)
135
230
 
231
+ self.config = config
232
+ block_out_channels = config.block_out_channels
233
+ reversed_block_out_channels = list(reversed(block_out_channels))
136
234
 
137
- class SwitchSequential(nn.Sequential):
138
-
139
- def forward(self, x, context, time):
140
- for layer in self:
141
- if isinstance(layer, AttentionBlock):
142
- x = layer(x, context)
143
- elif isinstance(layer, ResidualBlock):
144
- x = layer(x, time)
145
- else:
146
- x = layer(x)
147
- return x
148
-
149
-
150
- class UNet(nn.Module):
235
+ time_embedding_blocks_dim = config.time_embedding_blocks_dim
236
+ self.time_embedding = TimeEmbedding(
237
+ config.time_embedding_dim, config.time_embedding_blocks_dim
238
+ )
151
239
 
152
- def __init__(self):
153
- super().__init__()
154
- self.encoders = nn.ModuleList(
155
- [
156
- SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
157
- SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
158
- SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
159
- SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
160
- SwitchSequential(ResidualBlock(320, 640), AttentionBlock(8, 80)),
161
- SwitchSequential(ResidualBlock(640, 640), AttentionBlock(8, 80)),
162
- SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
163
- SwitchSequential(ResidualBlock(640, 1280), AttentionBlock(8, 160)),
164
- SwitchSequential(ResidualBlock(1280, 1280), AttentionBlock(8, 160)),
165
- SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
166
- SwitchSequential(ResidualBlock(1280, 1280)),
167
- SwitchSequential(ResidualBlock(1280, 1280)),
168
- ]
240
+ self.conv_in = nn.Conv2d(
241
+ config.in_channels, block_out_channels[0], kernel_size=3, padding=1
169
242
  )
170
- self.bottleneck = SwitchSequential(
171
- ResidualBlock(1280, 1280),
172
- AttentionBlock(8, 160),
173
- ResidualBlock(1280, 1280),
243
+
244
+ attention_config = layers_cfg.AttentionConfig(
245
+ num_heads=config.transformer_num_attention_heads,
246
+ num_query_groups=config.transformer_num_attention_heads,
247
+ rotary_percentage=0.0,
248
+ qkv_transpose_before_split=True,
249
+ qkv_use_bias=False,
250
+ output_proj_use_bias=True,
251
+ enable_kv_cache=False,
174
252
  )
175
253
 
176
- self.decoders = nn.ModuleList(
177
- [
178
- SwitchSequential(ResidualBlock(2560, 1280)),
179
- SwitchSequential(ResidualBlock(2560, 1280)),
180
- SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)),
181
- SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
182
- SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
183
- SwitchSequential(
184
- ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280)
254
+ # Down encoders.
255
+ down_encoders = []
256
+ output_channel = block_out_channels[0]
257
+ for i, block_out_channel in enumerate(block_out_channels):
258
+ input_channel = output_channel
259
+ output_channel = block_out_channel
260
+ not_final_block = i < len(block_out_channels) - 1
261
+ if not_final_block:
262
+ down_encoders.append(
263
+ blocks_2d.DownEncoderBlock2D(
264
+ unet_cfg.DownEncoderBlock2DConfig(
265
+ in_channels=input_channel,
266
+ out_channels=output_channel,
267
+ normalization_config=config.residual_norm_config,
268
+ activation_config=layers_cfg.ActivationConfig(
269
+ config.residual_activation_type
270
+ ),
271
+ num_layers=config.layers_per_block,
272
+ padding=config.downsample_padding,
273
+ time_embedding_channels=time_embedding_blocks_dim,
274
+ add_downsample=True,
275
+ sampling_config=unet_cfg.DownSamplingConfig(
276
+ mode=unet_cfg.SamplingType.CONVOLUTION,
277
+ in_channels=output_channel,
278
+ out_channels=output_channel,
279
+ kernel_size=3,
280
+ stride=2,
281
+ padding=config.downsample_padding,
282
+ ),
283
+ transformer_block_config=unet_cfg.TransformerBlock2Dconfig(
284
+ attention_block_config=unet_cfg.AttentionBlock2DConfig(
285
+ dim=output_channel,
286
+ attention_batch_size=config.transformer_batch_size,
287
+ normalization_config=config.transformer_norm_config,
288
+ attention_config=attention_config,
289
+ ),
290
+ cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
291
+ query_dim=output_channel,
292
+ cross_dim=config.transformer_cross_attention_dim,
293
+ attention_batch_size=config.transformer_batch_size,
294
+ normalization_config=config.transformer_norm_config,
295
+ attention_config=attention_config,
296
+ ),
297
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
298
+ feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
299
+ dim=output_channel,
300
+ hidden_dim=output_channel * 4,
301
+ normalization_config=config.transformer_norm_config,
302
+ activation_config=layers_cfg.ActivationConfig(
303
+ type=config.transformer_ff_activation_type,
304
+ dim_in=output_channel,
305
+ dim_out=output_channel * 4,
306
+ ),
307
+ use_bias=True,
308
+ ),
309
+ ),
310
+ )
311
+ )
312
+ )
313
+ else:
314
+ down_encoders.append(
315
+ blocks_2d.DownEncoderBlock2D(
316
+ unet_cfg.DownEncoderBlock2DConfig(
317
+ in_channels=input_channel,
318
+ out_channels=output_channel,
319
+ normalization_config=config.residual_norm_config,
320
+ activation_config=layers_cfg.ActivationConfig(
321
+ config.residual_activation_type
322
+ ),
323
+ num_layers=config.layers_per_block,
324
+ padding=config.downsample_padding,
325
+ time_embedding_channels=time_embedding_blocks_dim,
326
+ add_downsample=False,
327
+ )
328
+ )
329
+ )
330
+ self.down_encoders = nn.ModuleList(down_encoders)
331
+
332
+ # Mid block.
333
+ mid_block_channels = block_out_channels[-1]
334
+ self.mid_block = blocks_2d.MidBlock2D(
335
+ unet_cfg.MidBlock2DConfig(
336
+ in_channels=block_out_channels[-1],
337
+ normalization_config=config.residual_norm_config,
338
+ activation_config=layers_cfg.ActivationConfig(
339
+ config.residual_activation_type
185
340
  ),
186
- SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)),
187
- SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)),
188
- SwitchSequential(
189
- ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640)
341
+ num_layers=config.mid_block_layers,
342
+ time_embedding_channels=config.time_embedding_blocks_dim,
343
+ transformer_block_config=unet_cfg.TransformerBlock2Dconfig(
344
+ attention_block_config=unet_cfg.AttentionBlock2DConfig(
345
+ dim=mid_block_channels,
346
+ attention_batch_size=config.transformer_batch_size,
347
+ normalization_config=config.transformer_norm_config,
348
+ attention_config=attention_config,
349
+ ),
350
+ cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
351
+ query_dim=mid_block_channels,
352
+ cross_dim=config.transformer_cross_attention_dim,
353
+ attention_batch_size=config.transformer_batch_size,
354
+ normalization_config=config.transformer_norm_config,
355
+ attention_config=attention_config,
356
+ ),
357
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
358
+ feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
359
+ dim=mid_block_channels,
360
+ hidden_dim=mid_block_channels * 4,
361
+ normalization_config=config.transformer_norm_config,
362
+ activation_config=layers_cfg.ActivationConfig(
363
+ type=config.transformer_ff_activation_type,
364
+ dim_in=mid_block_channels,
365
+ dim_out=mid_block_channels * 4,
366
+ ),
367
+ use_bias=True,
368
+ ),
190
369
  ),
191
- SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)),
192
- SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
193
- SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
194
- ]
370
+ )
195
371
  )
196
372
 
197
- def forward(self, x, context, time):
198
- skip_connections = []
199
- for layers in self.encoders:
200
- x = layers(x, context, time)
201
- skip_connections.append(x)
202
-
203
- x = self.bottleneck(x, context, time)
204
-
205
- for layers in self.decoders:
206
- x = torch.cat((x, skip_connections.pop()), dim=1)
207
- x = layers(x, context, time)
208
-
209
- return x
210
-
211
-
212
- class FinalLayer(nn.Module):
213
-
214
- def __init__(self, in_channels, out_channels):
215
- super().__init__()
216
- self.groupnorm = nn.GroupNorm(32, in_channels)
217
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
218
-
219
- def forward(self, x):
220
- x = self.groupnorm(x)
221
- x = F.silu(x)
222
- x = self.conv(x)
223
- return x
224
-
225
-
226
- class Diffusion(nn.Module):
227
-
228
- def __init__(self):
229
- super().__init__()
230
- self.time_embedding = TimeEmbedding(320)
231
- self.unet = UNet()
232
- self.final = FinalLayer(320, 4)
373
+ # Up decoders.
374
+ up_decoders = []
375
+ up_decoder_layers_per_block = config.layers_per_block + 1
376
+ output_channel = reversed_block_out_channels[0]
377
+ for i, block_out_channel in enumerate(reversed_block_out_channels):
378
+ prev_out_channel = output_channel
379
+ output_channel = block_out_channel
380
+ input_channel = reversed_block_out_channels[
381
+ min(i + 1, len(reversed_block_out_channels) - 1)
382
+ ]
383
+ not_final_block = i < len(reversed_block_out_channels) - 1
384
+ not_first_block = i != 0
385
+ if not_first_block:
386
+ up_decoders.append(
387
+ blocks_2d.SkipUpDecoderBlock2D(
388
+ unet_cfg.SkipUpDecoderBlock2DConfig(
389
+ in_channels=input_channel,
390
+ out_channels=output_channel,
391
+ prev_out_channels=prev_out_channel,
392
+ normalization_config=config.residual_norm_config,
393
+ activation_config=layers_cfg.ActivationConfig(
394
+ config.residual_activation_type
395
+ ),
396
+ num_layers=up_decoder_layers_per_block,
397
+ time_embedding_channels=time_embedding_blocks_dim,
398
+ add_upsample=not_final_block,
399
+ upsample_conv=True,
400
+ sampling_config=unet_cfg.UpSamplingConfig(
401
+ mode=unet_cfg.SamplingType.NEAREST,
402
+ scale_factor=2,
403
+ ),
404
+ transformer_block_config=unet_cfg.TransformerBlock2Dconfig(
405
+ attention_block_config=unet_cfg.AttentionBlock2DConfig(
406
+ dim=output_channel,
407
+ attention_batch_size=config.transformer_batch_size,
408
+ normalization_config=config.transformer_norm_config,
409
+ attention_config=attention_config,
410
+ ),
411
+ cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
412
+ query_dim=output_channel,
413
+ cross_dim=config.transformer_cross_attention_dim,
414
+ attention_batch_size=config.transformer_batch_size,
415
+ normalization_config=config.transformer_norm_config,
416
+ attention_config=attention_config,
417
+ ),
418
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
419
+ feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
420
+ dim=output_channel,
421
+ hidden_dim=output_channel * 4,
422
+ normalization_config=config.transformer_norm_config,
423
+ activation_config=layers_cfg.ActivationConfig(
424
+ type=config.transformer_ff_activation_type,
425
+ dim_in=output_channel,
426
+ dim_out=output_channel * 4,
427
+ ),
428
+ use_bias=True,
429
+ ),
430
+ ),
431
+ )
432
+ )
433
+ )
434
+ else:
435
+ up_decoders.append(
436
+ blocks_2d.SkipUpDecoderBlock2D(
437
+ unet_cfg.SkipUpDecoderBlock2DConfig(
438
+ in_channels=input_channel,
439
+ out_channels=output_channel,
440
+ prev_out_channels=prev_out_channel,
441
+ normalization_config=config.residual_norm_config,
442
+ activation_config=layers_cfg.ActivationConfig(
443
+ config.residual_activation_type
444
+ ),
445
+ num_layers=up_decoder_layers_per_block,
446
+ time_embedding_channels=time_embedding_blocks_dim,
447
+ add_upsample=not_final_block,
448
+ upsample_conv=True,
449
+ sampling_config=unet_cfg.UpSamplingConfig(
450
+ mode=unet_cfg.SamplingType.NEAREST, scale_factor=2
451
+ ),
452
+ )
453
+ )
454
+ )
455
+ self.up_decoders = nn.ModuleList(up_decoders)
456
+
457
+ self.final_norm = layers_builder.build_norm(
458
+ reversed_block_out_channels[-1], config.final_norm_config
459
+ )
460
+ self.final_act = layers_builder.get_activation(
461
+ layers_cfg.ActivationConfig(config.final_activation_type)
462
+ )
463
+ self.conv_out = nn.Conv2d(
464
+ reversed_block_out_channels[-1], config.out_channels, kernel_size=3, padding=1
465
+ )
233
466
 
234
467
  @torch.inference_mode
235
- def forward(self, latent, context, time):
236
- time = self.time_embedding(time)
237
- output = self.unet(latent, context, time)
238
- output = self.final(output)
239
- return output
468
+ def forward(
469
+ self, latents: torch.Tensor, context: torch.Tensor, time_emb: torch.Tensor
470
+ ) -> torch.Tensor:
471
+ """Forward function of diffusion model.
472
+
473
+ Args:
474
+ latents (torch.Tensor): latents space tensor.
475
+ context (torch.Tensor): context tensor from CLIP text encoder.
476
+ time_emb (torch.Tensor): the time embedding tensor.
477
+
478
+ Returns:
479
+ output latents from diffusion model.
480
+ """
481
+ time_emb = self.time_embedding(time_emb)
482
+ x = self.conv_in(latents)
483
+ skip_connection_tensors = [x]
484
+ for encoder in self.down_encoders:
485
+ x, hidden_states = encoder(x, time_emb, context, output_hidden_states=True)
486
+ skip_connection_tensors.extend(hidden_states)
487
+ x = self.mid_block(x, time_emb, context)
488
+ for decoder in self.up_decoders:
489
+ encoder_tensors = [
490
+ skip_connection_tensors.pop() for i in range(self.config.layers_per_block + 1)
491
+ ]
492
+ x = decoder(x, encoder_tensors, time_emb, context)
493
+ x = self.final_norm(x)
494
+ x = self.final_act(x)
495
+ x = self.conv_out(x)
496
+ return x
240
497
 
241
498
 
242
- if __name__ == "__main__":
243
- diffusion = Diffusion()
244
- print(diffusion.state_dict().keys())
499
+ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
500
+ """Get configs for the Diffusion model of Stable Diffusion v1.5
501
+
502
+ Args:
503
+ batch_size (int): the batch size of input.
504
+
505
+ Retruns:
506
+ The configuration of diffusion model of Stable Diffusion v1.5.
507
+
508
+ """
509
+ in_channels = 4
510
+ out_channels = 4
511
+ block_out_channels = [320, 640, 1280, 1280]
512
+ layers_per_block = 2
513
+ downsample_padding = 1
514
+
515
+ # Residual configs.
516
+ residual_norm_config = layers_cfg.NormalizationConfig(
517
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=32
518
+ )
519
+ residual_activation_type = layers_cfg.ActivationType.SILU
520
+
521
+ # Transformer configs.
522
+ transformer_num_attention_heads = 8
523
+ transformer_batch_size = batch_size
524
+ transformer_cross_attention_dim = 768 # Embedding fomr CLIP model
525
+ transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
526
+ layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=32
527
+ )
528
+ transformer_norm_config = layers_cfg.NormalizationConfig(
529
+ layers_cfg.NormalizationType.LAYER_NORM
530
+ )
531
+ transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU
532
+
533
+ # Time embedding configs.
534
+ time_embedding_dim = 320
535
+ time_embedding_blocks_dim = 1280
536
+
537
+ # Mid block configs.
538
+ mid_block_layers = 1
539
+
540
+ # Finaly layer configs.
541
+ final_norm_config = layers_cfg.NormalizationConfig(
542
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=32
543
+ )
544
+ final_activation_type = layers_cfg.ActivationType.SILU
545
+
546
+ return unet_cfg.DiffusionModelConfig(
547
+ in_channels=in_channels,
548
+ out_channels=out_channels,
549
+ block_out_channels=block_out_channels,
550
+ layers_per_block=layers_per_block,
551
+ downsample_padding=downsample_padding,
552
+ residual_norm_config=residual_norm_config,
553
+ residual_activation_type=residual_activation_type,
554
+ transformer_batch_size=transformer_batch_size,
555
+ transformer_num_attention_heads=transformer_num_attention_heads,
556
+ transformer_cross_attention_dim=transformer_cross_attention_dim,
557
+ transformer_pre_conv_norm_config=transformer_pre_conv_norm_config,
558
+ transformer_norm_config=transformer_norm_config,
559
+ transformer_ff_activation_type=transformer_ff_activation_type,
560
+ mid_block_layers=mid_block_layers,
561
+ time_embedding_dim=time_embedding_dim,
562
+ time_embedding_blocks_dim=time_embedding_blocks_dim,
563
+ final_norm_config=final_norm_config,
564
+ final_activation_type=final_activation_type,
565
+ )