ai-edge-torch-nightly 0.2.0.dev20240611__py3-none-any.whl → 0.2.0.dev20240617__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (21) hide show
  1. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +19 -0
  2. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -2
  3. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +9 -6
  4. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +33 -25
  5. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +523 -202
  6. ai_edge_torch/generative/examples/t5/t5_attention.py +10 -39
  7. ai_edge_torch/generative/layers/attention.py +154 -26
  8. ai_edge_torch/generative/layers/model_config.py +3 -0
  9. ai_edge_torch/generative/layers/unet/blocks_2d.py +473 -49
  10. ai_edge_torch/generative/layers/unet/builder.py +20 -2
  11. ai_edge_torch/generative/layers/unet/model_config.py +157 -5
  12. ai_edge_torch/generative/test/test_model_conversion.py +24 -0
  13. ai_edge_torch/generative/test/test_quantize.py +1 -0
  14. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +860 -0
  15. ai_edge_torch/generative/utilities/t5_loader.py +33 -17
  16. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/METADATA +1 -1
  17. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/RECORD +20 -20
  18. ai_edge_torch/generative/utilities/autoencoder_loader.py +0 -298
  19. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/LICENSE +0 -0
  20. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/WHEEL +0 -0
  21. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/top_level.txt +0 -0
@@ -13,13 +13,15 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Optional
16
+ from typing import List, Optional, Tuple
17
17
 
18
18
  import torch
19
19
  from torch import nn
20
20
 
21
+ from ai_edge_torch.generative.layers.attention import CrossAttention
21
22
  from ai_edge_torch.generative.layers.attention import SelfAttention
22
23
  import ai_edge_torch.generative.layers.builder as layers_builder
24
+ import ai_edge_torch.generative.layers.model_config as layers_cfg
23
25
  import ai_edge_torch.generative.layers.unet.builder as unet_builder
24
26
  import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
25
27
 
@@ -78,6 +80,7 @@ class ResidualBlock2D(nn.Module):
78
80
  x = self.act_fn(x)
79
81
  x = self.conv_1(x)
80
82
  if self.time_emb_proj is not None:
83
+ time_emb = self.act_fn(time_emb)
81
84
  time_emb = self.time_emb_proj(time_emb)[:, :, None, None]
82
85
  x = x + time_emb
83
86
  x = self.norm_2(x)
@@ -90,7 +93,7 @@ class ResidualBlock2D(nn.Module):
90
93
  class AttentionBlock2D(nn.Module):
91
94
  """2D self attention block
92
95
 
93
- x = SelfAttention(Norm(input_tensor))
96
+ x = SelfAttention(Norm(input_tensor)) + x
94
97
 
95
98
  """
96
99
 
@@ -101,8 +104,15 @@ class AttentionBlock2D(nn.Module):
101
104
  config (unet_cfg.AttentionBlock2DConfig): the configuration of this block.
102
105
  """
103
106
  super().__init__()
104
- self.norm = layers_builder.build_norm(config.dims, config.normalization_config)
105
- self.attention = SelfAttention(config.dims, config.attention_config, 0, True)
107
+ self.config = config
108
+ self.norm = layers_builder.build_norm(config.dim, config.normalization_config)
109
+ self.attention = SelfAttention(
110
+ config.attention_batch_size,
111
+ config.dim,
112
+ config.attention_config,
113
+ 0,
114
+ enable_hlfb=config.enable_hlfb,
115
+ )
106
116
 
107
117
  def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
108
118
  """Forward function of the AttentionBlock2D.
@@ -114,10 +124,17 @@ class AttentionBlock2D(nn.Module):
114
124
  output activation tensor after self attention.
