ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +31 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +400 -0
- ai_edge_torch/convert/converter.py +202 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +311 -0
- ai_edge_torch/convert/test/test_convert_composites.py +192 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
- ai_edge_torch/convert/to_channel_last_io.py +85 -0
- ai_edge_torch/debug/__init__.py +17 -0
- ai_edge_torch/debug/culprit.py +464 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/test/test_search_model.py +50 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
- ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/fx_passes/__init__.py +31 -0
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +354 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +131 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +158 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
- ai_edge_torch/generative/layers/unet/__init__.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
- ai_edge_torch/generative/layers/unet/builder.py +47 -0
- ai_edge_torch/generative/layers/unet/model_config.py +269 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/loader_test.py +80 -0
- ai_edge_torch/generative/test/test_model_conversion.py +235 -0
- ai_edge_torch/generative/test/test_quantize.py +162 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +328 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
- ai_edge_torch/generative/utilities/t5_loader.py +483 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +142 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +81 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,711 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import List, Optional, Tuple
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch import nn
|
|
20
|
+
|
|
21
|
+
from ai_edge_torch.generative.layers.attention import CrossAttention
|
|
22
|
+
from ai_edge_torch.generative.layers.attention import SelfAttention
|
|
23
|
+
import ai_edge_torch.generative.layers.builder as layers_builder
|
|
24
|
+
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
|
25
|
+
import ai_edge_torch.generative.layers.unet.builder as unet_builder
|
|
26
|
+
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ResidualBlock2D(nn.Module):
|
|
30
|
+
"""2D Residual block containing two Conv2D with optional time embedding as input."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, config: unet_cfg.ResidualBlock2DConfig):
|
|
33
|
+
"""Initialize an instance of the ResidualBlock2D.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
config (unet_cfg.ResidualBlock2DConfig): the configuration of this block.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.config = config
|
|
40
|
+
self.norm_1 = layers_builder.build_norm(
|
|
41
|
+
config.in_channels, config.normalization_config
|
|
42
|
+
)
|
|
43
|
+
self.conv_1 = nn.Conv2d(
|
|
44
|
+
config.in_channels, config.out_channels, kernel_size=3, stride=1, padding=1
|
|
45
|
+
)
|
|
46
|
+
if config.time_embedding_channels is not None:
|
|
47
|
+
self.time_emb_proj = nn.Linear(
|
|
48
|
+
config.time_embedding_channels, config.out_channels
|
|
49
|
+
)
|
|
50
|
+
else:
|
|
51
|
+
self.time_emb_proj = None
|
|
52
|
+
self.norm_2 = layers_builder.build_norm(
|
|
53
|
+
config.out_channels, config.normalization_config
|
|
54
|
+
)
|
|
55
|
+
self.conv_2 = nn.Conv2d(
|
|
56
|
+
config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
|
|
57
|
+
)
|
|
58
|
+
self.act_fn = layers_builder.get_activation(config.activation_config)
|
|
59
|
+
if config.in_channels == config.out_channels:
|
|
60
|
+
self.residual_layer = nn.Identity()
|
|
61
|
+
else:
|
|
62
|
+
self.residual_layer = nn.Conv2d(
|
|
63
|
+
config.in_channels, config.out_channels, kernel_size=1, stride=1, padding=0
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def forward(
|
|
67
|
+
self, input_tensor: torch.Tensor, time_emb: Optional[torch.Tensor] = None
|
|
68
|
+
) -> torch.Tensor:
|
|
69
|
+
"""Forward function of the ResidualBlock2D.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
input_tensor (torch.Tensor): the input tensor.
|
|
73
|
+
time_emb (Optional[torch.Tensor]): optional time embedding tensor.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
output hidden_states tensor after ResidualBlock2D.
|
|
77
|
+
"""
|
|
78
|
+
residual = input_tensor
|
|
79
|
+
x = self.norm_1(input_tensor)
|
|
80
|
+
x = self.act_fn(x)
|
|
81
|
+
x = self.conv_1(x)
|
|
82
|
+
if self.time_emb_proj is not None:
|
|
83
|
+
time_emb = self.act_fn(time_emb)
|
|
84
|
+
time_emb = self.time_emb_proj(time_emb)[:, :, None, None]
|
|
85
|
+
x = x + time_emb
|
|
86
|
+
x = self.norm_2(x)
|
|
87
|
+
x = self.act_fn(x)
|
|
88
|
+
x = self.conv_2(x)
|
|
89
|
+
x = x + self.residual_layer(residual)
|
|
90
|
+
return x
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class AttentionBlock2D(nn.Module):
|
|
94
|
+
"""2D self attention block
|
|
95
|
+
|
|
96
|
+
x = SelfAttention(Norm(input_tensor)) + x
|
|
97
|
+
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def __init__(self, config: unet_cfg.AttentionBlock2DConfig):
|
|
101
|
+
"""Initialize an instance of the AttentionBlock2D.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
config (unet_cfg.AttentionBlock2DConfig): the configuration of this block.
|
|
105
|
+
"""
|
|
106
|
+
super().__init__()
|
|
107
|
+
self.config = config
|
|
108
|
+
self.norm = layers_builder.build_norm(config.dim, config.normalization_config)
|
|
109
|
+
self.attention = SelfAttention(
|
|
110
|
+
config.attention_batch_size,
|
|
111
|
+
config.dim,
|
|
112
|
+
config.attention_config,
|
|
113
|
+
0,
|
|
114
|
+
enable_hlfb=config.enable_hlfb,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
118
|
+
"""Forward function of the AttentionBlock2D.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
input_tensor (torch.Tensor): the input tensor.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
output activation tensor after self attention.
|
|
125
|
+
"""
|
|
126
|
+
residual = input_tensor
|
|
127
|
+
B, C, H, W = input_tensor.shape
|
|
128
|
+
x = input_tensor
|
|
129
|
+
if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
|
|
130
|
+
x = self.norm(x)
|
|
131
|
+
x = input_tensor.view(B, C, H * W)
|
|
132
|
+
x = x.transpose(-1, -2)
|
|
133
|
+
else:
|
|
134
|
+
x = input_tensor.view(B, C, H * W)
|
|
135
|
+
x = x.transpose(-1, -2)
|
|
136
|
+
x = self.norm(x)
|
|
137
|
+
x = x.contiguous() # Prevent BATCH_MATMUL op in converted tflite.
|
|
138
|
+
x = self.attention(x)
|
|
139
|
+
x = x.transpose(-1, -2)
|
|
140
|
+
x = x.view(B, C, H, W)
|
|
141
|
+
x = x + residual
|
|
142
|
+
return x
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class CrossAttentionBlock2D(nn.Module):
|
|
146
|
+
"""2D cross attention block
|
|
147
|
+
|
|
148
|
+
x = CrossAttention(Norm(input_tensor), context) + x
|
|
149
|
+
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
def __init__(self, config: unet_cfg.CrossAttentionBlock2DConfig):
|
|
153
|
+
"""Initialize an instance of the AttentionBlock2D.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
config (unet_cfg.CrossAttentionBlock2DConfig): the configuration of this block.
|
|
157
|
+
"""
|
|
158
|
+
super().__init__()
|
|
159
|
+
self.config = config
|
|
160
|
+
self.norm = layers_builder.build_norm(config.query_dim, config.normalization_config)
|
|
161
|
+
self.attention = CrossAttention(
|
|
162
|
+
config.attention_batch_size,
|
|
163
|
+
config.query_dim,
|
|
164
|
+
config.cross_dim,
|
|
165
|
+
config.attention_config,
|
|
166
|
+
0,
|
|
167
|
+
enable_hlfb=config.enable_hlfb,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def forward(
|
|
171
|
+
self, input_tensor: torch.Tensor, context_tensor: torch.Tensor
|
|
172
|
+
) -> torch.Tensor:
|
|
173
|
+
"""Forward function of the CrossAttentionBlock2D.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
input_tensor (torch.Tensor): the input tensor.
|
|
177
|
+
context_tensor (torch.Tensor): the context tensor to apply cross attention on.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
output activation tensor after cross attention.
|
|
181
|
+
"""
|
|
182
|
+
residual = input_tensor
|
|
183
|
+
B, C, H, W = input_tensor.shape
|
|
184
|
+
x = input_tensor
|
|
185
|
+
if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
|
|
186
|
+
x = self.norm(x)
|
|
187
|
+
x = input_tensor.view(B, C, H * W)
|
|
188
|
+
x = x.transpose(-1, -2)
|
|
189
|
+
else:
|
|
190
|
+
x = input_tensor.view(B, C, H * W)
|
|
191
|
+
x = x.transpose(-1, -2)
|
|
192
|
+
x = self.norm(x)
|
|
193
|
+
x = self.attention(x, context_tensor)
|
|
194
|
+
x = x.transpose(-1, -2)
|
|
195
|
+
x = x.view(B, C, H, W)
|
|
196
|
+
x = x + residual
|
|
197
|
+
return x
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class FeedForwardBlock2D(nn.Module):
|
|
201
|
+
"""2D feed forward block
|
|
202
|
+
|
|
203
|
+
x = w2(Activation(w1(Norm(x)))) + x
|
|
204
|
+
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
def __init__(
|
|
208
|
+
self,
|
|
209
|
+
config: unet_cfg.FeedForwardBlock2DConfig,
|
|
210
|
+
):
|
|
211
|
+
super().__init__()
|
|
212
|
+
self.config = config
|
|
213
|
+
self.act = layers_builder.get_activation(config.activation_config)
|
|
214
|
+
self.norm = layers_builder.build_norm(config.dim, config.normalization_config)
|
|
215
|
+
if config.activation_config.type == layers_cfg.ActivationType.GE_GLU:
|
|
216
|
+
self.w1 = nn.Identity()
|
|
217
|
+
self.w2 = nn.Linear(config.hidden_dim, config.dim)
|
|
218
|
+
else:
|
|
219
|
+
self.w1 = nn.Linear(config.dim, config.hidden_dim)
|
|
220
|
+
self.w2 = nn.Linear(config.hidden_dim, config.dim)
|
|
221
|
+
|
|
222
|
+
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
223
|
+
residual = input_tensor
|
|
224
|
+
B, C, H, W = input_tensor.shape
|
|
225
|
+
x = input_tensor
|
|
226
|
+
if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
|
|
227
|
+
x = self.norm(x)
|
|
228
|
+
x = input_tensor.view(B, C, H * W)
|
|
229
|
+
x = x.transpose(-1, -2)
|
|
230
|
+
else:
|
|
231
|
+
x = input_tensor.view(B, C, H * W)
|
|
232
|
+
x = x.transpose(-1, -2)
|
|
233
|
+
x = self.norm(x)
|
|
234
|
+
x = self.w1(x)
|
|
235
|
+
x = self.act(x)
|
|
236
|
+
x = self.w2(x)
|
|
237
|
+
|
|
238
|
+
x = x.transpose(-1, -2) # (B, C, HW)
|
|
239
|
+
x = x.view((B, C, H, W))
|
|
240
|
+
|
|
241
|
+
return x + residual
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class TransformerBlock2D(nn.Module):
|
|
245
|
+
"""Basic transformer block used in UNet of diffusion model
|
|
246
|
+
|
|
247
|
+
input_tensor context_tensor
|
|
248
|
+
| |
|
|
249
|
+
┌─────────▼─────────┐ |
|
|
250
|
+
│ ConvIn | │
|
|
251
|
+
└─────────┬─────────┘ |
|
|
252
|
+
| |
|
|
253
|
+
▼ |
|
|
254
|
+
┌───────────────────┐ |
|
|
255
|
+
│ Attention Block │ |
|
|
256
|
+
└─────────┬─────────┘ |
|
|
257
|
+
│ |
|
|
258
|
+
┌────────────────────┐ |
|
|
259
|
+
│CrossAttention Block│◄─────┘
|
|
260
|
+
└─────────┬──────────┘
|
|
261
|
+
│
|
|
262
|
+
┌─────────▼─────────┐
|
|
263
|
+
│ FeedForwardBlock │
|
|
264
|
+
└─────────┬─────────┘
|
|
265
|
+
│
|
|
266
|
+
┌─────────▼─────────┐
|
|
267
|
+
│ ConvOut │
|
|
268
|
+
└─────────┬─────────┘
|
|
269
|
+
▼
|
|
270
|
+
hidden_states
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
def __init__(self, config: unet_cfg.TransformerBlock2DConfig):
|
|
276
|
+
"""Initialize an instance of the TransformerBlock2D.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
config (unet_cfg.TransformerBlock2Dconfig): the configuration of this block.
|
|
280
|
+
"""
|
|
281
|
+
super().__init__()
|
|
282
|
+
self.config = config
|
|
283
|
+
self.pre_conv_norm = layers_builder.build_norm(
|
|
284
|
+
config.attention_block_config.dim, config.pre_conv_normalization_config
|
|
285
|
+
)
|
|
286
|
+
self.conv_in = nn.Conv2d(
|
|
287
|
+
config.attention_block_config.dim,
|
|
288
|
+
config.attention_block_config.dim,
|
|
289
|
+
kernel_size=1,
|
|
290
|
+
padding=0,
|
|
291
|
+
)
|
|
292
|
+
self.self_attention = AttentionBlock2D(config.attention_block_config)
|
|
293
|
+
self.cross_attention = CrossAttentionBlock2D(config.cross_attention_block_config)
|
|
294
|
+
self.feed_forward = FeedForwardBlock2D(config.feed_forward_block_config)
|
|
295
|
+
self.conv_out = nn.Conv2d(
|
|
296
|
+
config.attention_block_config.dim,
|
|
297
|
+
config.attention_block_config.dim,
|
|
298
|
+
kernel_size=1,
|
|
299
|
+
padding=0,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
def forward(self, x: torch.Tensor, context: torch.Tensor):
|
|
303
|
+
"""Forward function of the TransformerBlock2D.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
input_tensor (torch.Tensor): the input tensor.
|
|
307
|
+
context_tensor (torch.Tensor): the context tensor to apply cross attention on.
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
output activation tensor after transformer block.
|
|
311
|
+
"""
|
|
312
|
+
residual_long = x
|
|
313
|
+
|
|
314
|
+
x = self.pre_conv_norm(x)
|
|
315
|
+
x = self.conv_in(x)
|
|
316
|
+
x = self.self_attention(x)
|
|
317
|
+
x = self.cross_attention(x, context)
|
|
318
|
+
x = self.feed_forward(x)
|
|
319
|
+
|
|
320
|
+
x = self.conv_out(x)
|
|
321
|
+
x = x + residual_long
|
|
322
|
+
|
|
323
|
+
return x
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
class DownEncoderBlock2D(nn.Module):
|
|
327
|
+
"""Encoder block containing several residual blocks with optional interleaved transformer blocks.
|
|
328
|
+
|
|
329
|
+
input_tensor
|
|
330
|
+
|
|
|
331
|
+
┌──────────────▼─────────────┐
|
|
332
|
+
│ ┌────────────────────┐ │
|
|
333
|
+
│ │ ResidualBlock2D │ │
|
|
334
|
+
│ └──────────┬─────────┘ │
|
|
335
|
+
│ │ │ num_layers
|
|
336
|
+
│ ┌────────────────────┐ │
|
|
337
|
+
│ │ (Optional) │ │
|
|
338
|
+
│ │ TransformerBlock2D │ │
|
|
339
|
+
│ └──────────┬─────────┘ │
|
|
340
|
+
└──────────────┬─────────────┘
|
|
341
|
+
│
|
|
342
|
+
┌──────────▼─────────┐
|
|
343
|
+
│ (Optional) │
|
|
344
|
+
│ Downsampler │
|
|
345
|
+
└──────────┬─────────┘
|
|
346
|
+
│
|
|
347
|
+
▼
|
|
348
|
+
hidden_states
|
|
349
|
+
"""
|
|
350
|
+
|
|
351
|
+
def __init__(self, config: unet_cfg.DownEncoderBlock2DConfig):
|
|
352
|
+
"""Initialize an instance of the DownEncoderBlock2D.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
config (unet_cfg.DownEncoderBlock2DConfig): the configuration of this block.
|
|
356
|
+
"""
|
|
357
|
+
super().__init__()
|
|
358
|
+
self.config = config
|
|
359
|
+
resnets = []
|
|
360
|
+
transformers = []
|
|
361
|
+
for i in range(config.num_layers):
|
|
362
|
+
input_channels = config.in_channels if i == 0 else config.out_channels
|
|
363
|
+
resnets.append(
|
|
364
|
+
ResidualBlock2D(
|
|
365
|
+
unet_cfg.ResidualBlock2DConfig(
|
|
366
|
+
in_channels=input_channels,
|
|
367
|
+
out_channels=config.out_channels,
|
|
368
|
+
time_embedding_channels=config.time_embedding_channels,
|
|
369
|
+
normalization_config=config.normalization_config,
|
|
370
|
+
activation_config=config.activation_config,
|
|
371
|
+
)
|
|
372
|
+
)
|
|
373
|
+
)
|
|
374
|
+
if config.transformer_block_config:
|
|
375
|
+
transformers.append(TransformerBlock2D(config.transformer_block_config))
|
|
376
|
+
self.resnets = nn.ModuleList(resnets)
|
|
377
|
+
self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
|
|
378
|
+
if config.add_downsample:
|
|
379
|
+
self.downsampler = unet_builder.build_downsampling(config.sampling_config)
|
|
380
|
+
else:
|
|
381
|
+
self.downsampler = None
|
|
382
|
+
|
|
383
|
+
def forward(
|
|
384
|
+
self,
|
|
385
|
+
input_tensor: torch.Tensor,
|
|
386
|
+
time_emb: Optional[torch.Tensor] = None,
|
|
387
|
+
context_tensor: Optional[torch.Tensor] = None,
|
|
388
|
+
output_hidden_states: bool = False,
|
|
389
|
+
) -> torch.Tensor | Tuple[torch.Tensor, List[torch.Tensor]]:
|
|
390
|
+
"""Forward function of the DownEncoderBlock2D.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
input_tensor (torch.Tensor): the input tensor.
|
|
394
|
+
time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
|
|
395
|
+
time embedding.
|
|
396
|
+
context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block.
|
|
397
|
+
output_hidden_states (bool): whether to output hidden states, usually for skip connections.
|
|
398
|
+
Returns:
|
|
399
|
+
output hidden_states tensor after DownEncoderBlock2D.
|
|
400
|
+
"""
|
|
401
|
+
hidden_states = input_tensor
|
|
402
|
+
output_states = []
|
|
403
|
+
for i, resnet in enumerate(self.resnets):
|
|
404
|
+
hidden_states = resnet(hidden_states, time_emb)
|
|
405
|
+
if self.transformers is not None:
|
|
406
|
+
hidden_states = self.transformers[i](hidden_states, context_tensor)
|
|
407
|
+
output_states.append(hidden_states)
|
|
408
|
+
if self.downsampler:
|
|
409
|
+
hidden_states = self.downsampler(hidden_states)
|
|
410
|
+
output_states.append(hidden_states)
|
|
411
|
+
if output_hidden_states:
|
|
412
|
+
return hidden_states, output_states
|
|
413
|
+
else:
|
|
414
|
+
return hidden_states
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
class UpDecoderBlock2D(nn.Module):
|
|
418
|
+
"""Decoder block containing several residual blocks with optional interleaved transformer blocks.
|
|
419
|
+
|
|
420
|
+
input_tensor
|
|
421
|
+
|
|
|
422
|
+
┌──────────────▼─────────────┐
|
|
423
|
+
│ ┌────────────────────┐ │
|
|
424
|
+
│ │ ResidualBlock2D │ │
|
|
425
|
+
│ └──────────┬─────────┘ │
|
|
426
|
+
│ │ │ num_layers
|
|
427
|
+
│ ┌────────────────────┐ │
|
|
428
|
+
│ │ (Optional) │ │
|
|
429
|
+
│ │ TransformerBlock2D │ │
|
|
430
|
+
│ └──────────┬─────────┘ │
|
|
431
|
+
└──────────────┬─────────────┘
|
|
432
|
+
│
|
|
433
|
+
┌──────────▼─────────┐
|
|
434
|
+
│ (Optional) │
|
|
435
|
+
│ Upsampler │
|
|
436
|
+
└──────────┬─────────┘
|
|
437
|
+
│
|
|
438
|
+
┌──────────▼─────────┐
|
|
439
|
+
│ (Optional) │
|
|
440
|
+
│ Conv2D │
|
|
441
|
+
└──────────┬─────────┘
|
|
442
|
+
│
|
|
443
|
+
▼
|
|
444
|
+
hidden_states
|
|
445
|
+
"""
|
|
446
|
+
|
|
447
|
+
def __init__(self, config: unet_cfg.UpDecoderBlock2DConfig):
|
|
448
|
+
"""Initialize an instance of the UpDecoderBlock2D.
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
config (unet_cfg.UpDecoderBlock2DConfig): the configuration of this block.
|
|
452
|
+
"""
|
|
453
|
+
super().__init__()
|
|
454
|
+
self.config = config
|
|
455
|
+
resnets = []
|
|
456
|
+
transformers = []
|
|
457
|
+
for i in range(config.num_layers):
|
|
458
|
+
input_channels = config.in_channels if i == 0 else config.out_channels
|
|
459
|
+
resnets.append(
|
|
460
|
+
ResidualBlock2D(
|
|
461
|
+
unet_cfg.ResidualBlock2DConfig(
|
|
462
|
+
in_channels=input_channels,
|
|
463
|
+
out_channels=config.out_channels,
|
|
464
|
+
time_embedding_channels=config.time_embedding_channels,
|
|
465
|
+
normalization_config=config.normalization_config,
|
|
466
|
+
activation_config=config.activation_config,
|
|
467
|
+
)
|
|
468
|
+
)
|
|
469
|
+
)
|
|
470
|
+
if config.transformer_block_config:
|
|
471
|
+
transformers.append(TransformerBlock2D(config.transformer_block_config))
|
|
472
|
+
self.resnets = nn.ModuleList(resnets)
|
|
473
|
+
self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
|
|
474
|
+
if config.add_upsample:
|
|
475
|
+
self.upsampler = unet_builder.build_upsampling(config.sampling_config)
|
|
476
|
+
if config.upsample_conv:
|
|
477
|
+
self.upsample_conv = nn.Conv2d(
|
|
478
|
+
config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
|
|
479
|
+
)
|
|
480
|
+
else:
|
|
481
|
+
self.upsampler = None
|
|
482
|
+
|
|
483
|
+
def forward(
|
|
484
|
+
self,
|
|
485
|
+
input_tensor: torch.Tensor,
|
|
486
|
+
time_emb: Optional[torch.Tensor] = None,
|
|
487
|
+
context_tensor: Optional[torch.Tensor] = None,
|
|
488
|
+
) -> torch.Tensor:
|
|
489
|
+
"""Forward function of the UpDecoderBlock2D.
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
input_tensor (torch.Tensor): the input tensor.
|
|
493
|
+
time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
|
|
494
|
+
time embedding.
|
|
495
|
+
context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block.
|
|
496
|
+
|
|
497
|
+
Returns:
|
|
498
|
+
output hidden_states tensor after UpDecoderBlock2D.
|
|
499
|
+
"""
|
|
500
|
+
hidden_states = input_tensor
|
|
501
|
+
for i, resnet in enumerate(self.resnets):
|
|
502
|
+
hidden_states = resnet(hidden_states, time_emb)
|
|
503
|
+
if self.transformers is not None:
|
|
504
|
+
hidden_states = self.transformers[i](hidden_states, context_tensor)
|
|
505
|
+
if self.upsampler:
|
|
506
|
+
hidden_states = self.upsampler(hidden_states)
|
|
507
|
+
if self.upsample_conv:
|
|
508
|
+
hidden_states = self.upsample_conv(hidden_states)
|
|
509
|
+
return hidden_states
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
class SkipUpDecoderBlock2D(nn.Module):
|
|
513
|
+
"""Decoder block contains skip connections and residual blocks with optional interleaved transformer blocks.
|
|
514
|
+
|
|
515
|
+
input_tensor, skip_connection_tensors
|
|
516
|
+
|
|
|
517
|
+
┌──────────────▼─────────────┐
|
|
518
|
+
│ ┌────────────────────┐ │
|
|
519
|
+
│ │ ResidualBlock2D │ │
|
|
520
|
+
│ └──────────┬─────────┘ │
|
|
521
|
+
│ │ │ num_layers
|
|
522
|
+
│ ┌────────────────────┐ │
|
|
523
|
+
│ │ (Optional) │ │
|
|
524
|
+
│ │ TransformerBlock2D │ │
|
|
525
|
+
│ └──────────┬─────────┘ │
|
|
526
|
+
└──────────────┬─────────────┘
|
|
527
|
+
│
|
|
528
|
+
┌──────────▼─────────┐
|
|
529
|
+
│ (Optional) │
|
|
530
|
+
│ Upsampler │
|
|
531
|
+
└──────────┬─────────┘
|
|
532
|
+
│
|
|
533
|
+
┌──────────▼─────────┐
|
|
534
|
+
│ (Optional) │
|
|
535
|
+
│ Conv2D │
|
|
536
|
+
└──────────┬─────────┘
|
|
537
|
+
│
|
|
538
|
+
▼
|
|
539
|
+
hidden_states
|
|
540
|
+
"""
|
|
541
|
+
|
|
542
|
+
def __init__(self, config: unet_cfg.SkipUpDecoderBlock2DConfig):
|
|
543
|
+
"""Initialize an instance of the SkipUpDecoderBlock2D.
|
|
544
|
+
|
|
545
|
+
Args:
|
|
546
|
+
config (unet_cfg.SkipUpDecoderBlock2DConfig): the configuration of this block.
|
|
547
|
+
"""
|
|
548
|
+
super().__init__()
|
|
549
|
+
self.config = config
|
|
550
|
+
resnets = []
|
|
551
|
+
transformers = []
|
|
552
|
+
for i in range(config.num_layers):
|
|
553
|
+
res_skip_channels = (
|
|
554
|
+
config.in_channels if (i == config.num_layers - 1) else config.out_channels
|
|
555
|
+
)
|
|
556
|
+
resnet_in_channels = config.prev_out_channels if i == 0 else config.out_channels
|
|
557
|
+
resnets.append(
|
|
558
|
+
ResidualBlock2D(
|
|
559
|
+
unet_cfg.ResidualBlock2DConfig(
|
|
560
|
+
in_channels=resnet_in_channels + res_skip_channels,
|
|
561
|
+
out_channels=config.out_channels,
|
|
562
|
+
time_embedding_channels=config.time_embedding_channels,
|
|
563
|
+
normalization_config=config.normalization_config,
|
|
564
|
+
activation_config=config.activation_config,
|
|
565
|
+
)
|
|
566
|
+
)
|
|
567
|
+
)
|
|
568
|
+
if config.transformer_block_config:
|
|
569
|
+
transformers.append(TransformerBlock2D(config.transformer_block_config))
|
|
570
|
+
self.resnets = nn.ModuleList(resnets)
|
|
571
|
+
self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
|
|
572
|
+
if config.add_upsample:
|
|
573
|
+
self.upsampler = unet_builder.build_upsampling(config.sampling_config)
|
|
574
|
+
if config.upsample_conv:
|
|
575
|
+
self.upsample_conv = nn.Conv2d(
|
|
576
|
+
config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
|
|
577
|
+
)
|
|
578
|
+
else:
|
|
579
|
+
self.upsampler = None
|
|
580
|
+
|
|
581
|
+
def forward(
|
|
582
|
+
self,
|
|
583
|
+
input_tensor: torch.Tensor,
|
|
584
|
+
skip_connection_tensors: List[torch.Tensor],
|
|
585
|
+
time_emb: Optional[torch.Tensor] = None,
|
|
586
|
+
context_tensor: Optional[torch.Tensor] = None,
|
|
587
|
+
) -> torch.Tensor:
|
|
588
|
+
"""Forward function of the SkipUpDecoderBlock2D.
|
|
589
|
+
|
|
590
|
+
Args:
|
|
591
|
+
input_tensor (torch.Tensor): the input tensor.
|
|
592
|
+
skip_connection_tensors (List[torch.Tensor]): the skip connection tensors from encoder blocks.
|
|
593
|
+
time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
|
|
594
|
+
time embedding.
|
|
595
|
+
context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block.
|
|
596
|
+
|
|
597
|
+
Returns:
|
|
598
|
+
output hidden_states tensor after SkipUpDecoderBlock2D.
|
|
599
|
+
"""
|
|
600
|
+
hidden_states = input_tensor
|
|
601
|
+
for i, (resnet, skip_connection_tensor) in enumerate(
|
|
602
|
+
zip(self.resnets, skip_connection_tensors)
|
|
603
|
+
):
|
|
604
|
+
hidden_states = torch.cat([hidden_states, skip_connection_tensor], dim=1)
|
|
605
|
+
hidden_states = resnet(hidden_states, time_emb)
|
|
606
|
+
if self.transformers is not None:
|
|
607
|
+
hidden_states = self.transformers[i](hidden_states, context_tensor)
|
|
608
|
+
if self.upsampler:
|
|
609
|
+
hidden_states = self.upsampler(hidden_states)
|
|
610
|
+
if self.upsample_conv:
|
|
611
|
+
hidden_states = self.upsample_conv(hidden_states)
|
|
612
|
+
return hidden_states
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
class MidBlock2D(nn.Module):
|
|
616
|
+
"""Middle block containing at least one residual blocks with optional interleaved attention blocks.
|
|
617
|
+
|
|
618
|
+
input_tensor
|
|
619
|
+
|
|
|
620
|
+
▼
|
|
621
|
+
┌───────────────────┐
|
|
622
|
+
│ ResidualBlock2D │
|
|
623
|
+
└─────────┬─────────┘
|
|
624
|
+
│
|
|
625
|
+
┌──────────────▼─────────────┐
|
|
626
|
+
│ ┌────────────────────┐ │
|
|
627
|
+
│ │ (Optional) │ │
|
|
628
|
+
│ │ AttentionBlock2D │ │
|
|
629
|
+
│ └──────────┬─────────┘ │
|
|
630
|
+
│ │ │
|
|
631
|
+
│ ┌──────────▼─────────┐ │
|
|
632
|
+
│ │ (Optional) │ │ num_layers
|
|
633
|
+
│ │ TransformerBlock2D │ │
|
|
634
|
+
│ └──────────┬─────────┘ │
|
|
635
|
+
│ │ │
|
|
636
|
+
│ ┌──────────▼─────────┐ │
|
|
637
|
+
│ │ ResidualBlock2D │ │
|
|
638
|
+
│ └────────────────────┘ │
|
|
639
|
+
└──────────────┬─────────────┘
|
|
640
|
+
│
|
|
641
|
+
▼
|
|
642
|
+
hidden_states
|
|
643
|
+
"""
|
|
644
|
+
|
|
645
|
+
def __init__(self, config: unet_cfg.MidBlock2DConfig):
|
|
646
|
+
"""Initialize an instance of the MidBlock2D.
|
|
647
|
+
|
|
648
|
+
Args:
|
|
649
|
+
config (unet_cfg.MidBlock2DConfig): the configuration of this block.
|
|
650
|
+
"""
|
|
651
|
+
super().__init__()
|
|
652
|
+
self.config = config
|
|
653
|
+
resnets = [
|
|
654
|
+
ResidualBlock2D(
|
|
655
|
+
unet_cfg.ResidualBlock2DConfig(
|
|
656
|
+
in_channels=config.in_channels,
|
|
657
|
+
out_channels=config.in_channels,
|
|
658
|
+
time_embedding_channels=config.time_embedding_channels,
|
|
659
|
+
normalization_config=config.normalization_config,
|
|
660
|
+
activation_config=config.activation_config,
|
|
661
|
+
)
|
|
662
|
+
)
|
|
663
|
+
]
|
|
664
|
+
attentions = []
|
|
665
|
+
transformers = []
|
|
666
|
+
for i in range(config.num_layers):
|
|
667
|
+
if self.config.attention_block_config:
|
|
668
|
+
attentions.append(AttentionBlock2D(config.attention_block_config))
|
|
669
|
+
if self.config.transformer_block_config:
|
|
670
|
+
transformers.append(TransformerBlock2D(config.transformer_block_config))
|
|
671
|
+
resnets.append(
|
|
672
|
+
ResidualBlock2D(
|
|
673
|
+
unet_cfg.ResidualBlock2DConfig(
|
|
674
|
+
in_channels=config.in_channels,
|
|
675
|
+
out_channels=config.in_channels,
|
|
676
|
+
time_embedding_channels=config.time_embedding_channels,
|
|
677
|
+
normalization_config=config.normalization_config,
|
|
678
|
+
activation_config=config.activation_config,
|
|
679
|
+
)
|
|
680
|
+
)
|
|
681
|
+
)
|
|
682
|
+
self.resnets = nn.ModuleList(resnets)
|
|
683
|
+
self.attentions = nn.ModuleList(attentions) if len(attentions) > 0 else None
|
|
684
|
+
self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
|
|
685
|
+
|
|
686
|
+
def forward(
|
|
687
|
+
self,
|
|
688
|
+
input_tensor: torch.Tensor,
|
|
689
|
+
time_emb: Optional[torch.Tensor] = None,
|
|
690
|
+
context_tensor: Optional[torch.Tensor] = None,
|
|
691
|
+
) -> torch.Tensor:
|
|
692
|
+
"""Forward function of the MidBlock2D.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
input_tensor (torch.Tensor): the input tensor.
|
|
696
|
+
time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
|
|
697
|
+
time embedding.
|
|
698
|
+
context_tensor (torch.Tensor): optional context tensor, if the block if configured to use
|
|
699
|
+
transofrmer block.
|
|
700
|
+
|
|
701
|
+
Returns:
|
|
702
|
+
output hidden_states tensor after MidBlock2D.
|
|
703
|
+
"""
|
|
704
|
+
hidden_states = self.resnets[0](input_tensor, time_emb)
|
|
705
|
+
for i, resnet in enumerate(self.resnets[1:]):
|
|
706
|
+
if self.attentions is not None:
|
|
707
|
+
hidden_states = self.attentions[i](hidden_states)
|
|
708
|
+
if self.transformers is not None:
|
|
709
|
+
hidden_states = self.transformers[i](hidden_states, context_tensor)
|
|
710
|
+
hidden_states = resnet(hidden_states, time_emb)
|
|
711
|
+
return hidden_states
|