ai-edge-torch-nightly 0.2.0.dev20240606__py3-none-any.whl → 0.2.0.dev20240609__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (20) hide show
  1. ai_edge_torch/convert/conversion.py +2 -2
  2. ai_edge_torch/convert/fx_passes/__init__.py +1 -1
  3. ai_edge_torch/convert/fx_passes/{build_upsample_bilinear2d_composite_pass.py → build_interpolate_composite_pass.py} +22 -1
  4. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +8 -4
  5. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +275 -82
  6. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +54 -3
  7. ai_edge_torch/generative/layers/attention.py +25 -0
  8. ai_edge_torch/generative/layers/builder.py +4 -2
  9. ai_edge_torch/generative/layers/model_config.py +3 -0
  10. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  11. ai_edge_torch/generative/layers/unet/blocks_2d.py +287 -0
  12. ai_edge_torch/generative/layers/unet/builder.py +29 -0
  13. ai_edge_torch/generative/layers/unet/model_config.py +117 -0
  14. ai_edge_torch/generative/utilities/autoencoder_loader.py +298 -0
  15. ai_edge_torch/generative/utilities/loader.py +7 -5
  16. {ai_edge_torch_nightly-0.2.0.dev20240606.dist-info → ai_edge_torch_nightly-0.2.0.dev20240609.dist-info}/METADATA +1 -1
  17. {ai_edge_torch_nightly-0.2.0.dev20240606.dist-info → ai_edge_torch_nightly-0.2.0.dev20240609.dist-info}/RECORD +20 -15
  18. {ai_edge_torch_nightly-0.2.0.dev20240606.dist-info → ai_edge_torch_nightly-0.2.0.dev20240609.dist-info}/LICENSE +0 -0
  19. {ai_edge_torch_nightly-0.2.0.dev20240606.dist-info → ai_edge_torch_nightly-0.2.0.dev20240609.dist-info}/WHEEL +0 -0
  20. {ai_edge_torch_nightly-0.2.0.dev20240606.dist-info → ai_edge_torch_nightly-0.2.0.dev20240609.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,287 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Optional
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ from ai_edge_torch.generative.layers.attention import SelfAttention
22
+ import ai_edge_torch.generative.layers.builder as layers_builder
23
+ import ai_edge_torch.generative.layers.unet.builder as unet_builder
24
+ import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
25
+
26
+
27
+ class ResidualBlock2D(nn.Module):
28
+ """2D Residual block containing two Conv2D with optional time embedding as input."""
29
+
30
+ def __init__(self, config: unet_cfg.ResidualBlock2DConfig):
31
+ """Initialize an instance of the ResidualBlock2D.
32
+
33
+ Args:
34
+ config (unet_cfg.ResidualBlock2DConfig): the configuration of this block.
35
+ """
36
+ super().__init__()
37
+ self.config = config
38
+ self.norm_1 = layers_builder.build_norm(
39
+ config.in_channels, config.normalization_config
40
+ )
41
+ self.conv_1 = nn.Conv2d(
42
+ config.in_channels, config.out_channels, kernel_size=3, stride=1, padding=1
43
+ )
44
+ if config.time_embedding_channels is not None:
45
+ self.time_emb_proj = nn.Linear(
46
+ config.time_embedding_channels, config.out_channels
47
+ )
48
+ else:
49
+ self.time_emb_proj = None
50
+ self.norm_2 = layers_builder.build_norm(
51
+ config.out_channels, config.normalization_config
52
+ )
53
+ self.conv_2 = nn.Conv2d(
54
+ config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
55
+ )
56
+ self.act_fn = layers_builder.get_activation(config.activation_type)
57
+ if config.in_channels == config.out_channels:
58
+ self.residual_layer = nn.Identity()
59
+ else:
60
+ self.residual_layer = nn.Conv2d(
61
+ config.in_channels, config.out_channels, kernel_size=1, stride=1, padding=0
62
+ )
63
+
64
+ def forward(
65
+ self, input_tensor: torch.Tensor, time_emb: Optional[torch.Tensor] = None
66
+ ) -> torch.Tensor:
67
+ """Forward function of the ResidualBlock2D.
68
+
69
+ Args:
70
+ input_tensor (torch.Tensor): the input tensor.
71
+ time_emb (Optional[torch.Tensor]): optional time embedding tensor.
72
+
73
+ Returns:
74
+ output hidden_states tensor after ResidualBlock2D.
75
+ """
76
+ residual = input_tensor
77
+ x = self.norm_1(input_tensor)
78
+ x = self.act_fn(x)
79
+ x = self.conv_1(x)
80
+ if self.time_emb_proj is not None:
81
+ time_emb = self.time_emb_proj(time_emb)[:, :, None, None]
82
+ x = x + time_emb
83
+ x = self.norm_2(x)
84
+ x = self.act_fn(x)
85
+ x = self.conv_2(x)
86
+ x = x + self.residual_layer(residual)
87
+ return x
88
+
89
+
90
+ class AttentionBlock2D(nn.Module):
91
+ """2D self attention block
92
+
93
+ x = SelfAttention(Norm(input_tensor))
94
+
95
+ """
96
+
97
+ def __init__(self, config: unet_cfg.AttentionBlock2DConfig):
98
+ """Initialize an instance of the AttentionBlock2D.
99
+
100
+ Args:
101
+ config (unet_cfg.AttentionBlock2DConfig): the configuration of this block.
102
+ """
103
+ super().__init__()
104
+ self.norm = layers_builder.build_norm(config.dims, config.normalization_config)
105
+ self.attention = SelfAttention(config.dims, config.attention_config, 0, True)
106
+
107
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
108
+ """Forward function of the AttentionBlock2D.
109
+
110
+ Args:
111
+ input_tensor (torch.Tensor): the input tensor.
112
+
113
+ Returns:
114
+ output activation tensor after self attention.
115
+ """
116
+ residual = input_tensor
117
+ x = self.norm(input_tensor)
118
+ B, C, H, W = x.shape
119
+ x = x.view(B, C, H * W)
120
+ x = x.transpose(-1, -2)
121
+ x = self.attention(x)
122
+ x = x.transpose(-1, -2)
123
+ x = x.view(B, C, H, W)
124
+ x = x + residual
125
+ return x
126
+
127
+
128
+ class UpDecoderBlock2D(nn.Module):
129
+ """Decoder block containing several residual blocks followed by an optional upsampler.
130
+
131
+ input_tensor
132
+ |
133
+
134
+ ┌───────────────────┐
135
+ │ ResidualBlock2D │ num_layers
136
+ └─────────┬─────────┘
137
+
138
+ ┌─────────▼─────────┐
139
+ │ (Optional) │
140
+ │ Upsampler │
141
+ └─────────┬─────────┘
142
+
143
+ ┌─────────▼─────────┐
144
+ │ (Optional) │
145
+ │ Conv2D │
146
+ └─────────┬─────────┘
147
+
148
+
149
+ hidden_states
150
+ """
151
+
152
+ def __init__(self, config: unet_cfg.UpDecoderBlock2DConfig):
153
+ """Initialize an instance of the UpDecoderBlock2D.
154
+
155
+ Args:
156
+ config (unet_cfg.UpDecoderBlock2DConfig): the configuration of this block.
157
+ """
158
+ super().__init__()
159
+ self.config = config
160
+ resnets = []
161
+ for i in range(config.num_layers):
162
+ input_channels = config.in_channels if i == 0 else config.out_channels
163
+ resnets.append(
164
+ ResidualBlock2D(
165
+ unet_cfg.ResidualBlock2DConfig(
166
+ in_channels=input_channels,
167
+ out_channels=config.out_channels,
168
+ time_embedding_channels=config.time_embedding_channels,
169
+ normalization_config=config.normalization_config,
170
+ activation_type=config.activation_type,
171
+ )
172
+ )
173
+ )
174
+ self.resnets = nn.ModuleList(resnets)
175
+ if config.add_upsample:
176
+ self.upsampler = unet_builder.build_upsampling(config.sampling_config)
177
+ if config.upsample_conv:
178
+ self.upsample_conv = nn.Conv2d(
179
+ config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
180
+ )
181
+ else:
182
+ self.upsampler = None
183
+
184
+ def forward(
185
+ self, input_tensor: torch.Tensor, time_emb: Optional[torch.Tensor] = None
186
+ ) -> torch.Tensor:
187
+ """Forward function of the UpDecoderBlock2D.
188
+
189
+ Args:
190
+ input_tensor (torch.Tensor): the input tensor.
191
+ time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
192
+ time embedding context.
193
+
194
+ Returns:
195
+ output hidden_states tensor after UpDecoderBlock2D.
196
+ """
197
+ hidden_states = input_tensor
198
+ for resnet in self.resnets:
199
+ hidden_states = resnet(hidden_states, time_emb)
200
+ if self.upsampler:
201
+ hidden_states = self.upsampler(hidden_states)
202
+ if self.upsample_conv:
203
+ hidden_states = self.upsample_conv(hidden_states)
204
+ return hidden_states
205
+
206
+
207
+ class MidBlock2D(nn.Module):
208
+ """Middle block containing at least one residual blocks with optional interleaved attention blocks.
209
+
210
+ input_tensor
211
+ |
212
+
213
+ ┌───────────────────┐
214
+ │ ResidualBlock2D │
215
+ └─────────┬─────────┘
216
+
217
+ ┌─────────────▼─────────────┐
218
+ │ ┌───────────────────┐ │
219
+ │ │ (Optional) │ │
220
+ │ │ AttentionBlock2D │ │
221
+ │ └─────────┬─────────┘ │ num_layers
222
+ │ │ │
223
+ │ ┌─────────▼─────────┐ │
224
+ │ │ ResidualBlock2D │ │
225
+ │ └───────────────────┘ │
226
+ └─────────────┬─────────────┘
227
+
228
+
229
+ hidden_states
230
+ """
231
+
232
+ def __init__(self, config: unet_cfg.MidBlock2DConfig):
233
+ """Initialize an instance of the MidBlock2D.
234
+
235
+ Args:
236
+ config (unet_cfg.MidBlock2DConfig): the configuration of this block.
237
+ """
238
+ super().__init__()
239
+ self.config = config
240
+ resnets = [
241
+ ResidualBlock2D(
242
+ unet_cfg.ResidualBlock2DConfig(
243
+ in_channels=config.in_channels,
244
+ out_channels=config.in_channels,
245
+ time_embedding_channels=config.time_embedding_channels,
246
+ normalization_config=config.normalization_config,
247
+ activation_type=config.activation_type,
248
+ )
249
+ )
250
+ ]
251
+ attentions = []
252
+ for i in range(config.num_layers):
253
+ if self.config.attention_block_config:
254
+ attentions.append(AttentionBlock2D(config.attention_block_config))
255
+ resnets.append(
256
+ ResidualBlock2D(
257
+ unet_cfg.ResidualBlock2DConfig(
258
+ in_channels=config.in_channels,
259
+ out_channels=config.in_channels,
260
+ time_embedding_channels=config.time_embedding_channels,
261
+ normalization_config=config.normalization_config,
262
+ activation_type=config.activation_type,
263
+ )
264
+ )
265
+ )
266
+ self.resnets = nn.ModuleList(resnets)
267
+ self.attentions = nn.ModuleList(attentions)
268
+
269
+ def forward(
270
+ self, input_tensor: torch.Tensor, time_emb: Optional[torch.Tensor] = None
271
+ ) -> torch.Tensor:
272
+ """Forward function of the MidBlock2D.
273
+
274
+ Args:
275
+ input_tensor (torch.Tensor): the input tensor.
276
+ time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
277
+ time embedding context.
278
+
279
+ Returns:
280
+ output hidden_states tensor after MidBlock2D.
281
+ """
282
+ hidden_states = self.resnets[0](input_tensor, time_emb)
283
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
284
+ if attn is not None:
285
+ hidden_states = attn(hidden_states)
286
+ hidden_states = resnet(hidden_states, time_emb)
287
+ return hidden_states
@@ -0,0 +1,29 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Builder utils for individual components.
16
+
17
+ from torch import nn
18
+ import torch.nn.functional as F
19
+
20
+ import ai_edge_torch.generative.layers.unet.model_config as unet_config
21
+
22
+
23
+ def build_upsampling(config: unet_config.SamplingConfig):
24
+ if config.mode == unet_config.SamplingType.NEAREST:
25
+ return nn.UpsamplingNearest2d(scale_factor=config.scale_factor)
26
+ elif config.mode == unet_config.SamplingType.BILINEAR:
27
+ return nn.UpsamplingBilinear2d(scale_factor=config.scale_factor)
28
+ else:
29
+ raise ValueError("Unsupported upsampling type.")
@@ -0,0 +1,117 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # UNet configuration class.
17
+ from dataclasses import dataclass
18
+ from dataclasses import field
19
+ import enum
20
+ from typing import List, Optional
21
+
22
+ import ai_edge_torch.generative.layers.model_config as layers_cfg
23
+
24
+
25
+ @dataclass
26
+ class SamplingType(enum.Enum):
27
+ NEAREST = enum.auto()
28
+ BILINEAR = enum.auto()
29
+
30
+
31
+ @dataclass
32
+ class SamplingConfig:
33
+ scale_factor: float
34
+ mode: SamplingType
35
+
36
+
37
+ @dataclass
38
+ class ResidualBlock2DConfig:
39
+ in_channels: int
40
+ out_channels: int
41
+ normalization_config: layers_cfg.NormalizationConfig
42
+ activation_type: layers_cfg.ActivationType
43
+ # Optional time embedding channels if the residual block takes a time embedding context as input
44
+ time_embedding_channels: Optional[int] = None
45
+
46
+
47
+ @dataclass
48
+ class AttentionBlock2DConfig:
49
+ dims: int
50
+ normalization_config: layers_cfg.NormalizationConfig
51
+ attention_config: layers_cfg.AttentionConfig
52
+
53
+
54
+ @dataclass
55
+ class UpDecoderBlock2DConfig:
56
+ in_channels: int
57
+ out_channels: int
58
+ normalization_config: layers_cfg.NormalizationConfig
59
+ activation_type: layers_cfg.ActivationType
60
+ num_layers: int
61
+ # Optional time embedding channels if the residual blocks take a time embedding context as input
62
+ time_embedding_channels: Optional[int] = None
63
+ # Whether to add upsample operation after residual blocks
64
+ add_upsample: bool = True
65
+ # Whether to add a conv2d layer after upsample
66
+ upsample_conv: bool = True
67
+ # Optional sampling config if add_upsample is True.
68
+ sampling_config: Optional[SamplingConfig] = None
69
+
70
+
71
+ @dataclass
72
+ class MidBlock2DConfig:
73
+ in_channels: int
74
+ normalization_config: layers_cfg.NormalizationConfig
75
+ activation_type: layers_cfg.ActivationType
76
+ num_layers: int
77
+ # Optional time embedding channels if the residual blocks take a time embedding context as input
78
+ time_embedding_channels: Optional[int] = None
79
+ # Optional config of attention blocks interleaved with residual blocks
80
+ attention_block_config: Optional[AttentionBlock2DConfig] = None
81
+
82
+
83
+ @dataclass
84
+ class AutoEncoderConfig:
85
+ """Configurations of encoder/decoder in the autoencoder model."""
86
+
87
+ # The activation type of encoder/decoder blocks.
88
+ activation_type: layers_cfg.ActivationType
89
+
90
+ # The output channels of each block.
91
+ block_out_channels: List[int]
92
+
93
+ # Number of channels in the input image.
94
+ in_channels: int
95
+
96
+ # Number of channels in the output.
97
+ out_channels: int
98
+
99
+ # Number of channels in the latent space.
100
+ latent_channels: int
101
+
102
+ # The component-wise standard deviation of the trained latent space computed using the first batch of the
103
+ # training set. This is used to scale the latent space to have unit variance when training the diffusion
104
+ # model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
105
+ # diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
106
+ # / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
107
+ # Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
108
+ scaling_factor: float
109
+
110
+ # The layesr number of each encoder/decoder block.
111
+ layers_per_block: int
112
+
113
+ # The normalization config.
114
+ normalization_config: layers_cfg.NormalizationConfig
115
+
116
+ # The configuration of middle blocks, that is, after the last block of encoder and before the first block of decoder.
117
+ mid_block_config: MidBlock2DConfig