ai-edge-torch-nightly 0.2.0.dev20240606__py3-none-any.whl → 0.2.0.dev20240609__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (20) hide show
  1. ai_edge_torch/convert/conversion.py +2 -2
  2. ai_edge_torch/convert/fx_passes/__init__.py +1 -1
  3. ai_edge_torch/convert/fx_passes/{build_upsample_bilinear2d_composite_pass.py → build_interpolate_composite_pass.py} +22 -1
  4. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +8 -4
  5. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +275 -82
  6. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +54 -3
  7. ai_edge_torch/generative/layers/attention.py +25 -0
  8. ai_edge_torch/generative/layers/builder.py +4 -2
  9. ai_edge_torch/generative/layers/model_config.py +3 -0
  10. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  11. ai_edge_torch/generative/layers/unet/blocks_2d.py +287 -0
  12. ai_edge_torch/generative/layers/unet/builder.py +29 -0
  13. ai_edge_torch/generative/layers/unet/model_config.py +117 -0
  14. ai_edge_torch/generative/utilities/autoencoder_loader.py +298 -0
  15. ai_edge_torch/generative/utilities/loader.py +7 -5
  16. {ai_edge_torch_nightly-0.2.0.dev20240606.dist-info → ai_edge_torch_nightly-0.2.0.dev20240609.dist-info}/METADATA +1 -1
  17. {ai_edge_torch_nightly-0.2.0.dev20240606.dist-info → ai_edge_torch_nightly-0.2.0.dev20240609.dist-info}/RECORD +20 -15
  18. {ai_edge_torch_nightly-0.2.0.dev20240606.dist-info → ai_edge_torch_nightly-0.2.0.dev20240609.dist-info}/LICENSE +0 -0
  19. {ai_edge_torch_nightly-0.2.0.dev20240606.dist-info → ai_edge_torch_nightly-0.2.0.dev20240609.dist-info}/WHEEL +0 -0
  20. {ai_edge_torch_nightly-0.2.0.dev20240606.dist-info → ai_edge_torch_nightly-0.2.0.dev20240609.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,7 @@ from torch_xla import stablehlo
25
25
  from ai_edge_torch import model
26
26
  from ai_edge_torch.convert import conversion_utils as cutils
27
27
  from ai_edge_torch.convert.fx_passes import BuildAtenCompositePass
28
- from ai_edge_torch.convert.fx_passes import BuildUpsampleBilinear2DCompositePass # NOQA
28
+ from ai_edge_torch.convert.fx_passes import BuildInterpolateCompositePass # NOQA
29
29
  from ai_edge_torch.convert.fx_passes import CanonicalizePass
30
30
  from ai_edge_torch.convert.fx_passes import InjectMlirDebuginfoPass
31
31
  from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass
@@ -41,7 +41,7 @@ def _run_convert_passes(
41
41
  return run_passes(
42
42
  exported_program,
43
43
  [
44
- BuildUpsampleBilinear2DCompositePass(),
44
+ BuildInterpolateCompositePass(),
45
45
  CanonicalizePass(),
46
46
  OptimizeLayoutTransposesPass(),
47
47
  CanonicalizePass(),
@@ -24,7 +24,7 @@ from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult
24
24
  from ai_edge_torch.convert.fx_passes._pass_base import FxPassBase
25
25
  from ai_edge_torch.convert.fx_passes._pass_base import FxPassResult
26
26
  from ai_edge_torch.convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass # NOQA
27
- from ai_edge_torch.convert.fx_passes.build_upsample_bilinear2d_composite_pass import BuildUpsampleBilinear2DCompositePass # NOQA
27
+ from ai_edge_torch.convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass # NOQA
28
28
  from ai_edge_torch.convert.fx_passes.canonicalize_pass import CanonicalizePass
29
29
  from ai_edge_torch.convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
30
30
  from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
@@ -66,13 +66,34 @@ def _get_upsample_bilinear2d_align_corners_pattern():
66
66
  return pattern
67
67
 
68
68
 
69
- class BuildUpsampleBilinear2DCompositePass(FxPassBase):
69
+ @functools.cache
70
+ def _get_interpolate_nearest2d_pattern():
71
+ pattern = mark_pattern.Pattern(
72
+ "tfl.resize_nearest_neighbor",
73
+ lambda x: torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest"),
74
+ export_args=(torch.rand(1, 3, 100, 100),),
75
+ )
76
+
77
+ @pattern.register_attr_builder
78
+ def attr_builder(pattern, graph_module, internal_match):
79
+ output = internal_match.returning_nodes[0]
80
+ output_h, output_w = output.meta["val"].shape[-2:]
81
+ return {
82
+ "size": (int(output_h), int(output_w)),
83
+ "is_nchw_op": True,
84
+ }
85
+
86
+ return pattern
87
+
88
+
89
+ class BuildInterpolateCompositePass(FxPassBase):
70
90
 
71
91
  def __init__(self):
72
92
  super().__init__()
73
93
  self._patterns = [
74
94
  _get_upsample_bilinear2d_pattern(),
75
95
  _get_upsample_bilinear2d_align_corners_pattern(),
96
+ _get_interpolate_nearest2d_pattern(),
76
97
  ]
77
98
 
78
99
  def call(self, graph_module: torch.fx.GraphModule):
@@ -20,10 +20,11 @@ import torch
20
20
 
21
21
  import ai_edge_torch
22
22
  import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
23
- from ai_edge_torch.generative.examples.stable_diffusion.decoder import Decoder
23
+ import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
24
24
  from ai_edge_torch.generative.examples.stable_diffusion.diffusion import Diffusion # NOQA
25
25
  from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
26
26
  import ai_edge_torch.generative.examples.stable_diffusion.util as util
27
+ import ai_edge_torch.generative.utilities.autoencoder_loader as autoencoder_loader
27
28
  import ai_edge_torch.generative.utilities.loader as loading_utils
28
29
 
29
30
 
@@ -47,8 +48,11 @@ def convert_stable_diffusion_to_tflite(
47
48
  diffusion = Diffusion()
48
49
  diffusion.load_state_dict(torch.load(diffusion_ckpt_path))
49
50
 
50
- decoder = Decoder()
51
- decoder.load_state_dict(torch.load(decoder_ckpt_path))
51
+ decoder_model = decoder.Decoder(decoder.get_model_config())
52
+ decoder_loader = autoencoder_loader.AutoEncoderModelLoader(
53
+ decoder_ckpt_path, decoder.TENSORS_NAMES
54
+ )
55
+ decoder_loader.load(decoder_model)
52
56
 
53
57
  # Tensors used to trace the model graph during conversion.
54
58
  n_tokens = 77
@@ -85,7 +89,7 @@ def convert_stable_diffusion_to_tflite(
85
89
  ).convert().export('/tmp/stable_diffusion/diffusion.tflite')
86
90
 
87
91
  # Image decoder
88
- ai_edge_torch.signature('decode', decoder, (input_latents,)).convert().export(
92
+ ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert().export(
89
93
  '/tmp/stable_diffusion/decoder.tflite'
90
94
  )
91
95
 
@@ -15,99 +15,292 @@
15
15
 
16
16
  import torch
17
17
  from torch import nn
18
- from torch.nn import functional as F
19
18
 
20
- from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
19
+ import ai_edge_torch.generative.layers.builder as layers_builder
20
+ import ai_edge_torch.generative.layers.model_config as layers_cfg
21
+ import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
22
+ import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
23
+ import ai_edge_torch.generative.utilities.autoencoder_loader as autoencoder_loader
21
24
 
25
+ TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
26
+ post_quant_conv="0",
27
+ conv_in="1",
28
+ mid_block_tensor_names=autoencoder_loader.MidBlockTensorNames(
29
+ residual_block_tensor_names=[
30
+ autoencoder_loader.ResidualBlockTensorNames(
31
+ norm_1="2.groupnorm_1",
32
+ norm_2="2.groupnorm_2",
33
+ conv_1="2.conv_1",
34
+ conv_2="2.conv_2",
35
+ ),
36
+ autoencoder_loader.ResidualBlockTensorNames(
37
+ norm_1="4.groupnorm_1",
38
+ norm_2="4.groupnorm_2",
39
+ conv_1="4.conv_1",
40
+ conv_2="4.conv_2",
41
+ ),
42
+ ],
43
+ attention_block_tensor_names=[
44
+ autoencoder_loader.AttnetionBlockTensorNames(
45
+ norm="3.groupnorm",
46
+ fused_qkv_proj="3.attention.in_proj",
47
+ output_proj="3.attention.out_proj",
48
+ )
49
+ ],
50
+ ),
51
+ up_decoder_blocks_tensor_names=[
52
+ autoencoder_loader.UpDecoderBlockTensorNames(
53
+ residual_block_tensor_names=[
54
+ autoencoder_loader.ResidualBlockTensorNames(
55
+ norm_1="5.groupnorm_1",
56
+ norm_2="5.groupnorm_2",
57
+ conv_1="5.conv_1",
58
+ conv_2="5.conv_2",
59
+ ),
60
+ autoencoder_loader.ResidualBlockTensorNames(
61
+ norm_1="6.groupnorm_1",
62
+ norm_2="6.groupnorm_2",
63
+ conv_1="6.conv_1",
64
+ conv_2="6.conv_2",
65
+ ),
66
+ autoencoder_loader.ResidualBlockTensorNames(
67
+ norm_1="7.groupnorm_1",
68
+ norm_2="7.groupnorm_2",
69
+ conv_1="7.conv_1",
70
+ conv_2="7.conv_2",
71
+ ),
72
+ ],
73
+ upsample_conv="9",
74
+ ),
75
+ autoencoder_loader.UpDecoderBlockTensorNames(
76
+ residual_block_tensor_names=[
77
+ autoencoder_loader.ResidualBlockTensorNames(
78
+ norm_1="10.groupnorm_1",
79
+ norm_2="10.groupnorm_2",
80
+ conv_1="10.conv_1",
81
+ conv_2="10.conv_2",
82
+ ),
83
+ autoencoder_loader.ResidualBlockTensorNames(
84
+ norm_1="11.groupnorm_1",
85
+ norm_2="11.groupnorm_2",
86
+ conv_1="11.conv_1",
87
+ conv_2="11.conv_2",
88
+ ),
89
+ autoencoder_loader.ResidualBlockTensorNames(
90
+ norm_1="12.groupnorm_1",
91
+ norm_2="12.groupnorm_2",
92
+ conv_1="12.conv_1",
93
+ conv_2="12.conv_2",
94
+ ),
95
+ ],
96
+ upsample_conv="14",
97
+ ),
98
+ autoencoder_loader.UpDecoderBlockTensorNames(
99
+ residual_block_tensor_names=[
100
+ autoencoder_loader.ResidualBlockTensorNames(
101
+ norm_1="15.groupnorm_1",
102
+ norm_2="15.groupnorm_2",
103
+ conv_1="15.conv_1",
104
+ conv_2="15.conv_2",
105
+ residual_layer="15.residual_layer",
106
+ ),
107
+ autoencoder_loader.ResidualBlockTensorNames(
108
+ norm_1="16.groupnorm_1",
109
+ norm_2="16.groupnorm_2",
110
+ conv_1="16.conv_1",
111
+ conv_2="16.conv_2",
112
+ ),
113
+ autoencoder_loader.ResidualBlockTensorNames(
114
+ norm_1="17.groupnorm_1",
115
+ norm_2="17.groupnorm_2",
116
+ conv_1="17.conv_1",
117
+ conv_2="17.conv_2",
118
+ ),
119
+ ],
120
+ upsample_conv="19",
121
+ ),
122
+ autoencoder_loader.UpDecoderBlockTensorNames(
123
+ residual_block_tensor_names=[
124
+ autoencoder_loader.ResidualBlockTensorNames(
125
+ norm_1="20.groupnorm_1",
126
+ norm_2="20.groupnorm_2",
127
+ conv_1="20.conv_1",
128
+ conv_2="20.conv_2",
129
+ residual_layer="20.residual_layer",
130
+ ),
131
+ autoencoder_loader.ResidualBlockTensorNames(
132
+ norm_1="21.groupnorm_1",
133
+ norm_2="21.groupnorm_2",
134
+ conv_1="21.conv_1",
135
+ conv_2="21.conv_2",
136
+ ),
137
+ autoencoder_loader.ResidualBlockTensorNames(
138
+ norm_1="22.groupnorm_1",
139
+ norm_2="22.groupnorm_2",
140
+ conv_1="22.conv_1",
141
+ conv_2="22.conv_2",
142
+ ),
143
+ ],
144
+ ),
145
+ ],
146
+ final_norm="23",
147
+ conv_out="25",
148
+ )
22
149
 
23
- class AttentionBlock(nn.Module):
24
150
 
25
- def __init__(self, channels):
26
- super().__init__()
27
- self.groupnorm = nn.GroupNorm(32, channels)
28
- self.attention = SelfAttention(1, channels)
151
+ class Decoder(nn.Module):
152
+ """The Decoder model used in Stable Diffusion.
29
153
 
30
- def forward(self, x):
31
- residue = x
32
- x = self.groupnorm(x)
154
+ For details, see https://arxiv.org/abs/2103.00020
33
155
 
34
- n, c, h, w = x.shape
35
- x = x.view((n, c, h * w))
36
- x = x.transpose(-1, -2)
37
- x = self.attention(x)
38
- x = x.transpose(-1, -2)
39
- x = x.view((n, c, h, w))
156
+ Sturcture of the Decoder:
40
157
 
41
- x += residue
42
- return x
158
+ latents tensor
159
+ |
160
+
161
+ ┌───────────────────┐
162
+ │ Post Quant Conv │
163
+ └─────────┬─────────┘
164
+
165
+ ┌─────────▼─────────┐
166
+ │ ConvIn │
167
+ └─────────┬─────────┘
168
+
169
+ ┌─────────▼─────────┐
170
+ │ MidBlock2D │
171
+ └─────────┬─────────┘
172
+
173
+ ┌─────────▼─────────┐
174
+ │ UpDecoder2D │ x 4
175
+ └─────────┬─────────┘
176
+
177
+ ┌─────────▼─────────┐
178
+ │ FinalNorm │
179
+ └─────────┬─────────┘
180
+ |
181
+ ┌─────────▼─────────┐
182
+ │ Activation │
183
+ └─────────┬─────────┘
184
+ |
185
+ ┌─────────▼─────────┐
186
+ │ ConvOut │
187
+ └─────────┬─────────┘
188
+ |
189
+
190
+ Output Image
191
+ """
192
+
193
+ def __init__(self, config: unet_cfg.AutoEncoderConfig):
194
+ super().__init__()
195
+ self.config = config
196
+ self.post_quant_conv = nn.Conv2d(
197
+ config.latent_channels,
198
+ config.latent_channels,
199
+ kernel_size=1,
200
+ stride=1,
201
+ padding=0,
202
+ )
203
+ reversed_block_out_channels = list(reversed(config.block_out_channels))
204
+ self.conv_in = nn.Conv2d(
205
+ config.latent_channels,
206
+ reversed_block_out_channels[0],
207
+ kernel_size=3,
208
+ stride=1,
209
+ padding=1,
210
+ )
211
+ self.mid_block = blocks_2d.MidBlock2D(config.mid_block_config)
212
+ up_decoder_blocks = []
213
+ block_out_channels = reversed_block_out_channels[0]
214
+ for i, out_channels in enumerate(reversed_block_out_channels):
215
+ prev_output_channel = block_out_channels
216
+ block_out_channels = out_channels
217
+ not_final_block = i < len(reversed_block_out_channels) - 1
218
+ up_decoder_blocks.append(
219
+ blocks_2d.UpDecoderBlock2D(
220
+ unet_cfg.UpDecoderBlock2DConfig(
221
+ in_channels=prev_output_channel,
222
+ out_channels=block_out_channels,
223
+ normalization_config=config.normalization_config,
224
+ activation_type=config.activation_type,
225
+ num_layers=config.layers_per_block,
226
+ add_upsample=not_final_block,
227
+ upsample_conv=True,
228
+ sampling_config=unet_cfg.SamplingConfig(
229
+ 2, unet_cfg.SamplingType.NEAREST
230
+ ),
231
+ )
232
+ )
233
+ )
234
+ self.up_decoder_blocks = nn.ModuleList(up_decoder_blocks)
235
+ self.final_norm = layers_builder.build_norm(
236
+ block_out_channels, config.normalization_config
237
+ )
238
+ self.act_fn = layers_builder.get_activation(config.activation_type)
239
+ self.conv_out = nn.Conv2d(
240
+ block_out_channels,
241
+ config.out_channels,
242
+ kernel_size=3,
243
+ stride=1,
244
+ padding=1,
245
+ )
43
246
 
247
+ def forward(self, latents_tensor: torch.Tensor) -> torch.Tensor:
248
+ x = latents_tensor / self.config.scaling_factor
249
+ x = self.post_quant_conv(x)
250
+ x = self.conv_in(x)
251
+ x = self.mid_block(x)
252
+ for up_decoder_block in self.up_decoder_blocks:
253
+ x = up_decoder_block(x)
254
+ x = self.final_norm(x)
255
+ x = self.act_fn(x)
256
+ x = self.conv_out(x)
257
+ return x
44
258
 
45
- class ResidualBlock(nn.Module):
46
259
 
47
- def __init__(self, in_channels, out_channels):
48
- super().__init__()
49
- self.groupnorm_1 = nn.GroupNorm(32, in_channels)
50
- self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
260
+ def get_model_config() -> unet_cfg.AutoEncoderConfig:
261
+ """Get configs for the Decoder of Stable Diffusion v1.5"""
262
+ in_channels = 3
263
+ latent_channels = 4
264
+ out_channels = 3
265
+ block_out_channels = [128, 256, 512, 512]
266
+ scaling_factor = 0.18215
267
+ layers_per_block = 3
51
268
 
52
- self.groupnorm_2 = nn.GroupNorm(32, out_channels)
53
- self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
269
+ norm_config = layers_cfg.NormalizationConfig(
270
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=32
271
+ )
54
272
 
55
- if in_channels == out_channels:
56
- self.residual_layer = nn.Identity()
57
- else:
58
- self.residual_layer = nn.Conv2d(
59
- in_channels, out_channels, kernel_size=1, padding=0
60
- )
273
+ att_config = unet_cfg.AttentionBlock2DConfig(
274
+ dims=block_out_channels[-1],
275
+ normalization_config=norm_config,
276
+ attention_config=layers_cfg.AttentionConfig(
277
+ num_heads=1,
278
+ num_query_groups=1,
279
+ qkv_use_bias=True,
280
+ output_proj_use_bias=True,
281
+ enable_kv_cache=False,
282
+ qkv_transpose_before_split=True,
283
+ rotary_percentage=0.0,
284
+ ),
285
+ )
61
286
 
62
- def forward(self, x):
63
- residue = x
64
-
65
- x = self.groupnorm_1(x)
66
- x = F.silu(x)
67
- x = self.conv_1(x)
68
-
69
- x = self.groupnorm_2(x)
70
- x = F.silu(x)
71
- x = self.conv_2(x)
72
-
73
- return x + self.residual_layer(residue)
74
-
75
-
76
- class Decoder(nn.Sequential):
77
-
78
- def __init__(self):
79
- super().__init__(
80
- nn.Conv2d(4, 4, kernel_size=1, padding=0),
81
- nn.Conv2d(4, 512, kernel_size=3, padding=1),
82
- ResidualBlock(512, 512),
83
- AttentionBlock(512),
84
- ResidualBlock(512, 512),
85
- ResidualBlock(512, 512),
86
- ResidualBlock(512, 512),
87
- ResidualBlock(512, 512),
88
- nn.Upsample(scale_factor=2),
89
- nn.Conv2d(512, 512, kernel_size=3, padding=1),
90
- ResidualBlock(512, 512),
91
- ResidualBlock(512, 512),
92
- ResidualBlock(512, 512),
93
- nn.Upsample(scale_factor=2),
94
- nn.Conv2d(512, 512, kernel_size=3, padding=1),
95
- ResidualBlock(512, 256),
96
- ResidualBlock(256, 256),
97
- ResidualBlock(256, 256),
98
- nn.Upsample(scale_factor=2),
99
- nn.Conv2d(256, 256, kernel_size=3, padding=1),
100
- ResidualBlock(256, 128),
101
- ResidualBlock(128, 128),
102
- ResidualBlock(128, 128),
103
- nn.GroupNorm(32, 128),
104
- nn.SiLU(),
105
- nn.Conv2d(128, 3, kernel_size=3, padding=1),
106
- )
287
+ mid_block_config = unet_cfg.MidBlock2DConfig(
288
+ in_channels=block_out_channels[-1],
289
+ normalization_config=norm_config,
290
+ activation_type=layers_cfg.ActivationType.SILU,
291
+ num_layers=1,
292
+ attention_block_config=att_config,
293
+ )
107
294
 
108
- @torch.inference_mode
109
- def forward(self, x):
110
- x = x / 0.18215
111
- for module in self:
112
- x = module(x)
113
- return x
295
+ config = unet_cfg.AutoEncoderConfig(
296
+ in_channels=in_channels,
297
+ latent_channels=latent_channels,
298
+ out_channels=out_channels,
299
+ activation_type=layers_cfg.ActivationType.SILU,
300
+ block_out_channels=block_out_channels,
301
+ scaling_factor=scaling_factor,
302
+ layers_per_block=layers_per_block,
303
+ normalization_config=norm_config,
304
+ mid_block_config=mid_block_config,
305
+ )
306
+ return config
@@ -17,9 +17,60 @@ import torch
17
17
  from torch import nn
18
18
  from torch.nn import functional as F
19
19
 
20
- from ai_edge_torch.generative.examples.stable_diffusion.decoder import AttentionBlock # NOQA
21
- from ai_edge_torch.generative.examples.stable_diffusion.decoder import ResidualBlock # NOQA
22
- import ai_edge_torch.generative.utilities.loader as loading_utils
20
+ from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
21
+
22
+
23
+ class AttentionBlock(nn.Module):
24
+
25
+ def __init__(self, channels):
26
+ super().__init__()
27
+ self.groupnorm = nn.GroupNorm(32, channels)
28
+ self.attention = SelfAttention(1, channels)
29
+
30
+ def forward(self, x):
31
+ residue = x
32
+ x = self.groupnorm(x)
33
+
34
+ n, c, h, w = x.shape
35
+ x = x.view((n, c, h * w))
36
+ x = x.transpose(-1, -2)
37
+ x = self.attention(x)
38
+ x = x.transpose(-1, -2)
39
+ x = x.view((n, c, h, w))
40
+
41
+ x += residue
42
+ return x
43
+
44
+
45
+ class ResidualBlock(nn.Module):
46
+
47
+ def __init__(self, in_channels, out_channels):
48
+ super().__init__()
49
+ self.groupnorm_1 = nn.GroupNorm(32, in_channels)
50
+ self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
51
+
52
+ self.groupnorm_2 = nn.GroupNorm(32, out_channels)
53
+ self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
54
+
55
+ if in_channels == out_channels:
56
+ self.residual_layer = nn.Identity()
57
+ else:
58
+ self.residual_layer = nn.Conv2d(
59
+ in_channels, out_channels, kernel_size=1, padding=0
60
+ )
61
+
62
+ def forward(self, x):
63
+ residue = x
64
+
65
+ x = self.groupnorm_1(x)
66
+ x = F.silu(x)
67
+ x = self.conv_1(x)
68
+
69
+ x = self.groupnorm_2(x)
70
+ x = F.silu(x)
71
+ x = self.conv_2(x)
72
+
73
+ return x + self.residual_layer(residue)
23
74
 
24
75
 
25
76
  class Encoder(nn.Sequential):
@@ -199,3 +199,28 @@ class CausalSelfAttention(nn.Module):
199
199
  # Compute the output projection.
200
200
  y = self.output_projection(y)
201
201
  return y
202
+
203
+
204
+ class SelfAttention(CausalSelfAttention):
205
+ """Non-causal Self Attention module, which is equivalent to CausalSelfAttention without mask."""
206
+
207
+ def forward(
208
+ self,
209
+ x: torch.Tensor,
210
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
211
+ input_pos: Optional[torch.Tensor] = None,
212
+ ) -> torch.Tensor:
213
+ """Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
214
+
215
+ Args:
216
+ x (torch.Tensor): the input tensor.
217
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
218
+ input_pos (torch.Tensor): the optional input position tensor.
219
+
220
+ Returns:
221
+ output activation from this self attention layer.
222
+ """
223
+ B, T, _ = x.size()
224
+ return super().forward(
225
+ x, rope=rope, mask=torch.zeros((B, T), dtype=torch.float32), input_pos=input_pos
226
+ )
@@ -44,6 +44,8 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
44
44
  )
45
45
  elif config.type == cfg.NormalizationType.LAYER_NORM:
46
46
  return nn.LayerNorm(dim, eps=config.epsilon)
47
+ elif config.type == cfg.NormalizationType.GROUP_NORM:
48
+ return nn.GroupNorm(config.group_num, dim, config.epsilon)
47
49
  else:
48
50
  raise ValueError("Unsupported norm type.")
49
51
 
@@ -69,7 +71,7 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
69
71
  else:
70
72
  raise ValueError("Unsupported feedforward type.")
71
73
 
72
- activation = _get_activation(config.activation)
74
+ activation = get_activation(config.activation)
73
75
 
74
76
  return ff_module(
75
77
  dim=dim,
@@ -79,7 +81,7 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
79
81
  )
80
82
 
81
83
 
82
- def _get_activation(type_: cfg.ActivationType):
84
+ def get_activation(type_: cfg.ActivationType):
83
85
  """Get pytorch callable activation from the name.
84
86
 
85
87
  Args:
@@ -39,6 +39,7 @@ class NormalizationType(enum.Enum):
39
39
  NONE = enum.auto()
40
40
  RMS_NORM = enum.auto()
41
41
  LAYER_NORM = enum.auto()
42
+ GROUP_NORM = enum.auto()
42
43
 
43
44
 
44
45
  @enum.unique
@@ -90,6 +91,8 @@ class NormalizationConfig:
90
91
  type: NormalizationType = NormalizationType.NONE
91
92
  epsilon: float = 1e-5
92
93
  zero_centered: bool = False
94
+ # Number of groups used in group normalization.
95
+ group_num: Optional[float] = None
93
96
 
94
97
 
95
98
  @dataclass
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================