115
125
  """
116
126
  residual = input_tensor
117
- x = self.norm(input_tensor)
118
- B, C, H, W = x.shape
119
- x = x.view(B, C, H * W)
120
- x = x.transpose(-1, -2)
127
+ B, C, H, W = input_tensor.shape
128
+ x = input_tensor
129
+ if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
130
+ x = self.norm(x)
131
+ x = input_tensor.view(B, C, H * W)
132
+ x = x.transpose(-1, -2)
133
+ else:
134
+ x = input_tensor.view(B, C, H * W)
135
+ x = x.transpose(-1, -2)
136
+ x = self.norm(x)
137
+ x = x.contiguous() # Prevent BATCH_MATMUL op in converted tflite.
121
138
  x = self.attention(x)
122
139
  x = x.transpose(-1, -2)
123
140
  x = x.view(B, C, H, W)
@@ -125,28 +142,306 @@ class AttentionBlock2D(nn.Module):
125
142
  return x
126
143
 
127
144
 
128
- class UpDecoderBlock2D(nn.Module):
129
- """Decoder block containing several residual blocks followed by an optional upsampler.
145
+ class CrossAttentionBlock2D(nn.Module):
146
+ """2D cross attention block
130
147
 
131
- input_tensor
132
- |
133
-
134
- ┌───────────────────┐
135
- │ ResidualBlock2D │ num_layers
136
- └─────────┬─────────┘
148
+ x = CrossAttention(Norm(input_tensor), context) + x
149
+
150
+ """
151
+
152
+ def __init__(self, config: unet_cfg.CrossAttentionBlock2DConfig):
153
+ """Initialize an instance of the AttentionBlock2D.
154
+
155
+ Args:
156
+ config (unet_cfg.CrossAttentionBlock2DConfig): the configuration of this block.
157
+ """
158
+ super().__init__()
159
+ self.config = config
160
+ self.norm = layers_builder.build_norm(config.query_dim, config.normalization_config)
161
+ self.attention = CrossAttention(
162
+ config.attention_batch_size,
163
+ config.query_dim,
164
+ config.cross_dim,
165
+ config.attention_config,
166
+ 0,
167
+ enable_hlfb=config.enable_hlfb,
168
+ )
169
+
170
+ def forward(
171
+ self, input_tensor: torch.Tensor, context_tensor: torch.Tensor
172
+ ) -> torch.Tensor:
173
+ """Forward function of the CrossAttentionBlock2D.
174
+
175
+ Args:
176
+ input_tensor (torch.Tensor): the input tensor.
177
+ context_tensor (torch.Tensor): the context tensor to apply cross attention on.
178
+
179
+ Returns:
180
+ output activation tensor after cross attention.
181
+ """
182
+ residual = input_tensor
183
+ B, C, H, W = input_tensor.shape
184
+ x = input_tensor
185
+ if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
186
+ x = self.norm(x)
187
+ x = input_tensor.view(B, C, H * W)
188
+ x = x.transpose(-1, -2)
189
+ else:
190
+ x = input_tensor.view(B, C, H * W)
191
+ x = x.transpose(-1, -2)
192
+ x = self.norm(x)
193
+ x = self.attention(x, context_tensor)
194
+ x = x.transpose(-1, -2)
195
+ x = x.view(B, C, H, W)
196
+ x = x + residual
197
+ return x
198
+
199
+
200
+ class FeedForwardBlock2D(nn.Module):
201
+ """2D feed forward block
202
+
203
+ x = w2(Activation(w1(Norm(x)))) + x
204
+
205
+ """
206
+
207
+ def __init__(
208
+ self,
209
+ config: unet_cfg.FeedForwardBlock2DConfig,
210
+ ):
211
+ super().__init__()
212
+ self.config = config
213
+ self.act = layers_builder.get_activation(config.activation_config)
214
+ self.norm = layers_builder.build_norm(config.dim, config.normalization_config)
215
+ if config.activation_config.type == layers_cfg.ActivationType.GE_GLU:
216
+ self.w1 = nn.Identity()
217
+ self.w2 = nn.Linear(config.hidden_dim, config.dim)
218
+ else:
219
+ self.w1 = nn.Linear(config.dim, config.hidden_dim)
220
+ self.w2 = nn.Linear(config.hidden_dim, config.dim)
221
+
222
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
223
+ residual = input_tensor
224
+ B, C, H, W = input_tensor.shape
225
+ x = input_tensor
226
+ if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
227
+ x = self.norm(x)
228
+ x = input_tensor.view(B, C, H * W)
229
+ x = x.transpose(-1, -2)
230
+ else:
231
+ x = input_tensor.view(B, C, H * W)
232
+ x = x.transpose(-1, -2)
233
+ x = self.norm(x)
234
+ x = self.w1(x)
235
+ x = self.act(x)
236
+ x = self.w2(x)
237
+
238
+ x = x.transpose(-1, -2) # (B, C, HW)
239
+ x = x.view((B, C, H, W))
240
+
241
+ return x + residual
242
+
243
+
244
+ class TransformerBlock2D(nn.Module):
245
+ """Basic transformer block used in UNet of diffusion model
246
+
247
+ input_tensor context_tensor
248
+ | |
249
+ ┌─────────▼─────────┐ |
250
+ │ ConvIn | │
251
+ └─────────┬─────────┘ |
252
+ | |
253
+ ▼ |
254
+ ┌───────────────────┐ |
255
+ │ Attention Block │ |
256
+ └─────────┬─────────┘ |
257
+ │ |
258
+ ┌────────────────────┐ |
259
+ │CrossAttention Block│◄─────┘
260
+ └─────────┬──────────┘
137
261
 
138
262
  ┌─────────▼─────────┐
139
- (Optional)
140
- │ Upsampler │
263
+ FeedForwardBlock
141
264
  └─────────┬─────────┘
142
265
 
143
266
  ┌─────────▼─────────┐
144
- (Optional)
145
- │ Conv2D │
267
+ ConvOut
146
268
  └─────────┬─────────┘
147
-
148
269
 
149
270
  hidden_states
271
+
272
+
273
+ """
274
+
275
+ def __init__(self, config: unet_cfg.TransformerBlock2Dconfig):
276
+ """Initialize an instance of the TransformerBlock2D.
277
+
278
+ Args:
279
+ config (unet_cfg.TransformerBlock2Dconfig): the configuration of this block.
280
+ """
281
+ super().__init__()
282
+ self.config = config
283
+ self.pre_conv_norm = layers_builder.build_norm(
284
+ config.attention_block_config.dim, config.pre_conv_normalization_config
285
+ )
286
+ self.conv_in = nn.Conv2d(
287
+ config.attention_block_config.dim,
288
+ config.attention_block_config.dim,
289
+ kernel_size=1,
290
+ padding=0,
291
+ )
292
+ self.self_attention = AttentionBlock2D(config.attention_block_config)
293
+ self.cross_attention = CrossAttentionBlock2D(config.cross_attention_block_config)
294
+ self.feed_forward = FeedForwardBlock2D(config.feed_forward_block_config)
295
+ self.conv_out = nn.Conv2d(
296
+ config.attention_block_config.dim,
297
+ config.attention_block_config.dim,
298
+ kernel_size=1,
299
+ padding=0,
300
+ )
301
+
302
+ def forward(self, x: torch.Tensor, context: torch.Tensor):
303
+ """Forward function of the TransformerBlock2D.
304
+
305
+ Args:
306
+ input_tensor (torch.Tensor): the input tensor.
307
+ context_tensor (torch.Tensor): the context tensor to apply cross attention on.
308
+
309
+ Returns:
310
+ output activation tensor after transformer block.
311
+ """
312
+ residual_long = x
313
+
314
+ x = self.pre_conv_norm(x)
315
+ x = self.conv_in(x)
316
+ x = self.self_attention(x)
317
+ x = self.cross_attention(x, context)
318
+ x = self.feed_forward(x)
319
+
320
+ x = self.conv_out(x)
321
+ x = x + residual_long
322
+
323
+ return x
324
+
325
+
326
+ class DownEncoderBlock2D(nn.Module):
327
+ """Encoder block containing several residual blocks with optional interleaved transformer blocks.
328
+
329
+ input_tensor
330
+ |
331
+ ┌──────────────▼─────────────┐
332
+ │ ┌────────────────────┐ │
333
+ │ │ ResidualBlock2D │ │
334
+ │ └──────────┬─────────┘ │
335
+ │ │ │ num_layers
336
+ │ ┌────────────────────┐ │
337
+ │ │ (Optional) │ │
338
+ │ │ TransformerBlock2D │ │
339
+ │ └──────────┬─────────┘ │
340
+ └──────────────┬─────────────┘
341
+
342
+ ┌──────────▼─────────┐
343
+ │ (Optional) │
344
+ │ Downsampler │
345
+ └──────────┬─────────┘
346
+
347
+
348
+ hidden_states
349
+ """
350
+
351
+ def __init__(self, config: unet_cfg.DownEncoderBlock2DConfig):
352
+ """Initialize an instance of the DownEncoderBlock2D.
353
+
354
+ Args:
355
+ config (unet_cfg.DownEncoderBlock2DConfig): the configuration of this block.
356
+ """
357
+ super().__init__()
358
+ self.config = config
359
+ resnets = []
360
+ transformers = []
361
+ for i in range(config.num_layers):
362
+ input_channels = config.in_channels if i == 0 else config.out_channels
363
+ resnets.append(
364
+ ResidualBlock2D(
365
+ unet_cfg.ResidualBlock2DConfig(
366
+ in_channels=input_channels,
367
+ out_channels=config.out_channels,
368
+ time_embedding_channels=config.time_embedding_channels,
369
+ normalization_config=config.normalization_config,
370
+ activation_config=config.activation_config,
371
+ )
372
+ )
373
+ )
374
+ if config.transformer_block_config:
375
+ transformers.append(TransformerBlock2D(config.transformer_block_config))
376
+ self.resnets = nn.ModuleList(resnets)
377
+ self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
378
+ if config.add_downsample:
379
+ self.downsampler = unet_builder.build_downsampling(config.sampling_config)
380
+ else:
381
+ self.downsampler = None
382
+
383
+ def forward(
384
+ self,
385
+ input_tensor: torch.Tensor,
386
+ time_emb: Optional[torch.Tensor] = None,
387
+ context_tensor: Optional[torch.Tensor] = None,
388
+ output_hidden_states: bool = False,
389
+ ) -> torch.Tensor | Tuple[torch.Tensor, List[torch.Tensor]]:
390
+ """Forward function of the DownEncoderBlock2D.
391
+
392
+ Args:
393
+ input_tensor (torch.Tensor): the input tensor.
394
+ time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
395
+ time embedding.
396
+ context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block.
397
+ output_hidden_states (bool): whether to output hidden states, usually for skip connections.
398
+ Returns:
399
+ output hidden_states tensor after DownEncoderBlock2D.
400
+ """
401
+ hidden_states = input_tensor
402
+ output_states = []
403
+ for i, resnet in enumerate(self.resnets):
404
+ hidden_states = resnet(hidden_states, time_emb)
405
+ if self.transformers is not None:
406
+ hidden_states = self.transformers[i](hidden_states, context_tensor)
407
+ output_states.append(hidden_states)
408
+ if self.downsampler:
409
+ hidden_states = self.downsampler(hidden_states)
410
+ output_states.append(hidden_states)
411
+ if output_hidden_states:
412
+ return hidden_states, output_states
413
+ else:
414
+ return hidden_states
415
+
416
+
417
+ class UpDecoderBlock2D(nn.Module):
418
+ """Decoder block containing several residual blocks with optional interleaved transformer blocks.
419
+
420
+ input_tensor
421
+ |
422
+ ┌──────────────▼─────────────┐
423
+ │ ┌────────────────────┐ │
424
+ │ │ ResidualBlock2D │ │
425
+ │ └──────────┬─────────┘ │
426
+ │ │ │ num_layers
427
+ │ ┌────────────────────┐ │
428
+ │ │ (Optional) │ │
429
+ │ │ TransformerBlock2D │ │
430
+ │ └──────────┬─────────┘ │
431
+ └──────────────┬─────────────┘
432
+
433
+ ┌──────────▼─────────┐
434
+ │ (Optional) │
435
+ │ Upsampler │
436
+ └──────────┬─────────┘
437
+
438
+ ┌──────────▼─────────┐
439
+ │ (Optional) │
440
+ │ Conv2D │
441
+ └──────────┬─────────┘
442
+
443
+
444
+ hidden_states
150
445
  """
