ai-edge-torch-nightly 0.2.0.dev20240605__py3-none-any.whl → 0.2.0.dev20240608__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/conversion.py +2 -2
- ai_edge_torch/convert/fx_passes/__init__.py +1 -1
- ai_edge_torch/convert/fx_passes/{build_upsample_bilinear2d_composite_pass.py → build_interpolate_composite_pass.py} +22 -1
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +8 -4
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +275 -82
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +54 -3
- ai_edge_torch/generative/layers/attention.py +25 -0
- ai_edge_torch/generative/layers/builder.py +4 -2
- ai_edge_torch/generative/layers/model_config.py +3 -0
- ai_edge_torch/generative/layers/unet/__init__.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +287 -0
- ai_edge_torch/generative/layers/unet/builder.py +29 -0
- ai_edge_torch/generative/layers/unet/model_config.py +117 -0
- ai_edge_torch/generative/test/test_model_conversion.py +90 -80
- ai_edge_torch/generative/utilities/autoencoder_loader.py +298 -0
- ai_edge_torch/generative/utilities/loader.py +7 -5
- {ai_edge_torch_nightly-0.2.0.dev20240605.dist-info → ai_edge_torch_nightly-0.2.0.dev20240608.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240605.dist-info → ai_edge_torch_nightly-0.2.0.dev20240608.dist-info}/RECORD +21 -16
- {ai_edge_torch_nightly-0.2.0.dev20240605.dist-info → ai_edge_torch_nightly-0.2.0.dev20240608.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240605.dist-info → ai_edge_torch_nightly-0.2.0.dev20240608.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240605.dist-info → ai_edge_torch_nightly-0.2.0.dev20240608.dist-info}/top_level.txt +0 -0
|
@@ -25,7 +25,7 @@ from torch_xla import stablehlo
|
|
|
25
25
|
from ai_edge_torch import model
|
|
26
26
|
from ai_edge_torch.convert import conversion_utils as cutils
|
|
27
27
|
from ai_edge_torch.convert.fx_passes import BuildAtenCompositePass
|
|
28
|
-
from ai_edge_torch.convert.fx_passes import
|
|
28
|
+
from ai_edge_torch.convert.fx_passes import BuildInterpolateCompositePass # NOQA
|
|
29
29
|
from ai_edge_torch.convert.fx_passes import CanonicalizePass
|
|
30
30
|
from ai_edge_torch.convert.fx_passes import InjectMlirDebuginfoPass
|
|
31
31
|
from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass
|
|
@@ -41,7 +41,7 @@ def _run_convert_passes(
|
|
|
41
41
|
return run_passes(
|
|
42
42
|
exported_program,
|
|
43
43
|
[
|
|
44
|
-
|
|
44
|
+
BuildInterpolateCompositePass(),
|
|
45
45
|
CanonicalizePass(),
|
|
46
46
|
OptimizeLayoutTransposesPass(),
|
|
47
47
|
CanonicalizePass(),
|
|
@@ -24,7 +24,7 @@ from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult
|
|
|
24
24
|
from ai_edge_torch.convert.fx_passes._pass_base import FxPassBase
|
|
25
25
|
from ai_edge_torch.convert.fx_passes._pass_base import FxPassResult
|
|
26
26
|
from ai_edge_torch.convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass # NOQA
|
|
27
|
-
from ai_edge_torch.convert.fx_passes.
|
|
27
|
+
from ai_edge_torch.convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass # NOQA
|
|
28
28
|
from ai_edge_torch.convert.fx_passes.canonicalize_pass import CanonicalizePass
|
|
29
29
|
from ai_edge_torch.convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
|
|
30
30
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
|
|
@@ -66,13 +66,34 @@ def _get_upsample_bilinear2d_align_corners_pattern():
|
|
|
66
66
|
return pattern
|
|
67
67
|
|
|
68
68
|
|
|
69
|
-
|
|
69
|
+
@functools.cache
|
|
70
|
+
def _get_interpolate_nearest2d_pattern():
|
|
71
|
+
pattern = mark_pattern.Pattern(
|
|
72
|
+
"tfl.resize_nearest_neighbor",
|
|
73
|
+
lambda x: torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest"),
|
|
74
|
+
export_args=(torch.rand(1, 3, 100, 100),),
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
@pattern.register_attr_builder
|
|
78
|
+
def attr_builder(pattern, graph_module, internal_match):
|
|
79
|
+
output = internal_match.returning_nodes[0]
|
|
80
|
+
output_h, output_w = output.meta["val"].shape[-2:]
|
|
81
|
+
return {
|
|
82
|
+
"size": (int(output_h), int(output_w)),
|
|
83
|
+
"is_nchw_op": True,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
return pattern
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class BuildInterpolateCompositePass(FxPassBase):
|
|
70
90
|
|
|
71
91
|
def __init__(self):
|
|
72
92
|
super().__init__()
|
|
73
93
|
self._patterns = [
|
|
74
94
|
_get_upsample_bilinear2d_pattern(),
|
|
75
95
|
_get_upsample_bilinear2d_align_corners_pattern(),
|
|
96
|
+
_get_interpolate_nearest2d_pattern(),
|
|
76
97
|
]
|
|
77
98
|
|
|
78
99
|
def call(self, graph_module: torch.fx.GraphModule):
|
|
@@ -20,10 +20,11 @@ import torch
|
|
|
20
20
|
|
|
21
21
|
import ai_edge_torch
|
|
22
22
|
import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
|
|
23
|
-
|
|
23
|
+
import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
|
|
24
24
|
from ai_edge_torch.generative.examples.stable_diffusion.diffusion import Diffusion # NOQA
|
|
25
25
|
from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
|
|
26
26
|
import ai_edge_torch.generative.examples.stable_diffusion.util as util
|
|
27
|
+
import ai_edge_torch.generative.utilities.autoencoder_loader as autoencoder_loader
|
|
27
28
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
28
29
|
|
|
29
30
|
|
|
@@ -47,8 +48,11 @@ def convert_stable_diffusion_to_tflite(
|
|
|
47
48
|
diffusion = Diffusion()
|
|
48
49
|
diffusion.load_state_dict(torch.load(diffusion_ckpt_path))
|
|
49
50
|
|
|
50
|
-
|
|
51
|
-
|
|
51
|
+
decoder_model = decoder.Decoder(decoder.get_model_config())
|
|
52
|
+
decoder_loader = autoencoder_loader.AutoEncoderModelLoader(
|
|
53
|
+
decoder_ckpt_path, decoder.TENSORS_NAMES
|
|
54
|
+
)
|
|
55
|
+
decoder_loader.load(decoder_model)
|
|
52
56
|
|
|
53
57
|
# Tensors used to trace the model graph during conversion.
|
|
54
58
|
n_tokens = 77
|
|
@@ -85,7 +89,7 @@ def convert_stable_diffusion_to_tflite(
|
|
|
85
89
|
).convert().export('/tmp/stable_diffusion/diffusion.tflite')
|
|
86
90
|
|
|
87
91
|
# Image decoder
|
|
88
|
-
ai_edge_torch.signature('decode',
|
|
92
|
+
ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert().export(
|
|
89
93
|
'/tmp/stable_diffusion/decoder.tflite'
|
|
90
94
|
)
|
|
91
95
|
|
|
@@ -15,99 +15,292 @@
|
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
17
|
from torch import nn
|
|
18
|
-
from torch.nn import functional as F
|
|
19
18
|
|
|
20
|
-
|
|
19
|
+
import ai_edge_torch.generative.layers.builder as layers_builder
|
|
20
|
+
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
|
21
|
+
import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
|
|
22
|
+
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
|
|
23
|
+
import ai_edge_torch.generative.utilities.autoencoder_loader as autoencoder_loader
|
|
21
24
|
|
|
25
|
+
TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
|
|
26
|
+
post_quant_conv="0",
|
|
27
|
+
conv_in="1",
|
|
28
|
+
mid_block_tensor_names=autoencoder_loader.MidBlockTensorNames(
|
|
29
|
+
residual_block_tensor_names=[
|
|
30
|
+
autoencoder_loader.ResidualBlockTensorNames(
|
|
31
|
+
norm_1="2.groupnorm_1",
|
|
32
|
+
norm_2="2.groupnorm_2",
|
|
33
|
+
conv_1="2.conv_1",
|
|
34
|
+
conv_2="2.conv_2",
|
|
35
|
+
),
|
|
36
|
+
autoencoder_loader.ResidualBlockTensorNames(
|
|
37
|
+
norm_1="4.groupnorm_1",
|
|
38
|
+
norm_2="4.groupnorm_2",
|
|
39
|
+
conv_1="4.conv_1",
|
|
40
|
+
conv_2="4.conv_2",
|
|
41
|
+
),
|
|
42
|
+
],
|
|
43
|
+
attention_block_tensor_names=[
|
|
44
|
+
autoencoder_loader.AttnetionBlockTensorNames(
|
|
45
|
+
norm="3.groupnorm",
|
|
46
|
+
fused_qkv_proj="3.attention.in_proj",
|
|
47
|
+
output_proj="3.attention.out_proj",
|
|
48
|
+
)
|
|
49
|
+
],
|
|
50
|
+
),
|
|
51
|
+
up_decoder_blocks_tensor_names=[
|
|
52
|
+
autoencoder_loader.UpDecoderBlockTensorNames(
|
|
53
|
+
residual_block_tensor_names=[
|
|
54
|
+
autoencoder_loader.ResidualBlockTensorNames(
|
|
55
|
+
norm_1="5.groupnorm_1",
|
|
56
|
+
norm_2="5.groupnorm_2",
|
|
57
|
+
conv_1="5.conv_1",
|
|
58
|
+
conv_2="5.conv_2",
|
|
59
|
+
),
|
|
60
|
+
autoencoder_loader.ResidualBlockTensorNames(
|
|
61
|
+
norm_1="6.groupnorm_1",
|
|
62
|
+
norm_2="6.groupnorm_2",
|
|
63
|
+
conv_1="6.conv_1",
|
|
64
|
+
conv_2="6.conv_2",
|
|
65
|
+
),
|
|
66
|
+
autoencoder_loader.ResidualBlockTensorNames(
|
|
67
|
+
norm_1="7.groupnorm_1",
|
|
68
|
+
norm_2="7.groupnorm_2",
|
|
69
|
+
conv_1="7.conv_1",
|
|
70
|
+
conv_2="7.conv_2",
|
|
71
|
+
),
|
|
72
|
+
],
|
|
73
|
+
upsample_conv="9",
|
|
74
|
+
),
|
|
75
|
+
autoencoder_loader.UpDecoderBlockTensorNames(
|
|
76
|
+
residual_block_tensor_names=[
|
|
77
|
+
autoencoder_loader.ResidualBlockTensorNames(
|
|
78
|
+
norm_1="10.groupnorm_1",
|
|
79
|
+
norm_2="10.groupnorm_2",
|
|
80
|
+
conv_1="10.conv_1",
|
|
81
|
+
conv_2="10.conv_2",
|
|
82
|
+
),
|
|
83
|
+
autoencoder_loader.ResidualBlockTensorNames(
|
|
84
|
+
norm_1="11.groupnorm_1",
|
|
85
|
+
norm_2="11.groupnorm_2",
|
|
86
|
+
conv_1="11.conv_1",
|
|
87
|
+
conv_2="11.conv_2",
|
|
88
|
+
),
|
|
89
|
+
autoencoder_loader.ResidualBlockTensorNames(
|
|
90
|
+
norm_1="12.groupnorm_1",
|
|
91
|
+
norm_2="12.groupnorm_2",
|
|
92
|
+
conv_1="12.conv_1",
|
|
93
|
+
conv_2="12.conv_2",
|
|
94
|
+
),
|
|
95
|
+
],
|
|
96
|
+
upsample_conv="14",
|
|
97
|
+
),
|
|
98
|
+
autoencoder_loader.UpDecoderBlockTensorNames(
|
|
99
|
+
residual_block_tensor_names=[
|
|
100
|
+
autoencoder_loader.ResidualBlockTensorNames(
|
|
101
|
+
norm_1="15.groupnorm_1",
|
|
102
|
+
norm_2="15.groupnorm_2",
|
|
103
|
+
conv_1="15.conv_1",
|
|
104
|
+
conv_2="15.conv_2",
|
|
105
|
+
residual_layer="15.residual_layer",
|
|
106
|
+
),
|
|
107
|
+
autoencoder_loader.ResidualBlockTensorNames(
|
|
108
|
+
norm_1="16.groupnorm_1",
|
|
109
|
+
norm_2="16.groupnorm_2",
|
|
110
|
+
conv_1="16.conv_1",
|
|
111
|
+
conv_2="16.conv_2",
|
|
112
|
+
),
|
|
113
|
+
autoencoder_loader.ResidualBlockTensorNames(
|
|
114
|
+
norm_1="17.groupnorm_1",
|
|
115
|
+
norm_2="17.groupnorm_2",
|
|
116
|
+
conv_1="17.conv_1",
|
|
117
|
+
conv_2="17.conv_2",
|
|
118
|
+
),
|
|
119
|
+
],
|
|
120
|
+
upsample_conv="19",
|
|
121
|
+
),
|
|
122
|
+
autoencoder_loader.UpDecoderBlockTensorNames(
|
|
123
|
+
residual_block_tensor_names=[
|
|
124
|
+
autoencoder_loader.ResidualBlockTensorNames(
|
|
125
|
+
norm_1="20.groupnorm_1",
|
|
126
|
+
norm_2="20.groupnorm_2",
|
|
127
|
+
conv_1="20.conv_1",
|
|
128
|
+
conv_2="20.conv_2",
|
|
129
|
+
residual_layer="20.residual_layer",
|
|
130
|
+
),
|
|
131
|
+
autoencoder_loader.ResidualBlockTensorNames(
|
|
132
|
+
norm_1="21.groupnorm_1",
|
|
133
|
+
norm_2="21.groupnorm_2",
|
|
134
|
+
conv_1="21.conv_1",
|
|
135
|
+
conv_2="21.conv_2",
|
|
136
|
+
),
|
|
137
|
+
autoencoder_loader.ResidualBlockTensorNames(
|
|
138
|
+
norm_1="22.groupnorm_1",
|
|
139
|
+
norm_2="22.groupnorm_2",
|
|
140
|
+
conv_1="22.conv_1",
|
|
141
|
+
conv_2="22.conv_2",
|
|
142
|
+
),
|
|
143
|
+
],
|
|
144
|
+
),
|
|
145
|
+
],
|
|
146
|
+
final_norm="23",
|
|
147
|
+
conv_out="25",
|
|
148
|
+
)
|
|
22
149
|
|
|
23
|
-
class AttentionBlock(nn.Module):
|
|
24
150
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
self.groupnorm = nn.GroupNorm(32, channels)
|
|
28
|
-
self.attention = SelfAttention(1, channels)
|
|
151
|
+
class Decoder(nn.Module):
|
|
152
|
+
"""The Decoder model used in Stable Diffusion.
|
|
29
153
|
|
|
30
|
-
|
|
31
|
-
residue = x
|
|
32
|
-
x = self.groupnorm(x)
|
|
154
|
+
For details, see https://arxiv.org/abs/2103.00020
|
|
33
155
|
|
|
34
|
-
|
|
35
|
-
x = x.view((n, c, h * w))
|
|
36
|
-
x = x.transpose(-1, -2)
|
|
37
|
-
x = self.attention(x)
|
|
38
|
-
x = x.transpose(-1, -2)
|
|
39
|
-
x = x.view((n, c, h, w))
|
|
156
|
+
Sturcture of the Decoder:
|
|
40
157
|
|
|
41
|
-
|
|
42
|
-
|
|
158
|
+
latents tensor
|
|
159
|
+
|
|
|
160
|
+
▼
|
|
161
|
+
┌───────────────────┐
|
|
162
|
+
│ Post Quant Conv │
|
|
163
|
+
└─────────┬─────────┘
|
|
164
|
+
│
|
|
165
|
+
┌─────────▼─────────┐
|
|
166
|
+
│ ConvIn │
|
|
167
|
+
└─────────┬─────────┘
|
|
168
|
+
│
|
|
169
|
+
┌─────────▼─────────┐
|
|
170
|
+
│ MidBlock2D │
|
|
171
|
+
└─────────┬─────────┘
|
|
172
|
+
│
|
|
173
|
+
┌─────────▼─────────┐
|
|
174
|
+
│ UpDecoder2D │ x 4
|
|
175
|
+
└─────────┬─────────┘
|
|
176
|
+
│
|
|
177
|
+
┌─────────▼─────────┐
|
|
178
|
+
│ FinalNorm │
|
|
179
|
+
└─────────┬─────────┘
|
|
180
|
+
|
|
|
181
|
+
┌─────────▼─────────┐
|
|
182
|
+
│ Activation │
|
|
183
|
+
└─────────┬─────────┘
|
|
184
|
+
|
|
|
185
|
+
┌─────────▼─────────┐
|
|
186
|
+
│ ConvOut │
|
|
187
|
+
└─────────┬─────────┘
|
|
188
|
+
|
|
|
189
|
+
▼
|
|
190
|
+
Output Image
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
def __init__(self, config: unet_cfg.AutoEncoderConfig):
|
|
194
|
+
super().__init__()
|
|
195
|
+
self.config = config
|
|
196
|
+
self.post_quant_conv = nn.Conv2d(
|
|
197
|
+
config.latent_channels,
|
|
198
|
+
config.latent_channels,
|
|
199
|
+
kernel_size=1,
|
|
200
|
+
stride=1,
|
|
201
|
+
padding=0,
|
|
202
|
+
)
|
|
203
|
+
reversed_block_out_channels = list(reversed(config.block_out_channels))
|
|
204
|
+
self.conv_in = nn.Conv2d(
|
|
205
|
+
config.latent_channels,
|
|
206
|
+
reversed_block_out_channels[0],
|
|
207
|
+
kernel_size=3,
|
|
208
|
+
stride=1,
|
|
209
|
+
padding=1,
|
|
210
|
+
)
|
|
211
|
+
self.mid_block = blocks_2d.MidBlock2D(config.mid_block_config)
|
|
212
|
+
up_decoder_blocks = []
|
|
213
|
+
block_out_channels = reversed_block_out_channels[0]
|
|
214
|
+
for i, out_channels in enumerate(reversed_block_out_channels):
|
|
215
|
+
prev_output_channel = block_out_channels
|
|
216
|
+
block_out_channels = out_channels
|
|
217
|
+
not_final_block = i < len(reversed_block_out_channels) - 1
|
|
218
|
+
up_decoder_blocks.append(
|
|
219
|
+
blocks_2d.UpDecoderBlock2D(
|
|
220
|
+
unet_cfg.UpDecoderBlock2DConfig(
|
|
221
|
+
in_channels=prev_output_channel,
|
|
222
|
+
out_channels=block_out_channels,
|
|
223
|
+
normalization_config=config.normalization_config,
|
|
224
|
+
activation_type=config.activation_type,
|
|
225
|
+
num_layers=config.layers_per_block,
|
|
226
|
+
add_upsample=not_final_block,
|
|
227
|
+
upsample_conv=True,
|
|
228
|
+
sampling_config=unet_cfg.SamplingConfig(
|
|
229
|
+
2, unet_cfg.SamplingType.NEAREST
|
|
230
|
+
),
|
|
231
|
+
)
|
|
232
|
+
)
|
|
233
|
+
)
|
|
234
|
+
self.up_decoder_blocks = nn.ModuleList(up_decoder_blocks)
|
|
235
|
+
self.final_norm = layers_builder.build_norm(
|
|
236
|
+
block_out_channels, config.normalization_config
|
|
237
|
+
)
|
|
238
|
+
self.act_fn = layers_builder.get_activation(config.activation_type)
|
|
239
|
+
self.conv_out = nn.Conv2d(
|
|
240
|
+
block_out_channels,
|
|
241
|
+
config.out_channels,
|
|
242
|
+
kernel_size=3,
|
|
243
|
+
stride=1,
|
|
244
|
+
padding=1,
|
|
245
|
+
)
|
|
43
246
|
|
|
247
|
+
def forward(self, latents_tensor: torch.Tensor) -> torch.Tensor:
|
|
248
|
+
x = latents_tensor / self.config.scaling_factor
|
|
249
|
+
x = self.post_quant_conv(x)
|
|
250
|
+
x = self.conv_in(x)
|
|
251
|
+
x = self.mid_block(x)
|
|
252
|
+
for up_decoder_block in self.up_decoder_blocks:
|
|
253
|
+
x = up_decoder_block(x)
|
|
254
|
+
x = self.final_norm(x)
|
|
255
|
+
x = self.act_fn(x)
|
|
256
|
+
x = self.conv_out(x)
|
|
257
|
+
return x
|
|
44
258
|
|
|
45
|
-
class ResidualBlock(nn.Module):
|
|
46
259
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
260
|
+
def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
261
|
+
"""Get configs for the Decoder of Stable Diffusion v1.5"""
|
|
262
|
+
in_channels = 3
|
|
263
|
+
latent_channels = 4
|
|
264
|
+
out_channels = 3
|
|
265
|
+
block_out_channels = [128, 256, 512, 512]
|
|
266
|
+
scaling_factor = 0.18215
|
|
267
|
+
layers_per_block = 3
|
|
51
268
|
|
|
52
|
-
|
|
53
|
-
|
|
269
|
+
norm_config = layers_cfg.NormalizationConfig(
|
|
270
|
+
layers_cfg.NormalizationType.GROUP_NORM, group_num=32
|
|
271
|
+
)
|
|
54
272
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
273
|
+
att_config = unet_cfg.AttentionBlock2DConfig(
|
|
274
|
+
dims=block_out_channels[-1],
|
|
275
|
+
normalization_config=norm_config,
|
|
276
|
+
attention_config=layers_cfg.AttentionConfig(
|
|
277
|
+
num_heads=1,
|
|
278
|
+
num_query_groups=1,
|
|
279
|
+
qkv_use_bias=True,
|
|
280
|
+
output_proj_use_bias=True,
|
|
281
|
+
enable_kv_cache=False,
|
|
282
|
+
qkv_transpose_before_split=True,
|
|
283
|
+
rotary_percentage=0.0,
|
|
284
|
+
),
|
|
285
|
+
)
|
|
61
286
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
x = self.groupnorm_2(x)
|
|
70
|
-
x = F.silu(x)
|
|
71
|
-
x = self.conv_2(x)
|
|
72
|
-
|
|
73
|
-
return x + self.residual_layer(residue)
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
class Decoder(nn.Sequential):
|
|
77
|
-
|
|
78
|
-
def __init__(self):
|
|
79
|
-
super().__init__(
|
|
80
|
-
nn.Conv2d(4, 4, kernel_size=1, padding=0),
|
|
81
|
-
nn.Conv2d(4, 512, kernel_size=3, padding=1),
|
|
82
|
-
ResidualBlock(512, 512),
|
|
83
|
-
AttentionBlock(512),
|
|
84
|
-
ResidualBlock(512, 512),
|
|
85
|
-
ResidualBlock(512, 512),
|
|
86
|
-
ResidualBlock(512, 512),
|
|
87
|
-
ResidualBlock(512, 512),
|
|
88
|
-
nn.Upsample(scale_factor=2),
|
|
89
|
-
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
|
90
|
-
ResidualBlock(512, 512),
|
|
91
|
-
ResidualBlock(512, 512),
|
|
92
|
-
ResidualBlock(512, 512),
|
|
93
|
-
nn.Upsample(scale_factor=2),
|
|
94
|
-
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
|
95
|
-
ResidualBlock(512, 256),
|
|
96
|
-
ResidualBlock(256, 256),
|
|
97
|
-
ResidualBlock(256, 256),
|
|
98
|
-
nn.Upsample(scale_factor=2),
|
|
99
|
-
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
|
100
|
-
ResidualBlock(256, 128),
|
|
101
|
-
ResidualBlock(128, 128),
|
|
102
|
-
ResidualBlock(128, 128),
|
|
103
|
-
nn.GroupNorm(32, 128),
|
|
104
|
-
nn.SiLU(),
|
|
105
|
-
nn.Conv2d(128, 3, kernel_size=3, padding=1),
|
|
106
|
-
)
|
|
287
|
+
mid_block_config = unet_cfg.MidBlock2DConfig(
|
|
288
|
+
in_channels=block_out_channels[-1],
|
|
289
|
+
normalization_config=norm_config,
|
|
290
|
+
activation_type=layers_cfg.ActivationType.SILU,
|
|
291
|
+
num_layers=1,
|
|
292
|
+
attention_block_config=att_config,
|
|
293
|
+
)
|
|
107
294
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
295
|
+
config = unet_cfg.AutoEncoderConfig(
|
|
296
|
+
in_channels=in_channels,
|
|
297
|
+
latent_channels=latent_channels,
|
|
298
|
+
out_channels=out_channels,
|
|
299
|
+
activation_type=layers_cfg.ActivationType.SILU,
|
|
300
|
+
block_out_channels=block_out_channels,
|
|
301
|
+
scaling_factor=scaling_factor,
|
|
302
|
+
layers_per_block=layers_per_block,
|
|
303
|
+
normalization_config=norm_config,
|
|
304
|
+
mid_block_config=mid_block_config,
|
|
305
|
+
)
|
|
306
|
+
return config
|
|
@@ -17,9 +17,60 @@ import torch
|
|
|
17
17
|
from torch import nn
|
|
18
18
|
from torch.nn import functional as F
|
|
19
19
|
|
|
20
|
-
from ai_edge_torch.generative.examples.stable_diffusion.
|
|
21
|
-
|
|
22
|
-
|
|
20
|
+
from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AttentionBlock(nn.Module):
|
|
24
|
+
|
|
25
|
+
def __init__(self, channels):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.groupnorm = nn.GroupNorm(32, channels)
|
|
28
|
+
self.attention = SelfAttention(1, channels)
|
|
29
|
+
|
|
30
|
+
def forward(self, x):
|
|
31
|
+
residue = x
|
|
32
|
+
x = self.groupnorm(x)
|
|
33
|
+
|
|
34
|
+
n, c, h, w = x.shape
|
|
35
|
+
x = x.view((n, c, h * w))
|
|
36
|
+
x = x.transpose(-1, -2)
|
|
37
|
+
x = self.attention(x)
|
|
38
|
+
x = x.transpose(-1, -2)
|
|
39
|
+
x = x.view((n, c, h, w))
|
|
40
|
+
|
|
41
|
+
x += residue
|
|
42
|
+
return x
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ResidualBlock(nn.Module):
|
|
46
|
+
|
|
47
|
+
def __init__(self, in_channels, out_channels):
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.groupnorm_1 = nn.GroupNorm(32, in_channels)
|
|
50
|
+
self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
|
51
|
+
|
|
52
|
+
self.groupnorm_2 = nn.GroupNorm(32, out_channels)
|
|
53
|
+
self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
|
54
|
+
|
|
55
|
+
if in_channels == out_channels:
|
|
56
|
+
self.residual_layer = nn.Identity()
|
|
57
|
+
else:
|
|
58
|
+
self.residual_layer = nn.Conv2d(
|
|
59
|
+
in_channels, out_channels, kernel_size=1, padding=0
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def forward(self, x):
|
|
63
|
+
residue = x
|
|
64
|
+
|
|
65
|
+
x = self.groupnorm_1(x)
|
|
66
|
+
x = F.silu(x)
|
|
67
|
+
x = self.conv_1(x)
|
|
68
|
+
|
|
69
|
+
x = self.groupnorm_2(x)
|
|
70
|
+
x = F.silu(x)
|
|
71
|
+
x = self.conv_2(x)
|
|
72
|
+
|
|
73
|
+
return x + self.residual_layer(residue)
|
|
23
74
|
|
|
24
75
|
|
|
25
76
|
class Encoder(nn.Sequential):
|
|
@@ -199,3 +199,28 @@ class CausalSelfAttention(nn.Module):
|
|
|
199
199
|
# Compute the output projection.
|
|
200
200
|
y = self.output_projection(y)
|
|
201
201
|
return y
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class SelfAttention(CausalSelfAttention):
|
|
205
|
+
"""Non-causal Self Attention module, which is equivalent to CausalSelfAttention without mask."""
|
|
206
|
+
|
|
207
|
+
def forward(
|
|
208
|
+
self,
|
|
209
|
+
x: torch.Tensor,
|
|
210
|
+
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
211
|
+
input_pos: Optional[torch.Tensor] = None,
|
|
212
|
+
) -> torch.Tensor:
|
|
213
|
+
"""Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
x (torch.Tensor): the input tensor.
|
|
217
|
+
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
|
218
|
+
input_pos (torch.Tensor): the optional input position tensor.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
output activation from this self attention layer.
|
|
222
|
+
"""
|
|
223
|
+
B, T, _ = x.size()
|
|
224
|
+
return super().forward(
|
|
225
|
+
x, rope=rope, mask=torch.zeros((B, T), dtype=torch.float32), input_pos=input_pos
|
|
226
|
+
)
|
|
@@ -44,6 +44,8 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
|
|
|
44
44
|
)
|
|
45
45
|
elif config.type == cfg.NormalizationType.LAYER_NORM:
|
|
46
46
|
return nn.LayerNorm(dim, eps=config.epsilon)
|
|
47
|
+
elif config.type == cfg.NormalizationType.GROUP_NORM:
|
|
48
|
+
return nn.GroupNorm(config.group_num, dim, config.epsilon)
|
|
47
49
|
else:
|
|
48
50
|
raise ValueError("Unsupported norm type.")
|
|
49
51
|
|
|
@@ -69,7 +71,7 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
|
|
|
69
71
|
else:
|
|
70
72
|
raise ValueError("Unsupported feedforward type.")
|
|
71
73
|
|
|
72
|
-
activation =
|
|
74
|
+
activation = get_activation(config.activation)
|
|
73
75
|
|
|
74
76
|
return ff_module(
|
|
75
77
|
dim=dim,
|
|
@@ -79,7 +81,7 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
|
|
|
79
81
|
)
|
|
80
82
|
|
|
81
83
|
|
|
82
|
-
def
|
|
84
|
+
def get_activation(type_: cfg.ActivationType):
|
|
83
85
|
"""Get pytorch callable activation from the name.
|
|
84
86
|
|
|
85
87
|
Args:
|
|
@@ -39,6 +39,7 @@ class NormalizationType(enum.Enum):
|
|
|
39
39
|
NONE = enum.auto()
|
|
40
40
|
RMS_NORM = enum.auto()
|
|
41
41
|
LAYER_NORM = enum.auto()
|
|
42
|
+
GROUP_NORM = enum.auto()
|
|
42
43
|
|
|
43
44
|
|
|
44
45
|
@enum.unique
|
|
@@ -90,6 +91,8 @@ class NormalizationConfig:
|
|
|
90
91
|
type: NormalizationType = NormalizationType.NONE
|
|
91
92
|
epsilon: float = 1e-5
|
|
92
93
|
zero_centered: bool = False
|
|
94
|
+
# Number of groups used in group normalization.
|
|
95
|
+
group_num: Optional[float] = None
|
|
93
96
|
|
|
94
97
|
|
|
95
98
|
@dataclass
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|