ai-edge-torch-nightly 0.2.0.dev20240611__py3-none-any.whl → 0.2.0.dev20240619__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +19 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -2
- ai_edge_torch/debug/__init__.py +1 -0
- ai_edge_torch/debug/culprit.py +70 -29
- ai_edge_torch/debug/test/test_search_model.py +50 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +9 -6
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +33 -25
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +523 -202
- ai_edge_torch/generative/examples/t5/t5_attention.py +10 -39
- ai_edge_torch/generative/layers/attention.py +154 -26
- ai_edge_torch/generative/layers/model_config.py +3 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +473 -49
- ai_edge_torch/generative/layers/unet/builder.py +20 -2
- ai_edge_torch/generative/layers/unet/model_config.py +157 -5
- ai_edge_torch/generative/test/test_model_conversion.py +24 -0
- ai_edge_torch/generative/test/test_quantize.py +1 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +860 -0
- ai_edge_torch/generative/utilities/t5_loader.py +33 -17
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/RECORD +23 -22
- ai_edge_torch/generative/utilities/autoencoder_loader.py +0 -298
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/top_level.txt +0 -0
|
@@ -22,16 +22,28 @@ from typing import List, Optional
|
|
|
22
22
|
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
@
|
|
25
|
+
@enum.unique
|
|
26
26
|
class SamplingType(enum.Enum):
|
|
27
27
|
NEAREST = enum.auto()
|
|
28
28
|
BILINEAR = enum.auto()
|
|
29
|
+
AVERAGE = enum.auto()
|
|
30
|
+
CONVOLUTION = enum.auto()
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
@dataclass
|
|
32
|
-
class
|
|
34
|
+
class UpSamplingConfig:
|
|
35
|
+
mode: SamplingType
|
|
33
36
|
scale_factor: float
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class DownSamplingConfig:
|
|
34
41
|
mode: SamplingType
|
|
42
|
+
in_channels: int
|
|
43
|
+
kernel_size: int
|
|
44
|
+
stride: int
|
|
45
|
+
padding: int
|
|
46
|
+
out_channels: Optional[int] = None
|
|
35
47
|
|
|
36
48
|
|
|
37
49
|
@dataclass
|
|
@@ -46,9 +58,38 @@ class ResidualBlock2DConfig:
|
|
|
46
58
|
|
|
47
59
|
@dataclass
|
|
48
60
|
class AttentionBlock2DConfig:
|
|
49
|
-
|
|
61
|
+
dim: int
|
|
62
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
63
|
+
attention_config: layers_cfg.AttentionConfig
|
|
64
|
+
enable_hlfb: bool = True
|
|
65
|
+
attention_batch_size: int = 1
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class CrossAttentionBlock2DConfig:
|
|
70
|
+
query_dim: int
|
|
71
|
+
cross_dim: int
|
|
50
72
|
normalization_config: layers_cfg.NormalizationConfig
|
|
51
73
|
attention_config: layers_cfg.AttentionConfig
|
|
74
|
+
enable_hlfb: bool = True
|
|
75
|
+
attention_batch_size: int = 1
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@dataclass
|
|
79
|
+
class FeedForwardBlock2DConfig:
|
|
80
|
+
dim: int
|
|
81
|
+
hidden_dim: int
|
|
82
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
83
|
+
activation_config: layers_cfg.ActivationConfig
|
|
84
|
+
use_bias: bool
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class TransformerBlock2Dconfig:
|
|
89
|
+
pre_conv_normalization_config: layers_cfg.NormalizationConfig
|
|
90
|
+
attention_block_config: AttentionBlock2DConfig
|
|
91
|
+
cross_attention_block_config: CrossAttentionBlock2DConfig
|
|
92
|
+
feed_forward_block_config: FeedForwardBlock2DConfig
|
|
52
93
|
|
|
53
94
|
|
|
54
95
|
@dataclass
|
|
@@ -58,14 +99,62 @@ class UpDecoderBlock2DConfig:
|
|
|
58
99
|
normalization_config: layers_cfg.NormalizationConfig
|
|
59
100
|
activation_config: layers_cfg.ActivationConfig
|
|
60
101
|
num_layers: int
|
|
61
|
-
# Optional time embedding channels if the residual blocks take a time embedding
|
|
102
|
+
# Optional time embedding channels if the residual blocks take a time embedding as input
|
|
62
103
|
time_embedding_channels: Optional[int] = None
|
|
63
104
|
# Whether to add upsample operation after residual blocks
|
|
64
105
|
add_upsample: bool = True
|
|
65
106
|
# Whether to add a conv2d layer after upsample
|
|
66
107
|
upsample_conv: bool = True
|
|
67
108
|
# Optional sampling config if add_upsample is True.
|
|
68
|
-
sampling_config: Optional[
|
|
109
|
+
sampling_config: Optional[UpSamplingConfig] = None
|
|
110
|
+
# Optional config of transformer blocks interleaved with residual blocks
|
|
111
|
+
transformer_block_config: Optional[TransformerBlock2Dconfig] = None
|
|
112
|
+
# Optional dimension of context tensor if context tensor is given as input.
|
|
113
|
+
context_dim: Optional[int] = None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@dataclass
|
|
117
|
+
class SkipUpDecoderBlock2DConfig:
|
|
118
|
+
in_channels: int
|
|
119
|
+
out_channels: int
|
|
120
|
+
# The dimension of output channels of previous connected block
|
|
121
|
+
prev_out_channels: int
|
|
122
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
123
|
+
activation_config: layers_cfg.ActivationConfig
|
|
124
|
+
num_layers: int
|
|
125
|
+
# Optional time embedding channels if the residual blocks take a time embedding as input
|
|
126
|
+
time_embedding_channels: Optional[int] = None
|
|
127
|
+
# Whether to add upsample operation after residual blocks
|
|
128
|
+
add_upsample: bool = True
|
|
129
|
+
# Whether to add a conv2d layer after upsample
|
|
130
|
+
upsample_conv: bool = True
|
|
131
|
+
# Optional sampling config if add_upsample is True.
|
|
132
|
+
sampling_config: Optional[UpSamplingConfig] = None
|
|
133
|
+
# Optional config of transformer blocks interleaved with residual blocks
|
|
134
|
+
transformer_block_config: Optional[TransformerBlock2Dconfig] = None
|
|
135
|
+
# Optional dimension of context tensor if context tensor is given as input.
|
|
136
|
+
context_dim: Optional[int] = None
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@dataclass
|
|
140
|
+
class DownEncoderBlock2DConfig:
|
|
141
|
+
in_channels: int
|
|
142
|
+
out_channels: int
|
|
143
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
144
|
+
activation_config: layers_cfg.ActivationConfig
|
|
145
|
+
num_layers: int
|
|
146
|
+
# Padding for the downsampling convolution.
|
|
147
|
+
padding: int = 1
|
|
148
|
+
# Optional time embedding channels if the residual blocks take a time embedding as input
|
|
149
|
+
time_embedding_channels: Optional[int] = None
|
|
150
|
+
# Whether to add downsample operation after residual blocks
|
|
151
|
+
add_downsample: bool = True
|
|
152
|
+
# Optional sampling config if add_upsample is True.
|
|
153
|
+
sampling_config: Optional[DownSamplingConfig] = None
|
|
154
|
+
# Optional config of transformer blocks interleaved with residual blocks
|
|
155
|
+
transformer_block_config: Optional[TransformerBlock2Dconfig] = None
|
|
156
|
+
# Optional dimension of context tensor if context tensor is given as input.
|
|
157
|
+
context_dim: Optional[int] = None
|
|
69
158
|
|
|
70
159
|
|
|
71
160
|
@dataclass
|
|
@@ -78,6 +167,10 @@ class MidBlock2DConfig:
|
|
|
78
167
|
time_embedding_channels: Optional[int] = None
|
|
79
168
|
# Optional config of attention blocks interleaved with residual blocks
|
|
80
169
|
attention_block_config: Optional[AttentionBlock2DConfig] = None
|
|
170
|
+
# Optional config of transformer blocks interleaved with residual blocks
|
|
171
|
+
transformer_block_config: Optional[TransformerBlock2Dconfig] = None
|
|
172
|
+
# Optional dimension of context tensor if context tensor is given as input.
|
|
173
|
+
context_dim: Optional[int] = None
|
|
81
174
|
|
|
82
175
|
|
|
83
176
|
@dataclass
|
|
@@ -115,3 +208,62 @@ class AutoEncoderConfig:
|
|
|
115
208
|
|
|
116
209
|
# The configuration of middle blocks, that is, after the last block of encoder and before the first block of decoder.
|
|
117
210
|
mid_block_config: MidBlock2DConfig
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@dataclass
|
|
214
|
+
class DiffusionModelConfig:
|
|
215
|
+
"""Configurations of Diffusion model."""
|
|
216
|
+
|
|
217
|
+
# Number of channels in the input tensor.
|
|
218
|
+
in_channels: int
|
|
219
|
+
|
|
220
|
+
# Number of channels in the output tensor.
|
|
221
|
+
out_channels: int
|
|
222
|
+
|
|
223
|
+
# The output channels of each block.
|
|
224
|
+
block_out_channels: List[int]
|
|
225
|
+
|
|
226
|
+
# The layesr number of each block.
|
|
227
|
+
layers_per_block: int
|
|
228
|
+
|
|
229
|
+
# The padding to use for the downsampling.
|
|
230
|
+
downsample_padding: int
|
|
231
|
+
|
|
232
|
+
# Normalization config used in residual blocks.
|
|
233
|
+
residual_norm_config: layers_cfg.NormalizationConfig
|
|
234
|
+
|
|
235
|
+
# Activation config used in residual blocks
|
|
236
|
+
residual_activation_type: layers_cfg.ActivationType
|
|
237
|
+
|
|
238
|
+
# The batch size used in transformer blocks, for attention layers.
|
|
239
|
+
transformer_batch_size: int
|
|
240
|
+
|
|
241
|
+
# The number of attention heads used in transformer blocks.
|
|
242
|
+
transformer_num_attention_heads: int
|
|
243
|
+
|
|
244
|
+
# The dimension of cross attention used in transformer blocks.
|
|
245
|
+
transformer_cross_attention_dim: int
|
|
246
|
+
|
|
247
|
+
# Normalization config used in prev conv layer of transformer blocks.
|
|
248
|
+
transformer_pre_conv_norm_config: layers_cfg.NormalizationConfig
|
|
249
|
+
|
|
250
|
+
# Normalization config used in transformer blocks.
|
|
251
|
+
transformer_norm_config: layers_cfg.NormalizationConfig
|
|
252
|
+
|
|
253
|
+
# Activation type of feed forward used in transformer blocks.
|
|
254
|
+
transformer_ff_activation_type: layers_cfg.ActivationType
|
|
255
|
+
|
|
256
|
+
# Number of layers in mid block.
|
|
257
|
+
mid_block_layers: int
|
|
258
|
+
|
|
259
|
+
# Dimension of time embedding.
|
|
260
|
+
time_embedding_dim: int
|
|
261
|
+
|
|
262
|
+
# Time embedding dimensions for blocks.
|
|
263
|
+
time_embedding_blocks_dim: int
|
|
264
|
+
|
|
265
|
+
# Normalization config used for final layer
|
|
266
|
+
final_norm_config: layers_cfg.NormalizationConfig
|
|
267
|
+
|
|
268
|
+
# Activation type used in final layer
|
|
269
|
+
final_activation_type: layers_cfg.ActivationType
|
|
@@ -55,6 +55,30 @@ class TestModelConversion(unittest.TestCase):
|
|
|
55
55
|
)
|
|
56
56
|
)
|
|
57
57
|
|
|
58
|
+
def test_toy_model_with_multi_batches(self):
|
|
59
|
+
config = toy_model_with_kv_cache.get_model_config()
|
|
60
|
+
config.batch_size = 2
|
|
61
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
62
|
+
idx, input_pos = torch.tensor([[1], [2]], dtype=torch.long), torch.tensor(
|
|
63
|
+
[10], dtype=torch.int64
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
67
|
+
|
|
68
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
69
|
+
skip_output_check = True
|
|
70
|
+
if skip_output_check is False:
|
|
71
|
+
self.assertTrue(
|
|
72
|
+
model_coverage.compare_tflite_torch(
|
|
73
|
+
edge_model,
|
|
74
|
+
pytorch_model,
|
|
75
|
+
(idx, input_pos),
|
|
76
|
+
num_valid_inputs=1,
|
|
77
|
+
atol=1e-5,
|
|
78
|
+
rtol=1e-5,
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
|
|
58
82
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
|
59
83
|
config = toy_model_with_kv_cache.get_model_config()
|
|
60
84
|
config.enable_hlfb = True
|
|
@@ -116,6 +116,7 @@ class TestQuantizeConvert(unittest.TestCase):
|
|
|
116
116
|
]
|
|
117
117
|
)
|
|
118
118
|
def test_quantize_convert_toy_sizes(self, quant_config, expected_compression):
|
|
119
|
+
self.skipTest("b/346896669")
|
|
119
120
|
config = toy_model_with_kv_cache.get_model_config()
|
|
120
121
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
121
122
|
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|