ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (121) hide show
  1. ai_edge_torch/__init__.py +31 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +400 -0
  5. ai_edge_torch/convert/converter.py +202 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
  9. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +311 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +192 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
  27. ai_edge_torch/convert/to_channel_last_io.py +85 -0
  28. ai_edge_torch/debug/__init__.py +17 -0
  29. ai_edge_torch/debug/culprit.py +464 -0
  30. ai_edge_torch/debug/test/__init__.py +14 -0
  31. ai_edge_torch/debug/test/test_culprit.py +133 -0
  32. ai_edge_torch/debug/test/test_search_model.py +50 -0
  33. ai_edge_torch/debug/utils.py +48 -0
  34. ai_edge_torch/experimental/__init__.py +14 -0
  35. ai_edge_torch/generative/__init__.py +14 -0
  36. ai_edge_torch/generative/examples/__init__.py +14 -0
  37. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  39. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  40. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  42. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  44. ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
  45. ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
  46. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
  47. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
  48. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
  49. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
  50. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
  51. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  52. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
  54. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
  55. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
  56. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
  57. ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
  58. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  59. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  60. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  61. ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
  62. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  63. ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
  64. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
  65. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  66. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  67. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  68. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  69. ai_edge_torch/generative/fx_passes/__init__.py +31 -0
  70. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
  71. ai_edge_torch/generative/layers/__init__.py +14 -0
  72. ai_edge_torch/generative/layers/attention.py +354 -0
  73. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  74. ai_edge_torch/generative/layers/builder.py +131 -0
  75. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  76. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  77. ai_edge_torch/generative/layers/model_config.py +158 -0
  78. ai_edge_torch/generative/layers/normalization.py +62 -0
  79. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  80. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
  81. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  82. ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
  83. ai_edge_torch/generative/layers/unet/builder.py +47 -0
  84. ai_edge_torch/generative/layers/unet/model_config.py +269 -0
  85. ai_edge_torch/generative/quantize/__init__.py +14 -0
  86. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  87. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
  88. ai_edge_torch/generative/quantize/example.py +45 -0
  89. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  90. ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
  91. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  92. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  93. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  94. ai_edge_torch/generative/test/__init__.py +14 -0
  95. ai_edge_torch/generative/test/loader_test.py +80 -0
  96. ai_edge_torch/generative/test/test_model_conversion.py +235 -0
  97. ai_edge_torch/generative/test/test_quantize.py +162 -0
  98. ai_edge_torch/generative/utilities/__init__.py +15 -0
  99. ai_edge_torch/generative/utilities/loader.py +328 -0
  100. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
  101. ai_edge_torch/generative/utilities/t5_loader.py +483 -0
  102. ai_edge_torch/hlfb/__init__.py +16 -0
  103. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  104. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  105. ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
  106. ai_edge_torch/hlfb/test/__init__.py +14 -0
  107. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  108. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  109. ai_edge_torch/model.py +142 -0
  110. ai_edge_torch/quantize/__init__.py +16 -0
  111. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  112. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  113. ai_edge_torch/quantize/quant_config.py +81 -0
  114. ai_edge_torch/testing/__init__.py +14 -0
  115. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  116. ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
  117. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
  118. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
  119. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
  120. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
  121. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