151
446
 
152
447
  def __init__(self, config: unet_cfg.UpDecoderBlock2DConfig):
@@ -158,6 +453,7 @@ class UpDecoderBlock2D(nn.Module):
158
453
  super().__init__()
159
454
  self.config = config
160
455
  resnets = []
456
+ transformers = []
161
457
  for i in range(config.num_layers):
162
458
  input_channels = config.in_channels if i == 0 else config.out_channels
163
459
  resnets.append(
@@ -171,7 +467,10 @@ class UpDecoderBlock2D(nn.Module):
171
467
  )
172
468
  )
173
469
  )
470
+ if config.transformer_block_config:
471
+ transformers.append(TransformerBlock2D(config.transformer_block_config))
174
472
  self.resnets = nn.ModuleList(resnets)
473
+ self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
175
474
  if config.add_upsample:
176
475
  self.upsampler = unet_builder.build_upsampling(config.sampling_config)
177
476
  if config.upsample_conv:
@@ -182,21 +481,130 @@ class UpDecoderBlock2D(nn.Module):
182
481
  self.upsampler = None
183
482
 
184
483
  def forward(
185
- self, input_tensor: torch.Tensor, time_emb: Optional[torch.Tensor] = None
484
+ self,
485
+ input_tensor: torch.Tensor,
486
+ time_emb: Optional[torch.Tensor] = None,
487
+ context_tensor: Optional[torch.Tensor] = None,
186
488
  ) -> torch.Tensor:
187
489
  """Forward function of the UpDecoderBlock2D.
188
490
 
189
491
  Args:
190
492
  input_tensor (torch.Tensor): the input tensor.
191
493
  time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
192
- time embedding context.
494
+ time embedding.
495
+ context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block.
193
496
 
194
497
  Returns:
195
498
  output hidden_states tensor after UpDecoderBlock2D.
196
499
  """
197
500
  hidden_states = input_tensor
198
- for resnet in self.resnets:
501
+ for i, resnet in enumerate(self.resnets):
199
502
  hidden_states = resnet(hidden_states, time_emb)
503
+ if self.transformers is not None:
504
+ hidden_states = self.transformers[i](hidden_states, context_tensor)
505
+ if self.upsampler:
506
+ hidden_states = self.upsampler(hidden_states)
507
+ if self.upsample_conv:
508
+ hidden_states = self.upsample_conv(hidden_states)
509
+ return hidden_states
510
+
511
+
512
+ class SkipUpDecoderBlock2D(nn.Module):
513
+ """Decoder block contains skip connections and residual blocks with optional interleaved transformer blocks.
514
+
515
+ input_tensor, skip_connection_tensors
516
+ |
517
+ ┌──────────────▼─────────────┐
518
+ │ ┌────────────────────┐ │
519
+ │ │ ResidualBlock2D │ │
520
+ │ └──────────┬─────────┘ │
521
+ │ │ │ num_layers
522
+ │ ┌────────────────────┐ │
523
+ │ │ (Optional) │ │
524
+ │ │ TransformerBlock2D │ │
525
+ │ └──────────┬─────────┘ │
526
+ └──────────────┬─────────────┘
527
+
528
+ ┌──────────▼─────────┐
529
+ │ (Optional) │
530
+ │ Upsampler │
531
+ └──────────┬─────────┘
532
+
533
+ ┌──────────▼─────────┐
534
+ │ (Optional) │
535
+ │ Conv2D │
536
+ └──────────┬─────────┘
537
+
538
+
539
+ hidden_states
540
+ """
541
+
542
+ def __init__(self, config: unet_cfg.SkipUpDecoderBlock2DConfig):
543
+ """Initialize an instance of the SkipUpDecoderBlock2D.
544
+
545
+ Args:
546
+ config (unet_cfg.SkipUpDecoderBlock2DConfig): the configuration of this block.
547
+ """
548
+ super().__init__()
549
+ self.config = config
550
+ resnets = []
551
+ transformers = []
552
+ for i in range(config.num_layers):
553
+ res_skip_channels = (
554
+ config.in_channels if (i == config.num_layers - 1) else config.out_channels
555
+ )
556
+ resnet_in_channels = config.prev_out_channels if i == 0 else config.out_channels
557
+ resnets.append(
558
+ ResidualBlock2D(
559
+ unet_cfg.ResidualBlock2DConfig(
560
+ in_channels=resnet_in_channels + res_skip_channels,
561
+ out_channels=config.out_channels,
562
+ time_embedding_channels=config.time_embedding_channels,
563
+ normalization_config=config.normalization_config,
564
+ activation_config=config.activation_config,
565
+ )
566
+ )
567
+ )
568
+ if config.transformer_block_config:
569
+ transformers.append(TransformerBlock2D(config.transformer_block_config))
570
+ self.resnets = nn.ModuleList(resnets)
571
+ self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
572
+ if config.add_upsample:
573
+ self.upsampler = unet_builder.build_upsampling(config.sampling_config)
574
+ if config.upsample_conv:
575
+ self.upsample_conv = nn.Conv2d(
576
+ config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
577
+ )
578
+ else:
579
+ self.upsampler = None
580
+
581
+ def forward(
582
+ self,
583
+ input_tensor: torch.Tensor,
584
+ skip_connection_tensors: List[torch.Tensor],
585
+ time_emb: Optional[torch.Tensor] = None,
586
+ context_tensor: Optional[torch.Tensor] = None,
587
+ ) -> torch.Tensor:
588
+ """Forward function of the SkipUpDecoderBlock2D.
589
+
590
+ Args:
591
+ input_tensor (torch.Tensor): the input tensor.
592
+ skip_connection_tensors (List[torch.Tensor]): the skip connection tensors from encoder blocks.
593
+ time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
594
+ time embedding.
595
+ context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block.
596
+
597
+ Returns:
598
+ output hidden_states tensor after SkipUpDecoderBlock2D.
599
+ """
600
+ hidden_states = input_tensor
601
+ for i, (resnet, skip_connection_tensor) in enumerate(
602
+ zip(self.resnets, skip_connection_tensors)
603
+ ):
604
+ hidden_states = torch.cat([hidden_states, skip_connection_tensor], dim=1)
605
+ hidden_states = resnet(hidden_states, time_emb)
606
+ if self.transformers is not None:
607
+ hidden_states = self.transformers[i](hidden_states, context_tensor)
200
608
  if self.upsampler:
