ai-edge-torch-nightly 0.2.0.dev20240611__py3-none-any.whl → 0.2.0.dev20240617__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +19 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -2
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +9 -6
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +33 -25
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +523 -202
- ai_edge_torch/generative/examples/t5/t5_attention.py +10 -39
- ai_edge_torch/generative/layers/attention.py +154 -26
- ai_edge_torch/generative/layers/model_config.py +3 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +473 -49
- ai_edge_torch/generative/layers/unet/builder.py +20 -2
- ai_edge_torch/generative/layers/unet/model_config.py +157 -5
- ai_edge_torch/generative/test/test_model_conversion.py +24 -0
- ai_edge_torch/generative/test/test_quantize.py +1 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +860 -0
- ai_edge_torch/generative/utilities/t5_loader.py +33 -17
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/RECORD +20 -20
- ai_edge_torch/generative/utilities/autoencoder_loader.py +0 -298
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/top_level.txt +0 -0
|
@@ -13,13 +13,15 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from typing import Optional
|
|
16
|
+
from typing import List, Optional, Tuple
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
from torch import nn
|
|
20
20
|
|
|
21
|
+
from ai_edge_torch.generative.layers.attention import CrossAttention
|
|
21
22
|
from ai_edge_torch.generative.layers.attention import SelfAttention
|
|
22
23
|
import ai_edge_torch.generative.layers.builder as layers_builder
|
|
24
|
+
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
|
23
25
|
import ai_edge_torch.generative.layers.unet.builder as unet_builder
|
|
24
26
|
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
|
|
25
27
|
|
|
@@ -78,6 +80,7 @@ class ResidualBlock2D(nn.Module):
|
|
|
78
80
|
x = self.act_fn(x)
|
|
79
81
|
x = self.conv_1(x)
|
|
80
82
|
if self.time_emb_proj is not None:
|
|
83
|
+
time_emb = self.act_fn(time_emb)
|
|
81
84
|
time_emb = self.time_emb_proj(time_emb)[:, :, None, None]
|
|
82
85
|
x = x + time_emb
|
|
83
86
|
x = self.norm_2(x)
|
|
@@ -90,7 +93,7 @@ class ResidualBlock2D(nn.Module):
|
|
|
90
93
|
class AttentionBlock2D(nn.Module):
|
|
91
94
|
"""2D self attention block
|
|
92
95
|
|
|
93
|
-
x = SelfAttention(Norm(input_tensor))
|
|
96
|
+
x = SelfAttention(Norm(input_tensor)) + x
|
|
94
97
|
|
|
95
98
|
"""
|
|
96
99
|
|
|
@@ -101,8 +104,15 @@ class AttentionBlock2D(nn.Module):
|
|
|
101
104
|
config (unet_cfg.AttentionBlock2DConfig): the configuration of this block.
|
|
102
105
|
"""
|
|
103
106
|
super().__init__()
|
|
104
|
-
self.
|
|
105
|
-
self.
|
|
107
|
+
self.config = config
|
|
108
|
+
self.norm = layers_builder.build_norm(config.dim, config.normalization_config)
|
|
109
|
+
self.attention = SelfAttention(
|
|
110
|
+
config.attention_batch_size,
|
|
111
|
+
config.dim,
|
|
112
|
+
config.attention_config,
|
|
113
|
+
0,
|
|
114
|
+
enable_hlfb=config.enable_hlfb,
|
|
115
|
+
)
|
|
106
116
|
|
|
107
117
|
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
108
118
|
"""Forward function of the AttentionBlock2D.
|
|
@@ -114,10 +124,17 @@ class AttentionBlock2D(nn.Module):
|
|
|
114
124
|
output activation tensor after self attention.
|
|
115
125
|
"""
|
|
116
126
|
residual = input_tensor
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
127
|
+
B, C, H, W = input_tensor.shape
|
|
128
|
+
x = input_tensor
|
|
129
|
+
if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
|
|
130
|
+
x = self.norm(x)
|
|
131
|
+
x = input_tensor.view(B, C, H * W)
|
|
132
|
+
x = x.transpose(-1, -2)
|
|
133
|
+
else:
|
|
134
|
+
x = input_tensor.view(B, C, H * W)
|
|
135
|
+
x = x.transpose(-1, -2)
|
|
136
|
+
x = self.norm(x)
|
|
137
|
+
x = x.contiguous() # Prevent BATCH_MATMUL op in converted tflite.
|
|
121
138
|
x = self.attention(x)
|
|
122
139
|
x = x.transpose(-1, -2)
|
|
123
140
|
x = x.view(B, C, H, W)
|
|
@@ -125,28 +142,306 @@ class AttentionBlock2D(nn.Module):
|
|
|
125
142
|
return x
|
|
126
143
|
|
|
127
144
|
|
|
128
|
-
class
|
|
129
|
-
"""
|
|
145
|
+
class CrossAttentionBlock2D(nn.Module):
|
|
146
|
+
"""2D cross attention block
|
|
130
147
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
148
|
+
x = CrossAttention(Norm(input_tensor), context) + x
|
|
149
|
+
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
def __init__(self, config: unet_cfg.CrossAttentionBlock2DConfig):
|
|
153
|
+
"""Initialize an instance of the AttentionBlock2D.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
config (unet_cfg.CrossAttentionBlock2DConfig): the configuration of this block.
|
|
157
|
+
"""
|
|
158
|
+
super().__init__()
|
|
159
|
+
self.config = config
|
|
160
|
+
self.norm = layers_builder.build_norm(config.query_dim, config.normalization_config)
|
|
161
|
+
self.attention = CrossAttention(
|
|
162
|
+
config.attention_batch_size,
|
|
163
|
+
config.query_dim,
|
|
164
|
+
config.cross_dim,
|
|
165
|
+
config.attention_config,
|
|
166
|
+
0,
|
|
167
|
+
enable_hlfb=config.enable_hlfb,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def forward(
|
|
171
|
+
self, input_tensor: torch.Tensor, context_tensor: torch.Tensor
|
|
172
|
+
) -> torch.Tensor:
|
|
173
|
+
"""Forward function of the CrossAttentionBlock2D.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
input_tensor (torch.Tensor): the input tensor.
|
|
177
|
+
context_tensor (torch.Tensor): the context tensor to apply cross attention on.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
output activation tensor after cross attention.
|
|
181
|
+
"""
|
|
182
|
+
residual = input_tensor
|
|
183
|
+
B, C, H, W = input_tensor.shape
|
|
184
|
+
x = input_tensor
|
|
185
|
+
if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
|
|
186
|
+
x = self.norm(x)
|
|
187
|
+
x = input_tensor.view(B, C, H * W)
|
|
188
|
+
x = x.transpose(-1, -2)
|
|
189
|
+
else:
|
|
190
|
+
x = input_tensor.view(B, C, H * W)
|
|
191
|
+
x = x.transpose(-1, -2)
|
|
192
|
+
x = self.norm(x)
|
|
193
|
+
x = self.attention(x, context_tensor)
|
|
194
|
+
x = x.transpose(-1, -2)
|
|
195
|
+
x = x.view(B, C, H, W)
|
|
196
|
+
x = x + residual
|
|
197
|
+
return x
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class FeedForwardBlock2D(nn.Module):
|
|
201
|
+
"""2D feed forward block
|
|
202
|
+
|
|
203
|
+
x = w2(Activation(w1(Norm(x)))) + x
|
|
204
|
+
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
def __init__(
|
|
208
|
+
self,
|
|
209
|
+
config: unet_cfg.FeedForwardBlock2DConfig,
|
|
210
|
+
):
|
|
211
|
+
super().__init__()
|
|
212
|
+
self.config = config
|
|
213
|
+
self.act = layers_builder.get_activation(config.activation_config)
|
|
214
|
+
self.norm = layers_builder.build_norm(config.dim, config.normalization_config)
|
|
215
|
+
if config.activation_config.type == layers_cfg.ActivationType.GE_GLU:
|
|
216
|
+
self.w1 = nn.Identity()
|
|
217
|
+
self.w2 = nn.Linear(config.hidden_dim, config.dim)
|
|
218
|
+
else:
|
|
219
|
+
self.w1 = nn.Linear(config.dim, config.hidden_dim)
|
|
220
|
+
self.w2 = nn.Linear(config.hidden_dim, config.dim)
|
|
221
|
+
|
|
222
|
+
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
223
|
+
residual = input_tensor
|
|
224
|
+
B, C, H, W = input_tensor.shape
|
|
225
|
+
x = input_tensor
|
|
226
|
+
if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
|
|
227
|
+
x = self.norm(x)
|
|
228
|
+
x = input_tensor.view(B, C, H * W)
|
|
229
|
+
x = x.transpose(-1, -2)
|
|
230
|
+
else:
|
|
231
|
+
x = input_tensor.view(B, C, H * W)
|
|
232
|
+
x = x.transpose(-1, -2)
|
|
233
|
+
x = self.norm(x)
|
|
234
|
+
x = self.w1(x)
|
|
235
|
+
x = self.act(x)
|
|
236
|
+
x = self.w2(x)
|
|
237
|
+
|
|
238
|
+
x = x.transpose(-1, -2) # (B, C, HW)
|
|
239
|
+
x = x.view((B, C, H, W))
|
|
240
|
+
|
|
241
|
+
return x + residual
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class TransformerBlock2D(nn.Module):
|
|
245
|
+
"""Basic transformer block used in UNet of diffusion model
|
|
246
|
+
|
|
247
|
+
input_tensor context_tensor
|
|
248
|
+
| |
|
|
249
|
+
┌─────────▼─────────┐ |
|
|
250
|
+
│ ConvIn | │
|
|
251
|
+
└─────────┬─────────┘ |
|
|
252
|
+
| |
|
|
253
|
+
▼ |
|
|
254
|
+
┌───────────────────┐ |
|
|
255
|
+
│ Attention Block │ |
|
|
256
|
+
└─────────┬─────────┘ |
|
|
257
|
+
│ |
|
|
258
|
+
┌────────────────────┐ |
|
|
259
|
+
│CrossAttention Block│◄─────┘
|
|
260
|
+
└─────────┬──────────┘
|
|
137
261
|
│
|
|
138
262
|
┌─────────▼─────────┐
|
|
139
|
-
│
|
|
140
|
-
│ Upsampler │
|
|
263
|
+
│ FeedForwardBlock │
|
|
141
264
|
└─────────┬─────────┘
|
|
142
265
|
│
|
|
143
266
|
┌─────────▼─────────┐
|
|
144
|
-
│
|
|
145
|
-
│ Conv2D │
|
|
267
|
+
│ ConvOut │
|
|
146
268
|
└─────────┬─────────┘
|
|
147
|
-
│
|
|
148
269
|
▼
|
|
149
270
|
hidden_states
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
def __init__(self, config: unet_cfg.TransformerBlock2Dconfig):
|
|
276
|
+
"""Initialize an instance of the TransformerBlock2D.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
config (unet_cfg.TransformerBlock2Dconfig): the configuration of this block.
|
|
280
|
+
"""
|
|
281
|
+
super().__init__()
|
|
282
|
+
self.config = config
|
|
283
|
+
self.pre_conv_norm = layers_builder.build_norm(
|
|
284
|
+
config.attention_block_config.dim, config.pre_conv_normalization_config
|
|
285
|
+
)
|
|
286
|
+
self.conv_in = nn.Conv2d(
|
|
287
|
+
config.attention_block_config.dim,
|
|
288
|
+
config.attention_block_config.dim,
|
|
289
|
+
kernel_size=1,
|
|
290
|
+
padding=0,
|
|
291
|
+
)
|
|
292
|
+
self.self_attention = AttentionBlock2D(config.attention_block_config)
|
|
293
|
+
self.cross_attention = CrossAttentionBlock2D(config.cross_attention_block_config)
|
|
294
|
+
self.feed_forward = FeedForwardBlock2D(config.feed_forward_block_config)
|
|
295
|
+
self.conv_out = nn.Conv2d(
|
|
296
|
+
config.attention_block_config.dim,
|
|
297
|
+
config.attention_block_config.dim,
|
|
298
|
+
kernel_size=1,
|
|
299
|
+
padding=0,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
def forward(self, x: torch.Tensor, context: torch.Tensor):
|
|
303
|
+
"""Forward function of the TransformerBlock2D.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
input_tensor (torch.Tensor): the input tensor.
|
|
307
|
+
context_tensor (torch.Tensor): the context tensor to apply cross attention on.
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
output activation tensor after transformer block.
|
|
311
|
+
"""
|
|
312
|
+
residual_long = x
|
|
313
|
+
|
|
314
|
+
x = self.pre_conv_norm(x)
|
|
315
|
+
x = self.conv_in(x)
|
|
316
|
+
x = self.self_attention(x)
|
|
317
|
+
x = self.cross_attention(x, context)
|
|
318
|
+
x = self.feed_forward(x)
|
|
319
|
+
|
|
320
|
+
x = self.conv_out(x)
|
|
321
|
+
x = x + residual_long
|
|
322
|
+
|
|
323
|
+
return x
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
class DownEncoderBlock2D(nn.Module):
|
|
327
|
+
"""Encoder block containing several residual blocks with optional interleaved transformer blocks.
|
|
328
|
+
|
|
329
|
+
input_tensor
|
|
330
|
+
|
|
|
331
|
+
┌──────────────▼─────────────┐
|
|
332
|
+
│ ┌────────────────────┐ │
|
|
333
|
+
│ │ ResidualBlock2D │ │
|
|
334
|
+
│ └──────────┬─────────┘ │
|
|
335
|
+
│ │ │ num_layers
|
|
336
|
+
│ ┌────────────────────┐ │
|
|
337
|
+
│ │ (Optional) │ │
|
|
338
|
+
│ │ TransformerBlock2D │ │
|
|
339
|
+
│ └──────────┬─────────┘ │
|
|
340
|
+
└──────────────┬─────────────┘
|
|
341
|
+
│
|
|
342
|
+
┌──────────▼─────────┐
|
|
343
|
+
│ (Optional) │
|
|
344
|
+
│ Downsampler │
|
|
345
|
+
└──────────┬─────────┘
|
|
346
|
+
│
|
|
347
|
+
▼
|
|
348
|
+
hidden_states
|
|
349
|
+
"""
|
|
350
|
+
|
|
351
|
+
def __init__(self, config: unet_cfg.DownEncoderBlock2DConfig):
|
|
352
|
+
"""Initialize an instance of the DownEncoderBlock2D.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
config (unet_cfg.DownEncoderBlock2DConfig): the configuration of this block.
|
|
356
|
+
"""
|
|
357
|
+
super().__init__()
|
|
358
|
+
self.config = config
|
|
359
|
+
resnets = []
|
|
360
|
+
transformers = []
|
|
361
|
+
for i in range(config.num_layers):
|
|
362
|
+
input_channels = config.in_channels if i == 0 else config.out_channels
|
|
363
|
+
resnets.append(
|
|
364
|
+
ResidualBlock2D(
|
|
365
|
+
unet_cfg.ResidualBlock2DConfig(
|
|
366
|
+
in_channels=input_channels,
|
|
367
|
+
out_channels=config.out_channels,
|
|
368
|
+
time_embedding_channels=config.time_embedding_channels,
|
|
369
|
+
normalization_config=config.normalization_config,
|
|
370
|
+
activation_config=config.activation_config,
|
|
371
|
+
)
|
|
372
|
+
)
|
|
373
|
+
)
|
|
374
|
+
if config.transformer_block_config:
|
|
375
|
+
transformers.append(TransformerBlock2D(config.transformer_block_config))
|
|
376
|
+
self.resnets = nn.ModuleList(resnets)
|
|
377
|
+
self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
|
|
378
|
+
if config.add_downsample:
|
|
379
|
+
self.downsampler = unet_builder.build_downsampling(config.sampling_config)
|
|
380
|
+
else:
|
|
381
|
+
self.downsampler = None
|
|
382
|
+
|
|
383
|
+
def forward(
|
|
384
|
+
self,
|
|
385
|
+
input_tensor: torch.Tensor,
|
|
386
|
+
time_emb: Optional[torch.Tensor] = None,
|
|
387
|
+
context_tensor: Optional[torch.Tensor] = None,
|
|
388
|
+
output_hidden_states: bool = False,
|
|
389
|
+
) -> torch.Tensor | Tuple[torch.Tensor, List[torch.Tensor]]:
|
|
390
|
+
"""Forward function of the DownEncoderBlock2D.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
input_tensor (torch.Tensor): the input tensor.
|
|
394
|
+
time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
|
|
395
|
+
time embedding.
|
|
396
|
+
context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block.
|
|
397
|
+
output_hidden_states (bool): whether to output hidden states, usually for skip connections.
|
|
398
|
+
Returns:
|
|
399
|
+
output hidden_states tensor after DownEncoderBlock2D.
|
|
400
|
+
"""
|
|
401
|
+
hidden_states = input_tensor
|
|
402
|
+
output_states = []
|
|
403
|
+
for i, resnet in enumerate(self.resnets):
|
|
404
|
+
hidden_states = resnet(hidden_states, time_emb)
|
|
405
|
+
if self.transformers is not None:
|
|
406
|
+
hidden_states = self.transformers[i](hidden_states, context_tensor)
|
|
407
|
+
output_states.append(hidden_states)
|
|
408
|
+
if self.downsampler:
|
|
409
|
+
hidden_states = self.downsampler(hidden_states)
|
|
410
|
+
output_states.append(hidden_states)
|
|
411
|
+
if output_hidden_states:
|
|
412
|
+
return hidden_states, output_states
|
|
413
|
+
else:
|
|
414
|
+
return hidden_states
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
class UpDecoderBlock2D(nn.Module):
|
|
418
|
+
"""Decoder block containing several residual blocks with optional interleaved transformer blocks.
|
|
419
|
+
|
|
420
|
+
input_tensor
|
|
421
|
+
|
|
|
422
|
+
┌──────────────▼─────────────┐
|
|
423
|
+
│ ┌────────────────────┐ │
|
|
424
|
+
│ │ ResidualBlock2D │ │
|
|
425
|
+
│ └──────────┬─────────┘ │
|
|
426
|
+
│ │ │ num_layers
|
|
427
|
+
│ ┌────────────────────┐ │
|
|
428
|
+
│ │ (Optional) │ │
|
|
429
|
+
│ │ TransformerBlock2D │ │
|
|
430
|
+
│ └──────────┬─────────┘ │
|
|
431
|
+
└──────────────┬─────────────┘
|
|
432
|
+
│
|
|
433
|
+
┌──────────▼─────────┐
|
|
434
|
+
│ (Optional) │
|
|
435
|
+
│ Upsampler │
|
|
436
|
+
└──────────┬─────────┘
|
|
437
|
+
│
|
|
438
|
+
┌──────────▼─────────┐
|
|
439
|
+
│ (Optional) │
|
|
440
|
+
│ Conv2D │
|
|
441
|
+
└──────────┬─────────┘
|
|
442
|
+
│
|
|
443
|
+
▼
|
|
444
|
+
hidden_states
|
|
150
445
|
"""
|
|
151
446
|
|
|
152
447
|
def __init__(self, config: unet_cfg.UpDecoderBlock2DConfig):
|
|
@@ -158,6 +453,7 @@ class UpDecoderBlock2D(nn.Module):
|
|
|
158
453
|
super().__init__()
|
|
159
454
|
self.config = config
|
|
160
455
|
resnets = []
|
|
456
|
+
transformers = []
|
|
161
457
|
for i in range(config.num_layers):
|
|
162
458
|
input_channels = config.in_channels if i == 0 else config.out_channels
|
|
163
459
|
resnets.append(
|
|
@@ -171,7 +467,10 @@ class UpDecoderBlock2D(nn.Module):
|
|
|
171
467
|
)
|
|
172
468
|
)
|
|
173
469
|
)
|
|
470
|
+
if config.transformer_block_config:
|
|
471
|
+
transformers.append(TransformerBlock2D(config.transformer_block_config))
|
|
174
472
|
self.resnets = nn.ModuleList(resnets)
|
|
473
|
+
self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
|
|
175
474
|
if config.add_upsample:
|
|
176
475
|
self.upsampler = unet_builder.build_upsampling(config.sampling_config)
|
|
177
476
|
if config.upsample_conv:
|
|
@@ -182,21 +481,130 @@ class UpDecoderBlock2D(nn.Module):
|
|
|
182
481
|
self.upsampler = None
|
|
183
482
|
|
|
184
483
|
def forward(
|
|
185
|
-
self,
|
|
484
|
+
self,
|
|
485
|
+
input_tensor: torch.Tensor,
|
|
486
|
+
time_emb: Optional[torch.Tensor] = None,
|
|
487
|
+
context_tensor: Optional[torch.Tensor] = None,
|
|
186
488
|
) -> torch.Tensor:
|
|
187
489
|
"""Forward function of the UpDecoderBlock2D.
|
|
188
490
|
|
|
189
491
|
Args:
|
|
190
492
|
input_tensor (torch.Tensor): the input tensor.
|
|
191
493
|
time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
|
|
192
|
-
time embedding
|
|
494
|
+
time embedding.
|
|
495
|
+
context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block.
|
|
193
496
|
|
|
194
497
|
Returns:
|
|
195
498
|
output hidden_states tensor after UpDecoderBlock2D.
|
|
196
499
|
"""
|
|
197
500
|
hidden_states = input_tensor
|
|
198
|
-
for resnet in self.resnets:
|
|
501
|
+
for i, resnet in enumerate(self.resnets):
|
|
199
502
|
hidden_states = resnet(hidden_states, time_emb)
|
|
503
|
+
if self.transformers is not None:
|
|
504
|
+
hidden_states = self.transformers[i](hidden_states, context_tensor)
|
|
505
|
+
if self.upsampler:
|
|
506
|
+
hidden_states = self.upsampler(hidden_states)
|
|
507
|
+
if self.upsample_conv:
|
|
508
|
+
hidden_states = self.upsample_conv(hidden_states)
|
|
509
|
+
return hidden_states
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
class SkipUpDecoderBlock2D(nn.Module):
|
|
513
|
+
"""Decoder block contains skip connections and residual blocks with optional interleaved transformer blocks.
|
|
514
|
+
|
|
515
|
+
input_tensor, skip_connection_tensors
|
|
516
|
+
|
|
|
517
|
+
┌──────────────▼─────────────┐
|
|
518
|
+
│ ┌────────────────────┐ │
|
|
519
|
+
│ │ ResidualBlock2D │ │
|
|
520
|
+
│ └──────────┬─────────┘ │
|
|
521
|
+
│ │ │ num_layers
|
|
522
|
+
│ ┌────────────────────┐ │
|
|
523
|
+
│ │ (Optional) │ │
|
|
524
|
+
│ │ TransformerBlock2D │ │
|
|
525
|
+
│ └──────────┬─────────┘ │
|
|
526
|
+
└──────────────┬─────────────┘
|
|
527
|
+
│
|
|
528
|
+
┌──────────▼─────────┐
|
|
529
|
+
│ (Optional) │
|
|
530
|
+
│ Upsampler │
|
|
531
|
+
└──────────┬─────────┘
|
|
532
|
+
│
|
|
533
|
+
┌──────────▼─────────┐
|
|
534
|
+
│ (Optional) │
|
|
535
|
+
│ Conv2D │
|
|
536
|
+
└──────────┬─────────┘
|
|
537
|
+
│
|
|
538
|
+
▼
|
|
539
|
+
hidden_states
|
|
540
|
+
"""
|
|
541
|
+
|
|
542
|
+
def __init__(self, config: unet_cfg.SkipUpDecoderBlock2DConfig):
|
|
543
|
+
"""Initialize an instance of the SkipUpDecoderBlock2D.
|
|
544
|
+
|
|
545
|
+
Args:
|
|
546
|
+
config (unet_cfg.SkipUpDecoderBlock2DConfig): the configuration of this block.
|
|
547
|
+
"""
|
|
548
|
+
super().__init__()
|
|
549
|
+
self.config = config
|
|
550
|
+
resnets = []
|
|
551
|
+
transformers = []
|
|
552
|
+
for i in range(config.num_layers):
|
|
553
|
+
res_skip_channels = (
|
|
554
|
+
config.in_channels if (i == config.num_layers - 1) else config.out_channels
|
|
555
|
+
)
|
|
556
|
+
resnet_in_channels = config.prev_out_channels if i == 0 else config.out_channels
|
|
557
|
+
resnets.append(
|
|
558
|
+
ResidualBlock2D(
|
|
559
|
+
unet_cfg.ResidualBlock2DConfig(
|
|
560
|
+
in_channels=resnet_in_channels + res_skip_channels,
|
|
561
|
+
out_channels=config.out_channels,
|
|
562
|
+
time_embedding_channels=config.time_embedding_channels,
|
|
563
|
+
normalization_config=config.normalization_config,
|
|
564
|
+
activation_config=config.activation_config,
|
|
565
|
+
)
|
|
566
|
+
)
|
|
567
|
+
)
|
|
568
|
+
if config.transformer_block_config:
|
|
569
|
+
transformers.append(TransformerBlock2D(config.transformer_block_config))
|
|
570
|
+
self.resnets = nn.ModuleList(resnets)
|
|
571
|
+
self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
|
|
572
|
+
if config.add_upsample:
|
|
573
|
+
self.upsampler = unet_builder.build_upsampling(config.sampling_config)
|
|
574
|
+
if config.upsample_conv:
|
|
575
|
+
self.upsample_conv = nn.Conv2d(
|
|
576
|
+
config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
|
|
577
|
+
)
|
|
578
|
+
else:
|
|
579
|
+
self.upsampler = None
|
|
580
|
+
|
|
581
|
+
def forward(
|
|
582
|
+
self,
|
|
583
|
+
input_tensor: torch.Tensor,
|
|
584
|
+
skip_connection_tensors: List[torch.Tensor],
|
|
585
|
+
time_emb: Optional[torch.Tensor] = None,
|
|
586
|
+
context_tensor: Optional[torch.Tensor] = None,
|
|
587
|
+
) -> torch.Tensor:
|
|
588
|
+
"""Forward function of the SkipUpDecoderBlock2D.
|
|
589
|
+
|
|
590
|
+
Args:
|
|
591
|
+
input_tensor (torch.Tensor): the input tensor.
|
|
592
|
+
skip_connection_tensors (List[torch.Tensor]): the skip connection tensors from encoder blocks.
|
|
593
|
+
time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
|
|
594
|
+
time embedding.
|
|
595
|
+
context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block.
|
|
596
|
+
|
|
597
|
+
Returns:
|
|
598
|
+
output hidden_states tensor after SkipUpDecoderBlock2D.
|
|
599
|
+
"""
|
|
600
|
+
hidden_states = input_tensor
|
|
601
|
+
for i, (resnet, skip_connection_tensor) in enumerate(
|
|
602
|
+
zip(self.resnets, skip_connection_tensors)
|
|
603
|
+
):
|
|
604
|
+
hidden_states = torch.cat([hidden_states, skip_connection_tensor], dim=1)
|
|
605
|
+
hidden_states = resnet(hidden_states, time_emb)
|
|
606
|
+
if self.transformers is not None:
|
|
607
|
+
hidden_states = self.transformers[i](hidden_states, context_tensor)
|
|
200
608
|
if self.upsampler:
|
|
201
609
|
hidden_states = self.upsampler(hidden_states)
|
|
202
610
|
if self.upsample_conv:
|
|
@@ -207,25 +615,30 @@ class UpDecoderBlock2D(nn.Module):
|
|
|
207
615
|
class MidBlock2D(nn.Module):
|
|
208
616
|
"""Middle block containing at least one residual blocks with optional interleaved attention blocks.
|
|
209
617
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
│
|
|
219
|
-
│ │
|
|
220
|
-
│ │ AttentionBlock2D
|
|
221
|
-
│
|
|
222
|
-
│
|
|
223
|
-
│
|
|
224
|
-
│ │
|
|
225
|
-
│
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
618
|
+
input_tensor
|
|
619
|
+
|
|
|
620
|
+
▼
|
|
621
|
+
┌───────────────────┐
|
|
622
|
+
│ ResidualBlock2D │
|
|
623
|
+
└─────────┬─────────┘
|
|
624
|
+
│
|
|
625
|
+
┌──────────────▼─────────────┐
|
|
626
|
+
│ ┌────────────────────┐ │
|
|
627
|
+
│ │ (Optional) │ │
|
|
628
|
+
│ │ AttentionBlock2D │ │
|
|
629
|
+
│ └──────────┬─────────┘ │
|
|
630
|
+
│ │ │
|
|
631
|
+
│ ┌──────────▼─────────┐ │
|
|
632
|
+
│ │ (Optional) │ │ num_layers
|
|
633
|
+
│ │ TransformerBlock2D │ │
|
|
634
|
+
│ └──────────┬─────────┘ │
|
|
635
|
+
│ │ │
|
|
636
|
+
│ ┌──────────▼─────────┐ │
|
|
637
|
+
│ │ ResidualBlock2D │ │
|
|
638
|
+
│ └────────────────────┘ │
|
|
639
|
+
└──────────────┬─────────────┘
|
|
640
|
+
│
|
|
641
|
+
▼
|
|
229
642
|
hidden_states
|
|
230
643
|
"""
|
|
231
644
|
|
|
@@ -249,9 +662,12 @@ class MidBlock2D(nn.Module):
|
|
|
249
662
|
)
|
|
250
663
|
]
|
|
251
664
|
attentions = []
|
|
665
|
+
transformers = []
|
|
252
666
|
for i in range(config.num_layers):
|
|
253
667
|
if self.config.attention_block_config:
|
|
254
668
|
attentions.append(AttentionBlock2D(config.attention_block_config))
|
|
669
|
+
if self.config.transformer_block_config:
|
|
670
|
+
transformers.append(TransformerBlock2D(config.transformer_block_config))
|
|
255
671
|
resnets.append(
|
|
256
672
|
ResidualBlock2D(
|
|
257
673
|
unet_cfg.ResidualBlock2DConfig(
|
|
@@ -264,24 +680,32 @@ class MidBlock2D(nn.Module):
|
|
|
264
680
|
)
|
|
265
681
|
)
|
|
266
682
|
self.resnets = nn.ModuleList(resnets)
|
|
267
|
-
self.attentions = nn.ModuleList(attentions)
|
|
683
|
+
self.attentions = nn.ModuleList(attentions) if len(attentions) > 0 else None
|
|
684
|
+
self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
|
|
268
685
|
|
|
269
686
|
def forward(
|
|
270
|
-
self,
|
|
687
|
+
self,
|
|
688
|
+
input_tensor: torch.Tensor,
|
|
689
|
+
time_emb: Optional[torch.Tensor] = None,
|
|
690
|
+
context_tensor: Optional[torch.Tensor] = None,
|
|
271
691
|
) -> torch.Tensor:
|
|
272
692
|
"""Forward function of the MidBlock2D.
|
|
273
693
|
|
|
274
694
|
Args:
|
|
275
695
|
input_tensor (torch.Tensor): the input tensor.
|
|
276
696
|
time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
|
|
277
|
-
time embedding
|
|
697
|
+
time embedding.
|
|
698
|
+
context_tensor (torch.Tensor): optional context tensor, if the block if configured to use
|
|
699
|
+
transofrmer block.
|
|
278
700
|
|
|
279
701
|
Returns:
|
|
280
702
|
output hidden_states tensor after MidBlock2D.
|
|
281
703
|
"""
|
|
282
704
|
hidden_states = self.resnets[0](input_tensor, time_emb)
|
|
283
|
-
for
|
|
284
|
-
if
|
|
285
|
-
hidden_states =
|
|
705
|
+
for i, resnet in enumerate(self.resnets[1:]):
|
|
706
|
+
if self.attentions is not None:
|
|
707
|
+
hidden_states = self.attentions[i](hidden_states)
|
|
708
|
+
if self.transformers is not None:
|
|
709
|
+
hidden_states = self.transformers[i](hidden_states, context_tensor)
|
|
286
710
|
hidden_states = resnet(hidden_states, time_emb)
|
|
287
711
|
return hidden_states
|
|
@@ -15,15 +15,33 @@
|
|
|
15
15
|
# Builder utils for individual components.
|
|
16
16
|
|
|
17
17
|
from torch import nn
|
|
18
|
-
import torch.nn.functional as F
|
|
19
18
|
|
|
20
19
|
import ai_edge_torch.generative.layers.unet.model_config as unet_config
|
|
21
20
|
|
|
22
21
|
|
|
23
|
-
def build_upsampling(config: unet_config.
|
|
22
|
+
def build_upsampling(config: unet_config.UpSamplingConfig):
|
|
24
23
|
if config.mode == unet_config.SamplingType.NEAREST:
|
|
25
24
|
return nn.UpsamplingNearest2d(scale_factor=config.scale_factor)
|
|
26
25
|
elif config.mode == unet_config.SamplingType.BILINEAR:
|
|
27
26
|
return nn.UpsamplingBilinear2d(scale_factor=config.scale_factor)
|
|
28
27
|
else:
|
|
29
28
|
raise ValueError("Unsupported upsampling type.")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def build_downsampling(config: unet_config.DownSamplingConfig):
|
|
32
|
+
if config.mode == unet_config.SamplingType.AVERAGE:
|
|
33
|
+
return nn.AvgPool2d(config.kernel_size, config.stride, padding=config.padding)
|
|
34
|
+
elif config.mode == unet_config.SamplingType.CONVOLUTION:
|
|
35
|
+
out_channels = (
|
|
36
|
+
config.in_channels if config.out_channels is None else config.out_channels
|
|
37
|
+
)
|
|
38
|
+
padding = (0, 1, 0, 1) if config.padding == 0 else config.padding
|
|
39
|
+
return nn.Conv2d(
|
|
40
|
+
config.in_channels,
|
|
41
|
+
out_channels=out_channels,
|
|
42
|
+
kernel_size=config.kernel_size,
|
|
43
|
+
stride=config.stride,
|
|
44
|
+
padding=padding,
|
|
45
|
+
)
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError("Unsupported downsampling type.")
|