@@ -0,0 +1,711 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import List, Optional, Tuple
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ from ai_edge_torch.generative.layers.attention import CrossAttention
22
+ from ai_edge_torch.generative.layers.attention import SelfAttention
23
+ import ai_edge_torch.generative.layers.builder as layers_builder
24
+ import ai_edge_torch.generative.layers.model_config as layers_cfg
25
+ import ai_edge_torch.generative.layers.unet.builder as unet_builder
26
+ import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
27
+
28
+
29
+ class ResidualBlock2D(nn.Module):
30
+ """2D Residual block containing two Conv2D with optional time embedding as input."""
31
+
32
+ def __init__(self, config: unet_cfg.ResidualBlock2DConfig):
33
+ """Initialize an instance of the ResidualBlock2D.
34
+
35
+ Args:
36
+ config (unet_cfg.ResidualBlock2DConfig): the configuration of this block.
37
+ """
38
+ super().__init__()
39
+ self.config = config
40
+ self.norm_1 = layers_builder.build_norm(
41
+ config.in_channels, config.normalization_config
42
+ )
43
+ self.conv_1 = nn.Conv2d(
44
+ config.in_channels, config.out_channels, kernel_size=3, stride=1, padding=1
45
+ )
46
+ if config.time_embedding_channels is not None:
47
+ self.time_emb_proj = nn.Linear(
48
+ config.time_embedding_channels, config.out_channels
49
+ )
50
+ else:
51
+ self.time_emb_proj = None
52
+ self.norm_2 = layers_builder.build_norm(
53
+ config.out_channels, config.normalization_config
54
+ )
55
+ self.conv_2 = nn.Conv2d(
56
+ config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
57
+ )
58
+ self.act_fn = layers_builder.get_activation(config.activation_config)
59
+ if config.in_channels == config.out_channels:
60
+ self.residual_layer = nn.Identity()
61
+ else:
62
+ self.residual_layer = nn.Conv2d(
63
+ config.in_channels, config.out_channels, kernel_size=1, stride=1, padding=0
64
+ )
65
+
66
+ def forward(
67
+ self, input_tensor: torch.Tensor, time_emb: Optional[torch.Tensor] = None
68
+ ) -> torch.Tensor:
69
+ """Forward function of the ResidualBlock2D.
70
+
71
+ Args:
72
+ input_tensor (torch.Tensor): the input tensor.
73
+ time_emb (Optional[torch.Tensor]): optional time embedding tensor.
74
+
75
+ Returns:
76
+ output hidden_states tensor after ResidualBlock2D.
77
+ """
78
+ residual = input_tensor
79
+ x = self.norm_1(input_tensor)
80
+ x = self.act_fn(x)
81
+ x = self.conv_1(x)
82
+ if self.time_emb_proj is not None:
83
+ time_emb = self.act_fn(time_emb)
84
+ time_emb = self.time_emb_proj(time_emb)[:, :, None, None]
85
+ x = x + time_emb
86
+ x = self.norm_2(x)
87
+ x = self.act_fn(x)
88
+ x = self.conv_2(x)
89
+ x = x + self.residual_layer(residual)
90
+ return x
91
+
92
+
93
+ class AttentionBlock2D(nn.Module):
94
+ """2D self attention block
95
+
96
+ x = SelfAttention(Norm(input_tensor)) + x
97
+
98
+ """
99
+
100
+ def __init__(self, config: unet_cfg.AttentionBlock2DConfig):
101
+ """Initialize an instance of the AttentionBlock2D.
102
+
103
+ Args:
104
+ config (unet_cfg.AttentionBlock2DConfig): the configuration of this block.
105
+ """
106
+ super().__init__()
107
+ self.config = config
108
+ self.norm = layers_builder.build_norm(config.dim, config.normalization_config)
109
+ self.attention = SelfAttention(
110
+ config.attention_batch_size,
111
+ config.dim,
112
+ config.attention_config,
113
+ 0,
114
+ enable_hlfb=config.enable_hlfb,
115
+ )
116
+
117
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
118
+ """Forward function of the AttentionBlock2D.
119
+
120
+ Args:
121
+ input_tensor (torch.Tensor): the input tensor.
122
+
123
+ Returns:
124
+ output activation tensor after self attention.
125
+ """
126
+ residual = input_tensor
127
+ B, C, H, W = input_tensor.shape
128
+ x = input_tensor
129
+ if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
130
+ x = self.norm(x)
131
+ x = input_tensor.view(B, C, H * W)
132
+ x = x.transpose(-1, -2)
133
+ else:
134
+ x = input_tensor.view(B, C, H * W)
135
+ x = x.transpose(-1, -2)
136
+ x = self.norm(x)
137
+ x = x.contiguous() # Prevent BATCH_MATMUL op in converted tflite.
138
+ x = self.attention(x)
139
+ x = x.transpose(-1, -2)
140
+ x = x.view(B, C, H, W)
141
+ x = x + residual
142
+ return x
143
+
144
+
145
+ class CrossAttentionBlock2D(nn.Module):
146
+ """2D cross attention block
147
+
148
+ x = CrossAttention(Norm(input_tensor), context) + x
149
+
150
+ """
151
+
152
+ def __init__(self, config: unet_cfg.CrossAttentionBlock2DConfig):
153
+ """Initialize an instance of the AttentionBlock2D.
154
+
155
+ Args:
156
+ config (unet_cfg.CrossAttentionBlock2DConfig): the configuration of this block.
157
+ """
158
+ super().__init__()
159
+ self.config = config
160
+ self.norm = layers_builder.build_norm(config.query_dim, config.normalization_config)
161
+ self.attention = CrossAttention(
162
+ config.attention_batch_size,
163
+ config.query_dim,
164
+ config.cross_dim,
165
+ config.attention_config,
166
+ 0,
167
+ enable_hlfb=config.enable_hlfb,
168
+ )
169
+
170
+ def forward(
171
+ self, input_tensor: torch.Tensor, context_tensor: torch.Tensor
172
+ ) -> torch.Tensor:
173
+ """Forward function of the CrossAttentionBlock2D.
174
+
175
+ Args:
176
+ input_tensor (torch.Tensor): the input tensor.
177
+ context_tensor (torch.Tensor): the context tensor to apply cross attention on.
178
+
179
+ Returns:
180
+ output activation tensor after cross attention.
181
+ """
182
+ residual = input_tensor
183
+ B, C, H, W = input_tensor.shape
184
+ x = input_tensor
185
+ if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
186
+ x = self.norm(x)
187
+ x = input_tensor.view(B, C, H * W)
188
+ x = x.transpose(-1, -2)
189
+ else:
190
+ x = input_tensor.view(B, C, H * W)
191
+ x = x.transpose(-1, -2)
192
+ x = self.norm(x)
193
+ x = self.attention(x, context_tensor)
194
+ x = x.transpose(-1, -2)
195
+ x = x.view(B, C, H, W)
196
+ x = x + residual
197
+ return x
198
+
199
+
200
+ class FeedForwardBlock2D(nn.Module):
201
+ """2D feed forward block
202
+
203
+ x = w2(Activation(w1(Norm(x)))) + x
204
+
205
+ """
206
+
207
+ def __init__(
208
+ self,
209
+ config: unet_cfg.FeedForwardBlock2DConfig,
210
+ ):
211
+ super().__init__()
212
+ self.config = config
213
+ self.act = layers_builder.get_activation(config.activation_config)
214
+ self.norm = layers_builder.build_norm(config.dim, config.normalization_config)
215
+ if config.activation_config.type == layers_cfg.ActivationType.GE_GLU:
216
+ self.w1 = nn.Identity()
217
+ self.w2 = nn.Linear(config.hidden_dim, config.dim)
218
+ else:
219
+ self.w1 = nn.Linear(config.dim, config.hidden_dim)
220
+ self.w2 = nn.Linear(config.hidden_dim, config.dim)
221
+
222
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
223
+ residual = input_tensor
224
+ B, C, H, W = input_tensor.shape
225
+ x = input_tensor
226
+ if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
227
+ x = self.norm(x)
228
+ x = input_tensor.view(B, C, H * W)
229
+ x = x.transpose(-1, -2)
230
+ else:
231
+ x = input_tensor.view(B, C, H * W)
232
+ x = x.transpose(-1, -2)
233
+ x = self.norm(x)
234
+ x = self.w1(x)
235
+ x = self.act(x)
236
+ x = self.w2(x)
237
+
238
+ x = x.transpose(-1, -2) # (B, C, HW)
239
+ x = x.view((B, C, H, W))
240
+
241
+ return x + residual
242
+
243
+
244
+ class TransformerBlock2D(nn.Module):
245
+ """Basic transformer block used in UNet of diffusion model
246
+
247
+ input_tensor context_tensor
248
+ | |
249
+ ┌─────────▼─────────┐ |
250
+ │ ConvIn | │
251
+ └─────────┬─────────┘ |
252
+ | |
253
+ ▼ |
254
+ ┌───────────────────┐ |
255
+ │ Attention Block │ |
256
+ └─────────┬─────────┘ |
257
+ │ |
258
+ ┌────────────────────┐ |
259
+ │CrossAttention Block│◄─────┘
260
+ └─────────┬──────────┘
261
+
262
+ ┌─────────▼─────────┐
263
+ │ FeedForwardBlock │
264
+ └─────────┬─────────┘
265
+
266
+ ┌─────────▼─────────┐
267
+ │ ConvOut │
268
+ └─────────┬─────────┘
269
+
270
+ hidden_states
271
+
272
+
273
+ """
274
+
275
+ def __init__(self, config: unet_cfg.TransformerBlock2DConfig):
276
+ """Initialize an instance of the TransformerBlock2D.
277
+
278
+ Args:
279
+ config (unet_cfg.TransformerBlock2Dconfig): the configuration of this block.
280
+ """
281
+ super().__init__()
282
+ self.config = config
283
+ self.pre_conv_norm = layers_builder.build_norm(
284
+ config.attention_block_config.dim, config.pre_conv_normalization_config
285
+ )
286
+ self.conv_in = nn.Conv2d(
287
+ config.attention_block_config.dim,
288
+ config.attention_block_config.dim,
289
+ kernel_size=1,
290
+ padding=0,
291
+ )
292
+ self.self_attention = AttentionBlock2D(config.attention_block_config)
293
+ self.cross_attention = CrossAttentionBlock2D(config.cross_attention_block_config)
294
+ self.feed_forward = FeedForwardBlock2D(config.feed_forward_block_config)
295
+ self.conv_out = nn.Conv2d(
296
+ config.attention_block_config.dim,
297
+ config.attention_block_config.dim,
298
+ kernel_size=1,
299
+ padding=0,
300
+ )
301
+
302
+ def forward(self, x: torch.Tensor, context: torch.Tensor):
303
+ """Forward function of the TransformerBlock2D.
304
+
305
+ Args:
306
+ input_tensor (torch.Tensor): the input tensor.
307
+ context_tensor (torch.Tensor): the context tensor to apply cross attention on.
308
+
309
+ Returns:
310
+ output activation tensor after transformer block.
311
+ """
312
+ residual_long = x
313
+
314
+ x = self.pre_conv_norm(x)
315
+ x = self.conv_in(x)
316
+ x = self.self_attention(x)
317
+ x = self.cross_attention(x, context)
318
+ x = self.feed_forward(x)
319
+
320
+ x = self.conv_out(x)
321
+ x = x + residual_long
322
+
323
+ return x
324
+
325
+
326
+ class DownEncoderBlock2D(nn.Module):
327
+ """Encoder block containing several residual blocks with optional interleaved transformer blocks.
328
+
329
+ input_tensor
330
+ |
331
+ ┌──────────────▼─────────────┐
332
+ │ ┌────────────────────┐ │
333
+ │ │ ResidualBlock2D │ │
334
+ │ └──────────┬─────────┘ │
335
+ │ │ │ num_layers
336
+ │ ┌────────────────────┐ │
337
+ │ │ (Optional) │ │
338
+ │ │ TransformerBlock2D │ │
339
+ │ └──────────┬─────────┘ │
340
+ └──────────────┬─────────────┘
341
+
342
+ ┌──────────▼─────────┐
343
+ │ (Optional) │
344
+ │ Downsampler │
345
+ └──────────┬─────────┘
346
+
347
+
348
+ hidden_states
349
+ """
350
+
351
+ def __init__(self, config: unet_cfg.DownEncoderBlock2DConfig):
352
+ """Initialize an instance of the DownEncoderBlock2D.
353
+
354
+ Args:
355
+ config (unet_cfg.DownEncoderBlock2DConfig): the configuration of this block.
356
+ """
357
+ super().__init__()
358
+ self.config = config
359
+ resnets = []
360
+ transformers = []
361
+ for i in range(config.num_layers):
362
+ input_channels = config.in_channels if i == 0 else config.out_channels
363
+ resnets.append(
364
+ ResidualBlock2D(
365
+ unet_cfg.ResidualBlock2DConfig(
366
+ in_channels=input_channels,
367
+ out_channels=config.out_channels,
368
+ time_embedding_channels=config.time_embedding_channels,
369
+ normalization_config=config.normalization_config,
370
+ activation_config=config.activation_config,
371
+ )
372
+ )
373
+ )
374
+ if config.transformer_block_config:
375
+ transformers.append(TransformerBlock2D(config.transformer_block_config))
376
+ self.resnets = nn.ModuleList(resnets)
377
+ self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
378
+ if config.add_downsample:
379
+ self.downsampler = unet_builder.build_downsampling(config.sampling_config)
380
+ else:
381
+ self.downsampler = None
382
+
383
+ def forward(
384
+ self,
385
+ input_tensor: torch.Tensor,
386
+ time_emb: Optional[torch.Tensor] = None,
387
+ context_tensor: Optional[torch.Tensor] = None,
388
+ output_hidden_states: bool = False,
389
+ ) -> torch.Tensor | Tuple[torch.Tensor, List[torch.Tensor]]:
390
+ """Forward function of the DownEncoderBlock2D.
391
+
392
+ Args:
393
+ input_tensor (torch.Tensor): the input tensor.
394
+ time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
395
+ time embedding.
396
+ context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block.
397
+ output_hidden_states (bool): whether to output hidden states, usually for skip connections.
398
+ Returns:
399
+ output hidden_states tensor after DownEncoderBlock2D.
400
+ """
401
+ hidden_states = input_tensor
402
+ output_states = []
403
+ for i, resnet in enumerate(self.resnets):
404
+ hidden_states = resnet(hidden_states, time_emb)
405
+ if self.transformers is not None:
406
+ hidden_states = self.transformers[i](hidden_states, context_tensor)
407
+ output_states.append(hidden_states)
408
+ if self.downsampler:
409
+ hidden_states = self.downsampler(hidden_states)
410
+ output_states.append(hidden_states)
411
+ if output_hidden_states:
412
+ return hidden_states, output_states
413
+ else:
414
+ return hidden_states
415
+
416
+
417
+ class UpDecoderBlock2D(nn.Module):
418
+ """Decoder block containing several residual blocks with optional interleaved transformer blocks.
419
+
420
+ input_tensor
421
+ |
422
+ ┌──────────────▼─────────────┐
423
+ │ ┌────────────────────┐ │
424
+ │ │ ResidualBlock2D │ │
425
+ │ └──────────┬─────────┘ │
426
+ │ │ │ num_layers
427
+ │ ┌────────────────────┐ │
428
+ │ │ (Optional) │ │
429
+ │ │ TransformerBlock2D │ │
430
+ │ └──────────┬─────────┘ │
431
+ └──────────────┬─────────────┘
432
+
433
+ ┌──────────▼─────────┐
434
+ │ (Optional) │
435
+ │ Upsampler │
436
+ └──────────┬─────────┘
437
+
438
+ ┌──────────▼─────────┐
439
+ │ (Optional) │
440
+ │ Conv2D │
441
+ └──────────┬─────────┘
442
+
443
+
444
+ hidden_states
445
+ """
446
+
447
+ def __init__(self, config: unet_cfg.UpDecoderBlock2DConfig):
448
+ """Initialize an instance of the UpDecoderBlock2D.
449
+
450
+ Args:
451
+ config (unet_cfg.UpDecoderBlock2DConfig): the configuration of this block.
452
+ """
453
+ super().__init__()
454
+ self.config = config
455
+ resnets = []
456
+ transformers = []
457
+ for i in range(config.num_layers):
458
+ input_channels = config.in_channels if i == 0 else config.out_channels
459
+ resnets.append(
460
+ ResidualBlock2D(
461
+ unet_cfg.ResidualBlock2DConfig(
462
+ in_channels=input_channels,
463
+ out_channels=config.out_channels,
464
+ time_embedding_channels=config.time_embedding_channels,
465
+ normalization_config=config.normalization_config,
466
+ activation_config=config.activation_config,
467
+ )
468
+ )
469
+ )
470
+ if config.transformer_block_config:
471
+ transformers.append(TransformerBlock2D(config.transformer_block_config))
472
+ self.resnets = nn.ModuleList(resnets)
473
+ self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
474
+ if config.add_upsample:
475
+ self.upsampler = unet_builder.build_upsampling(config.sampling_config)
476
+ if config.upsample_conv:
477
+ self.upsample_conv = nn.Conv2d(
478
+ config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
479
+ )
480
+ else:
481
+ self.upsampler = None
482
+
483
+ def forward(
484
+ self,
485
+ input_tensor: torch.Tensor,
486
+ time_emb: Optional[torch.Tensor] = None,
487
+ context_tensor: Optional[torch.Tensor] = None,
488
+ ) -> torch.Tensor:
489
+ """Forward function of the UpDecoderBlock2D.
490
+
491
+ Args:
492
+ input_tensor (torch.Tensor): the input tensor.
493
+ time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
494
+ time embedding.
495
+ context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block.
496
+
497
+ Returns:
498
+ output hidden_states tensor after UpDecoderBlock2D.
499
+ """
500
+ hidden_states = input_tensor
501
+ for i, resnet in enumerate(self.resnets):
502
+ hidden_states = resnet(hidden_states, time_emb)
503
+ if self.transformers is not None:
504
+ hidden_states = self.transformers[i](hidden_states, context_tensor)
505
+ if self.upsampler:
506
+ hidden_states = self.upsampler(hidden_states)
507
+ if self.upsample_conv:
508
+ hidden_states = self.upsample_conv(hidden_states)
509
+ return hidden_states
510
+
511
+
512
+ class SkipUpDecoderBlock2D(nn.Module):
513
+ """Decoder block contains skip connections and residual blocks with optional interleaved transformer blocks.
514
+
515
+ input_tensor, skip_connection_tensors
516
+ |
517
+ ┌──────────────▼─────────────┐
518
+ │ ┌────────────────────┐ │
519
+ │ │ ResidualBlock2D │ │
520
+ │ └──────────┬─────────┘ │
521
+ │ │ │ num_layers
522
+ │ ┌────────────────────┐ │
523
+ │ │ (Optional) │ │
524
+ │ │ TransformerBlock2D │ │
525
+ │ └──────────┬─────────┘ │
526
+ └──────────────┬─────────────┘
527
+
528
+ ┌──────────▼─────────┐
529
+ │ (Optional) │
530
+ │ Upsampler │
531
+ └──────────┬─────────┘
532
+
533
+ ┌──────────▼─────────┐
534
+ │ (Optional) │
535
+ │ Conv2D │
536
+ └──────────┬─────────┘
537
+
538
+
539
+ hidden_states
540
+ """
541
+
542
+ def __init__(self, config: unet_cfg.SkipUpDecoderBlock2DConfig):
543
+ """Initialize an instance of the SkipUpDecoderBlock2D.
544
+
545
+ Args:
546
+ config (unet_cfg.SkipUpDecoderBlock2DConfig): the configuration of this block.
547
+ """
548
+ super().__init__()
549
+ self.config = config
550
+ resnets = []
551
+ transformers = []
552
+ for i in range(config.num_layers):
553
+ res_skip_channels = (
554
+ config.in_channels if (i == config.num_layers - 1) else config.out_channels
555
+ )
556
+ resnet_in_channels = config.prev_out_channels if i == 0 else config.out_channels
557
+ resnets.append(
558
+ ResidualBlock2D(
559
+ unet_cfg.ResidualBlock2DConfig(
560
+ in_channels=resnet_in_channels + res_skip_channels,
561
+ out_channels=config.out_channels,
562
+ time_embedding_channels=config.time_embedding_channels,
563
+ normalization_config=config.normalization_config,
564
+ activation_config=config.activation_config,
565
+ )
566
+ )
567
+ )
568
+ if config.transformer_block_config:
569
+ transformers.append(TransformerBlock2D(config.transformer_block_config))
570
+ self.resnets = nn.ModuleList(resnets)
571
+ self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
572
+ if config.add_upsample:
573
+ self.upsampler = unet_builder.build_upsampling(config.sampling_config)
574
+ if config.upsample_conv:
575
+ self.upsample_conv = nn.Conv2d(
576
+ config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
577
+ )
578
+ else:
579
+ self.upsampler = None
580
+
581
+ def forward(
582
+ self,
583
+ input_tensor: torch.Tensor,
584
+ skip_connection_tensors: List[torch.Tensor],
585
+ time_emb: Optional[torch.Tensor] = None,
586
+ context_tensor: Optional[torch.Tensor] = None,
587
+ ) -> torch.Tensor:
588
+ """Forward function of the SkipUpDecoderBlock2D.
589
+
590
+ Args:
591
+ input_tensor (torch.Tensor): the input tensor.
592
+ skip_connection_tensors (List[torch.Tensor]): the skip connection tensors from encoder blocks.
593
+ time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
594
+ time embedding.
595
+ context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block.
596
+
597
+ Returns:
598
+ output hidden_states tensor after SkipUpDecoderBlock2D.
599
+ """
600
+ hidden_states = input_tensor
601
+ for i, (resnet, skip_connection_tensor) in enumerate(
602
+ zip(self.resnets, skip_connection_tensors)
603
+ ):
604
+ hidden_states = torch.cat([hidden_states, skip_connection_tensor], dim=1)
605
+ hidden_states = resnet(hidden_states, time_emb)
606
+ if self.transformers is not None:
607
+ hidden_states = self.transformers[i](hidden_states, context_tensor)
608
+ if self.upsampler:
609
+ hidden_states = self.upsampler(hidden_states)
610
+ if self.upsample_conv:
611
+ hidden_states = self.upsample_conv(hidden_states)
612
+ return hidden_states
613
+
614
+
615
+ class MidBlock2D(nn.Module):
616
+ """Middle block containing at least one residual blocks with optional interleaved attention blocks.
617
+
618
+ input_tensor
619
+ |
620
+
621
+ ┌───────────────────┐
622
+ │ ResidualBlock2D │
623
+ └─────────┬─────────┘
624
+
625
+ ┌──────────────▼─────────────┐
626
+ │ ┌────────────────────┐ │
627
+ │ │ (Optional) │ │
628
+ │ │ AttentionBlock2D │ │
629
+ │ └──────────┬─────────┘ │
630
+ │ │ │
631
+ │ ┌──────────▼─────────┐ │
632
+ │ │ (Optional) │ │ num_layers
633
+ │ │ TransformerBlock2D │ │
634
+ │ └──────────┬─────────┘ │
635
+ │ │ │
636
+ │ ┌──────────▼─────────┐ │
637
+ │ │ ResidualBlock2D │ │
638
+ │ └────────────────────┘ │
639
+ └──────────────┬─────────────┘
640
+
641
+
642
+ hidden_states
643
+ """
644
+
645
+ def __init__(self, config: unet_cfg.MidBlock2DConfig):
646
+ """Initialize an instance of the MidBlock2D.
647
+
648
+ Args:
649
+ config (unet_cfg.MidBlock2DConfig): the configuration of this block.
650
+ """
651
+ super().__init__()
652
+ self.config = config
653
+ resnets = [
654
+ ResidualBlock2D(
655
+ unet_cfg.ResidualBlock2DConfig(
656
+ in_channels=config.in_channels,
657
+ out_channels=config.in_channels,
658
+ time_embedding_channels=config.time_embedding_channels,
659
+ normalization_config=config.normalization_config,
660
+ activation_config=config.activation_config,
661
+ )
662
+ )
663
+ ]
664
+ attentions = []
665
+ transformers = []
666
+ for i in range(config.num_layers):
667
+ if self.config.attention_block_config:
668
+ attentions.append(AttentionBlock2D(config.attention_block_config))
669
+ if self.config.transformer_block_config:
670
+ transformers.append(TransformerBlock2D(config.transformer_block_config))
671
+ resnets.append(
672
+ ResidualBlock2D(
673
+ unet_cfg.ResidualBlock2DConfig(
674
+ in_channels=config.in_channels,
675
+ out_channels=config.in_channels,
676
+ time_embedding_channels=config.time_embedding_channels,
677
+ normalization_config=config.normalization_config,
678
+ activation_config=config.activation_config,
679
+ )
680
+ )
681
+ )
682
+ self.resnets = nn.ModuleList(resnets)
683
+ self.attentions = nn.ModuleList(attentions) if len(attentions) > 0 else None
684
+ self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
685
+
686
+ def forward(
687
+ self,
688
+ input_tensor: torch.Tensor,
689
+ time_emb: Optional[torch.Tensor] = None,
690
+ context_tensor: Optional[torch.Tensor] = None,
691
+ ) -> torch.Tensor:
692
+ """Forward function of the MidBlock2D.
693
+
694
+ Args:
695
+ input_tensor (torch.Tensor): the input tensor.
696
+ time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
697
+ time embedding.
698
+ context_tensor (torch.Tensor): optional context tensor, if the block if configured to use
699
+ transofrmer block.
700
+
701
+ Returns:
702
+ output hidden_states tensor after MidBlock2D.
703
+ """
704
+ hidden_states = self.resnets[0](input_tensor, time_emb)
705
+ for i, resnet in enumerate(self.resnets[1:]):
706
+ if self.attentions is not None:
707
+ hidden_states = self.attentions[i](hidden_states)
708
+ if self.transformers is not None:
709
+ hidden_states = self.transformers[i](hidden_states, context_tensor)
710
+ hidden_states = resnet(hidden_states, time_emb)
711
+ return hidden_states