201
609
  hidden_states = self.upsampler(hidden_states)
202
610
  if self.upsample_conv:
@@ -207,25 +615,30 @@ class UpDecoderBlock2D(nn.Module):
207
615
  class MidBlock2D(nn.Module):
208
616
  """Middle block containing at least one residual blocks with optional interleaved attention blocks.
209
617
 
210
- input_tensor
211
- |
212
-
213
- ┌───────────────────┐
214
- │ ResidualBlock2D │
215
- └─────────┬─────────┘
216
-
217
- ┌─────────────▼─────────────┐
218
- ┌───────────────────┐
219
- │ │ (Optional) │ │
220
- │ │ AttentionBlock2D │ │
221
- └─────────┬─────────┘ num_layers
222
- │ │
223
- ┌─────────▼─────────┐
224
- │ │ ResidualBlock2D │ │
225
- └───────────────────┘
226
- └─────────────┬─────────────┘
227
-
228
-
618
+ input_tensor
619
+ |
620
+
621
+ ┌───────────────────┐
622
+ │ ResidualBlock2D │
623
+ └─────────┬─────────┘
624
+
625
+ ┌──────────────▼─────────────┐
626
+ ┌────────────────────┐
627
+ │ │ (Optional) │ │
628
+ │ │ AttentionBlock2D │ │
629
+ └──────────┬─────────┘
630
+ │ │
631
+ ┌──────────▼─────────┐
632
+ │ │ (Optional) │ │ num_layers
633
+ │ TransformerBlock2D │
634
+ │ └──────────┬─────────┘ │
635
+ │ │
636
+ │ ┌──────────▼─────────┐ │
637
+ │ │ ResidualBlock2D │ │
638
+ │ └────────────────────┘ │
639
+ └──────────────┬─────────────┘
640
+
641
+
229
642
  hidden_states
230
643
  """
231
644
 
@@ -249,9 +662,12 @@ class MidBlock2D(nn.Module):
249
662
  )
250
663
  ]
251
664
  attentions = []
665
+ transformers = []
252
666
  for i in range(config.num_layers):
253
667
  if self.config.attention_block_config:
254
668
  attentions.append(AttentionBlock2D(config.attention_block_config))
669
+ if self.config.transformer_block_config:
670
+ transformers.append(TransformerBlock2D(config.transformer_block_config))
255
671
  resnets.append(
256
672
  ResidualBlock2D(
257
673
  unet_cfg.ResidualBlock2DConfig(
@@ -264,24 +680,32 @@ class MidBlock2D(nn.Module):
264
680
  )
265
681
  )
266
682
  self.resnets = nn.ModuleList(resnets)
267
- self.attentions = nn.ModuleList(attentions)
683
+ self.attentions = nn.ModuleList(attentions) if len(attentions) > 0 else None
684
+ self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
268
685
 
269
686
  def forward(
270
- self, input_tensor: torch.Tensor, time_emb: Optional[torch.Tensor] = None
687
+ self,
688
+ input_tensor: torch.Tensor,
689
+ time_emb: Optional[torch.Tensor] = None,
690
+ context_tensor: Optional[torch.Tensor] = None,
271
691
  ) -> torch.Tensor:
272
692
  """Forward function of the MidBlock2D.
273
693
 
274
694
  Args:
275
695
  input_tensor (torch.Tensor): the input tensor.
276
696
  time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
277
- time embedding context.
697
+ time embedding.
698
+ context_tensor (torch.Tensor): optional context tensor, if the block if configured to use
699
+ transofrmer block.
278
700
 
279
701
  Returns:
280
702
  output hidden_states tensor after MidBlock2D.
281
703
  """
282
704
  hidden_states = self.resnets[0](input_tensor, time_emb)
283
- for attn, resnet in zip(self.attentions, self.resnets[1:]):
284
- if attn is not None:
285
- hidden_states = attn(hidden_states)
705
+ for i, resnet in enumerate(self.resnets[1:]):
706
+ if self.attentions is not None:
707
+ hidden_states = self.attentions[i](hidden_states)
708
+ if self.transformers is not None:
709
+ hidden_states = self.transformers[i](hidden_states, context_tensor)
286
710
  hidden_states = resnet(hidden_states, time_emb)
287
711
  return hidden_states
@@ -15,15 +15,33 @@
15
15
  # Builder utils for individual components.
16
16
 
17
17
  from torch import nn
18
- import torch.nn.functional as F
19
18
 
20
19
  import ai_edge_torch.generative.layers.unet.model_config as unet_config
21
20
 
22
21
 
23
- def build_upsampling(config: unet_config.SamplingConfig):
22
+ def build_upsampling(config: unet_config.UpSamplingConfig):
24
23
  if config.mode == unet_config.SamplingType.NEAREST:
25
24
  return nn.UpsamplingNearest2d(scale_factor=config.scale_factor)
26
25
  elif config.mode == unet_config.SamplingType.BILINEAR:
27
26
  return nn.UpsamplingBilinear2d(scale_factor=config.scale_factor)
28
27
  else:
29
28
  raise ValueError("Unsupported upsampling type.")
29
+
30
+
31
+ def build_downsampling(config: unet_config.DownSamplingConfig):
32
+ if config.mode == unet_config.SamplingType.AVERAGE:
33
+ return nn.AvgPool2d(config.kernel_size, config.stride, padding=config.padding)
34
+ elif config.mode == unet_config.SamplingType.CONVOLUTION:
35
+ out_channels = (
36
+ config.in_channels if config.out_channels is None else config.out_channels
37
+ )
38
+ padding = (0, 1, 0, 1) if config.padding == 0 else config.padding
39
+ return nn.Conv2d(
40
+ config.in_channels,
41
+ out_channels=out_channels,
42
+ kernel_size=config.kernel_size,
43
+ stride=config.stride,
44
+ padding=padding,
45
+ )
46
+ else:
47
+ raise ValueError("Unsupported downsampling type.")