ai-edge-torch-nightly 0.2.0.dev20240610__py3-none-any.whl → 0.2.0.dev20240617__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/conversion_utils.py +17 -5
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +19 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -2
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +9 -6
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +33 -25
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +523 -202
- ai_edge_torch/generative/examples/t5/t5_attention.py +10 -39
- ai_edge_torch/generative/layers/attention.py +154 -26
- ai_edge_torch/generative/layers/model_config.py +4 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +473 -49
- ai_edge_torch/generative/layers/unet/builder.py +20 -2
- ai_edge_torch/generative/layers/unet/model_config.py +157 -5
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +164 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +2 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +49 -4
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -2
- ai_edge_torch/generative/quantize/quant_recipes.py +3 -3
- ai_edge_torch/generative/quantize/supported_schemes.py +2 -1
- ai_edge_torch/generative/test/test_model_conversion.py +24 -0
- ai_edge_torch/generative/test/test_quantize.py +75 -20
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +860 -0
- ai_edge_torch/generative/utilities/t5_loader.py +33 -17
- ai_edge_torch/quantize/quant_config.py +11 -15
- {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/RECORD +29 -27
- ai_edge_torch/generative/utilities/autoencoder_loader.py +0 -298
- {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/top_level.txt +0 -0
|
@@ -22,16 +22,28 @@ from typing import List, Optional
|
|
|
22
22
|
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
@
|
|
25
|
+
@enum.unique
|
|
26
26
|
class SamplingType(enum.Enum):
|
|
27
27
|
NEAREST = enum.auto()
|
|
28
28
|
BILINEAR = enum.auto()
|
|
29
|
+
AVERAGE = enum.auto()
|
|
30
|
+
CONVOLUTION = enum.auto()
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
@dataclass
|
|
32
|
-
class
|
|
34
|
+
class UpSamplingConfig:
|
|
35
|
+
mode: SamplingType
|
|
33
36
|
scale_factor: float
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class DownSamplingConfig:
|
|
34
41
|
mode: SamplingType
|
|
42
|
+
in_channels: int
|
|
43
|
+
kernel_size: int
|
|
44
|
+
stride: int
|
|
45
|
+
padding: int
|
|
46
|
+
out_channels: Optional[int] = None
|
|
35
47
|
|
|
36
48
|
|
|
37
49
|
@dataclass
|
|
@@ -46,9 +58,38 @@ class ResidualBlock2DConfig:
|
|
|
46
58
|
|
|
47
59
|
@dataclass
|
|
48
60
|
class AttentionBlock2DConfig:
|
|
49
|
-
|
|
61
|
+
dim: int
|
|
62
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
63
|
+
attention_config: layers_cfg.AttentionConfig
|
|
64
|
+
enable_hlfb: bool = True
|
|
65
|
+
attention_batch_size: int = 1
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class CrossAttentionBlock2DConfig:
|
|
70
|
+
query_dim: int
|
|
71
|
+
cross_dim: int
|
|
50
72
|
normalization_config: layers_cfg.NormalizationConfig
|
|
51
73
|
attention_config: layers_cfg.AttentionConfig
|
|
74
|
+
enable_hlfb: bool = True
|
|
75
|
+
attention_batch_size: int = 1
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@dataclass
|
|
79
|
+
class FeedForwardBlock2DConfig:
|
|
80
|
+
dim: int
|
|
81
|
+
hidden_dim: int
|
|
82
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
83
|
+
activation_config: layers_cfg.ActivationConfig
|
|
84
|
+
use_bias: bool
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class TransformerBlock2Dconfig:
|
|
89
|
+
pre_conv_normalization_config: layers_cfg.NormalizationConfig
|
|
90
|
+
attention_block_config: AttentionBlock2DConfig
|
|
91
|
+
cross_attention_block_config: CrossAttentionBlock2DConfig
|
|
92
|
+
feed_forward_block_config: FeedForwardBlock2DConfig
|
|
52
93
|
|
|
53
94
|
|
|
54
95
|
@dataclass
|
|
@@ -58,14 +99,62 @@ class UpDecoderBlock2DConfig:
|
|
|
58
99
|
normalization_config: layers_cfg.NormalizationConfig
|
|
59
100
|
activation_config: layers_cfg.ActivationConfig
|
|
60
101
|
num_layers: int
|
|
61
|
-
# Optional time embedding channels if the residual blocks take a time embedding
|
|
102
|
+
# Optional time embedding channels if the residual blocks take a time embedding as input
|
|
62
103
|
time_embedding_channels: Optional[int] = None
|
|
63
104
|
# Whether to add upsample operation after residual blocks
|
|
64
105
|
add_upsample: bool = True
|
|
65
106
|
# Whether to add a conv2d layer after upsample
|
|
66
107
|
upsample_conv: bool = True
|
|
67
108
|
# Optional sampling config if add_upsample is True.
|
|
68
|
-
sampling_config: Optional[
|
|
109
|
+
sampling_config: Optional[UpSamplingConfig] = None
|
|
110
|
+
# Optional config of transformer blocks interleaved with residual blocks
|
|
111
|
+
transformer_block_config: Optional[TransformerBlock2Dconfig] = None
|
|
112
|
+
# Optional dimension of context tensor if context tensor is given as input.
|
|
113
|
+
context_dim: Optional[int] = None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@dataclass
|
|
117
|
+
class SkipUpDecoderBlock2DConfig:
|
|
118
|
+
in_channels: int
|
|
119
|
+
out_channels: int
|
|
120
|
+
# The dimension of output channels of previous connected block
|
|
121
|
+
prev_out_channels: int
|
|
122
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
123
|
+
activation_config: layers_cfg.ActivationConfig
|
|
124
|
+
num_layers: int
|
|
125
|
+
# Optional time embedding channels if the residual blocks take a time embedding as input
|
|
126
|
+
time_embedding_channels: Optional[int] = None
|
|
127
|
+
# Whether to add upsample operation after residual blocks
|
|
128
|
+
add_upsample: bool = True
|
|
129
|
+
# Whether to add a conv2d layer after upsample
|
|
130
|
+
upsample_conv: bool = True
|
|
131
|
+
# Optional sampling config if add_upsample is True.
|
|
132
|
+
sampling_config: Optional[UpSamplingConfig] = None
|
|
133
|
+
# Optional config of transformer blocks interleaved with residual blocks
|
|
134
|
+
transformer_block_config: Optional[TransformerBlock2Dconfig] = None
|
|
135
|
+
# Optional dimension of context tensor if context tensor is given as input.
|
|
136
|
+
context_dim: Optional[int] = None
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@dataclass
|
|
140
|
+
class DownEncoderBlock2DConfig:
|
|
141
|
+
in_channels: int
|
|
142
|
+
out_channels: int
|
|
143
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
144
|
+
activation_config: layers_cfg.ActivationConfig
|
|
145
|
+
num_layers: int
|
|
146
|
+
# Padding for the downsampling convolution.
|
|
147
|
+
padding: int = 1
|
|
148
|
+
# Optional time embedding channels if the residual blocks take a time embedding as input
|
|
149
|
+
time_embedding_channels: Optional[int] = None
|
|
150
|
+
# Whether to add downsample operation after residual blocks
|
|
151
|
+
add_downsample: bool = True
|
|
152
|
+
# Optional sampling config if add_upsample is True.
|
|
153
|
+
sampling_config: Optional[DownSamplingConfig] = None
|
|
154
|
+
# Optional config of transformer blocks interleaved with residual blocks
|
|
155
|
+
transformer_block_config: Optional[TransformerBlock2Dconfig] = None
|
|
156
|
+
# Optional dimension of context tensor if context tensor is given as input.
|
|
157
|
+
context_dim: Optional[int] = None
|
|
69
158
|
|
|
70
159
|
|
|
71
160
|
@dataclass
|
|
@@ -78,6 +167,10 @@ class MidBlock2DConfig:
|
|
|
78
167
|
time_embedding_channels: Optional[int] = None
|
|
79
168
|
# Optional config of attention blocks interleaved with residual blocks
|
|
80
169
|
attention_block_config: Optional[AttentionBlock2DConfig] = None
|
|
170
|
+
# Optional config of transformer blocks interleaved with residual blocks
|
|
171
|
+
transformer_block_config: Optional[TransformerBlock2Dconfig] = None
|
|
172
|
+
# Optional dimension of context tensor if context tensor is given as input.
|
|
173
|
+
context_dim: Optional[int] = None
|
|
81
174
|
|
|
82
175
|
|
|
83
176
|
@dataclass
|
|
@@ -115,3 +208,62 @@ class AutoEncoderConfig:
|
|
|
115
208
|
|
|
116
209
|
# The configuration of middle blocks, that is, after the last block of encoder and before the first block of decoder.
|
|
117
210
|
mid_block_config: MidBlock2DConfig
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@dataclass
|
|
214
|
+
class DiffusionModelConfig:
|
|
215
|
+
"""Configurations of Diffusion model."""
|
|
216
|
+
|
|
217
|
+
# Number of channels in the input tensor.
|
|
218
|
+
in_channels: int
|
|
219
|
+
|
|
220
|
+
# Number of channels in the output tensor.
|
|
221
|
+
out_channels: int
|
|
222
|
+
|
|
223
|
+
# The output channels of each block.
|
|
224
|
+
block_out_channels: List[int]
|
|
225
|
+
|
|
226
|
+
# The layesr number of each block.
|
|
227
|
+
layers_per_block: int
|
|
228
|
+
|
|
229
|
+
# The padding to use for the downsampling.
|
|
230
|
+
downsample_padding: int
|
|
231
|
+
|
|
232
|
+
# Normalization config used in residual blocks.
|
|
233
|
+
residual_norm_config: layers_cfg.NormalizationConfig
|
|
234
|
+
|
|
235
|
+
# Activation config used in residual blocks
|
|
236
|
+
residual_activation_type: layers_cfg.ActivationType
|
|
237
|
+
|
|
238
|
+
# The batch size used in transformer blocks, for attention layers.
|
|
239
|
+
transformer_batch_size: int
|
|
240
|
+
|
|
241
|
+
# The number of attention heads used in transformer blocks.
|
|
242
|
+
transformer_num_attention_heads: int
|
|
243
|
+
|
|
244
|
+
# The dimension of cross attention used in transformer blocks.
|
|
245
|
+
transformer_cross_attention_dim: int
|
|
246
|
+
|
|
247
|
+
# Normalization config used in prev conv layer of transformer blocks.
|
|
248
|
+
transformer_pre_conv_norm_config: layers_cfg.NormalizationConfig
|
|
249
|
+
|
|
250
|
+
# Normalization config used in transformer blocks.
|
|
251
|
+
transformer_norm_config: layers_cfg.NormalizationConfig
|
|
252
|
+
|
|
253
|
+
# Activation type of feed forward used in transformer blocks.
|
|
254
|
+
transformer_ff_activation_type: layers_cfg.ActivationType
|
|
255
|
+
|
|
256
|
+
# Number of layers in mid block.
|
|
257
|
+
mid_block_layers: int
|
|
258
|
+
|
|
259
|
+
# Dimension of time embedding.
|
|
260
|
+
time_embedding_dim: int
|
|
261
|
+
|
|
262
|
+
# Time embedding dimensions for blocks.
|
|
263
|
+
time_embedding_blocks_dim: int
|
|
264
|
+
|
|
265
|
+
# Normalization config used for final layer
|
|
266
|
+
final_norm_config: layers_cfg.NormalizationConfig
|
|
267
|
+
|
|
268
|
+
# Activation type used in final layer
|
|
269
|
+
final_activation_type: layers_cfg.ActivationType
|
|
File without changes
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
|
|
18
|
+
from ai_edge_quantizer import quantizer
|
|
19
|
+
|
|
20
|
+
from ai_edge_torch.generative.quantize import quant_attrs
|
|
21
|
+
from ai_edge_torch.generative.quantize import quant_recipe
|
|
22
|
+
|
|
23
|
+
_OpExecutionMode = quantizer.qtyping.OpExecutionMode
|
|
24
|
+
_OpName = quantizer.qtyping.TFLOperationName
|
|
25
|
+
_TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig
|
|
26
|
+
_OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
|
|
27
|
+
|
|
28
|
+
_DEFAULT_REGEX_STR = '.*'
|
|
29
|
+
_ATTENTION_IDX_REGEX_STR = (
|
|
30
|
+
'transformer_blocks\[{}\]/ai_edge_torch.generative.layers.attention'
|
|
31
|
+
)
|
|
32
|
+
_FEEDFORWARD_IDX_REGEX_STR = (
|
|
33
|
+
'transformer_blocks\[{}\]/ai_edge_torch.generative.layers.feed_forward'
|
|
34
|
+
)
|
|
35
|
+
_EMBEDDING_REGEX_STR = 'Embedding_tok_embedding'
|
|
36
|
+
_ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}'
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _get_nbits_from_dtype(dtype: quant_attrs.Dtype) -> int:
|
|
40
|
+
if dtype == quant_attrs.Dtype.FP32:
|
|
41
|
+
return 32
|
|
42
|
+
elif dtype == quant_attrs.Dtype.FP16:
|
|
43
|
+
return 16
|
|
44
|
+
elif dtype == quant_attrs.Dtype.INT8:
|
|
45
|
+
return 8
|
|
46
|
+
raise ValueError('Unimplemented number of bits')
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _get_dtype_from_dtype(dtype: quant_attrs.Dtype) -> quantizer.qtyping.TensorDataType:
|
|
50
|
+
if dtype == quant_attrs.Dtype.FP32 or dtype == quant_attrs.Dtype.FP16:
|
|
51
|
+
return quantizer.qtyping.TensorDataType.FLOAT
|
|
52
|
+
else:
|
|
53
|
+
return quantizer.qtyping.TensorDataType.INT
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _get_execution_mode_from_mode(mode: quant_attrs.Mode) -> _OpExecutionMode:
|
|
57
|
+
if mode == quant_attrs.Mode.DYNAMIC_RANGE:
|
|
58
|
+
return _OpExecutionMode.DRQ
|
|
59
|
+
elif mode == quant_attrs.Mode.WEIGHT_ONLY:
|
|
60
|
+
return _OpExecutionMode.WEIGHT_ONLY
|
|
61
|
+
raise ValueError('Unimplemented execution mode')
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _get_channelwise_from_granularity(granularity: quant_attrs.Granularity) -> bool:
|
|
65
|
+
if granularity == quant_attrs.Granularity.CHANNELWISE:
|
|
66
|
+
return True
|
|
67
|
+
elif granularity == quant_attrs.Granularity.NONE:
|
|
68
|
+
return False
|
|
69
|
+
raise ValueError('Unimplemented granularity')
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _get_algorithm_key_from_algorithm(algo: quant_attrs.Algorithm) -> str:
|
|
73
|
+
if algo == quant_attrs.Algorithm.MIN_MAX:
|
|
74
|
+
return quantizer.algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT
|
|
75
|
+
elif algo == quant_attrs.Algorithm.FLOAT_CAST:
|
|
76
|
+
return quantizer.algorithm_manager.AlgorithmName.FLOAT_CASTING
|
|
77
|
+
raise ValueError('Unimplemented algorithm')
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _set_quant_config(
|
|
81
|
+
rm: quantizer.recipe_manager.RecipeManager,
|
|
82
|
+
layer_recipe: quant_recipe.LayerQuantRecipe,
|
|
83
|
+
regex: str,
|
|
84
|
+
):
|
|
85
|
+
support_op_list = [_OpName.FULLY_CONNECTED, _OpName.CONV_2D]
|
|
86
|
+
if layer_recipe.algorithm == quant_attrs.Algorithm.MIN_MAX:
|
|
87
|
+
support_op_list += [_OpName.BATCH_MATMUL, _OpName.EMBEDDING_LOOKUP]
|
|
88
|
+
for op_name in support_op_list:
|
|
89
|
+
rm.add_quantization_config(
|
|
90
|
+
regex=regex,
|
|
91
|
+
operation_name=op_name,
|
|
92
|
+
op_config=_OpQuantConfig(
|
|
93
|
+
weight_tensor_config=_TensorQuantConfig(
|
|
94
|
+
num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
|
|
95
|
+
symmetric=True,
|
|
96
|
+
channel_wise=_get_channelwise_from_granularity(
|
|
97
|
+
layer_recipe.granularity
|
|
98
|
+
),
|
|
99
|
+
dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
|
|
100
|
+
),
|
|
101
|
+
execution_mode=_get_execution_mode_from_mode(layer_recipe.mode),
|
|
102
|
+
),
|
|
103
|
+
algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
|
|
104
|
+
override_algorithm=True,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def translate_to_ai_edge_recipe(
|
|
109
|
+
recipe: quant_recipe.GenerativeQuantRecipe,
|
|
110
|
+
) -> quantizer.recipe_manager.ModelQuantizationRecipe:
|
|
111
|
+
rm = quantizer.recipe_manager.RecipeManager()
|
|
112
|
+
|
|
113
|
+
if recipe.default is not None:
|
|
114
|
+
_set_quant_config(rm, recipe.default, _DEFAULT_REGEX_STR)
|
|
115
|
+
|
|
116
|
+
if recipe.embedding is not None:
|
|
117
|
+
_set_quant_config(rm, recipe.embedding, _EMBEDDING_REGEX_STR)
|
|
118
|
+
|
|
119
|
+
if recipe.attention is not None:
|
|
120
|
+
if isinstance(recipe.attention, dict):
|
|
121
|
+
for idx, layer in recipe.attention.items():
|
|
122
|
+
_set_quant_config(rm, layer, _ATTENTION_IDX_REGEX_STR.format(idx))
|
|
123
|
+
else:
|
|
124
|
+
_set_quant_config(
|
|
125
|
+
rm,
|
|
126
|
+
recipe.attention,
|
|
127
|
+
_ATTENTION_IDX_REGEX_STR.format(_ANY_TWO_DIGITS_REGEX_STR),
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
if recipe.feedforward is not None:
|
|
131
|
+
if isinstance(recipe.feedforward, dict):
|
|
132
|
+
for idx, layer in recipe.feedforward.items():
|
|
133
|
+
_set_quant_config(rm, layer, _FEEDFORWARD_IDX_REGEX_STR.format(idx))
|
|
134
|
+
else:
|
|
135
|
+
_set_quant_config(
|
|
136
|
+
rm,
|
|
137
|
+
recipe.feedforward,
|
|
138
|
+
_FEEDFORWARD_IDX_REGEX_STR.format(_ANY_TWO_DIGITS_REGEX_STR),
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
return rm.get_quantization_recipe()
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def quantize_model(
|
|
145
|
+
model: bytearray, recipe: quantizer.recipe_manager.ModelQuantizationRecipe
|
|
146
|
+
) -> bytearray:
|
|
147
|
+
# TODO(b/336599483): Remove tempfile and use bytearray instead
|
|
148
|
+
tmp_model_path = '/tmp/tmp.tflite'
|
|
149
|
+
tmp_recipe_path = '/tmp/recipe.json'
|
|
150
|
+
with open(tmp_model_path, 'wb') as fp:
|
|
151
|
+
fp.write(model)
|
|
152
|
+
with open(tmp_recipe_path, 'w') as rp:
|
|
153
|
+
rp.write(json.dumps(recipe))
|
|
154
|
+
|
|
155
|
+
qt = quantizer.Quantizer(tmp_model_path, tmp_recipe_path)
|
|
156
|
+
result = qt.quantize()
|
|
157
|
+
|
|
158
|
+
# TODO(b/336599483): Remove tempfile and use bytearray instead
|
|
159
|
+
import os
|
|
160
|
+
|
|
161
|
+
os.remove(tmp_model_path)
|
|
162
|
+
os.remove(tmp_recipe_path)
|
|
163
|
+
|
|
164
|
+
return result.quantized_model
|
|
@@ -32,9 +32,11 @@ class Algorithm(enum.Enum):
|
|
|
32
32
|
Attributes:
|
|
33
33
|
MIN_MAX: Maps the min/max of floating point space to the min/max of
|
|
34
34
|
quantized space and quantize uniformly.
|
|
35
|
+
FLOAT_CAST: Casts a float to another float of a different type.
|
|
35
36
|
"""
|
|
36
37
|
|
|
37
38
|
MIN_MAX = enum.auto()
|
|
39
|
+
FLOAT_CAST = enum.auto()
|
|
38
40
|
|
|
39
41
|
|
|
40
42
|
@enum.unique
|
|
@@ -14,8 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
from dataclasses import dataclass
|
|
17
|
-
import
|
|
18
|
-
from typing import Optional
|
|
17
|
+
from typing import Optional, Union
|
|
19
18
|
|
|
20
19
|
from ai_edge_torch.generative.quantize import quant_attrs
|
|
21
20
|
from ai_edge_torch.generative.quantize import supported_schemes
|
|
@@ -80,18 +79,50 @@ class LayerQuantRecipe:
|
|
|
80
79
|
|
|
81
80
|
|
|
82
81
|
@dataclass
|
|
83
|
-
class
|
|
82
|
+
class GenerativeQuantRecipe:
|
|
84
83
|
"""Quantization recipe for a model composed of the Edge Generative API layers.
|
|
85
84
|
|
|
85
|
+
Some layers can be specified with different `LayerQuantRecipe` for each block by
|
|
86
|
+
providing a dictionary keyed by the TransformerBlock index, e.g. attention
|
|
87
|
+
and feedforward. For example,
|
|
88
|
+
|
|
89
|
+
```
|
|
90
|
+
default = LayerQuantRecipeA
|
|
91
|
+
attention = { 2: LayerQuantRecipeB }
|
|
92
|
+
feedforward = { 3: LayerQuantRecipeC }
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
will apply LayerQuantRecipeA to the entire model, overriden by
|
|
96
|
+
LayerQuantRecipeB for the TransformerBlock[2].attention layer and
|
|
97
|
+
LayerQuantRecipeC for the TransformerBlock[3].feedforward layer. Any config
|
|
98
|
+
with invalid indices will be ignored.
|
|
99
|
+
|
|
86
100
|
Attributes:
|
|
87
101
|
default: The quantization recipe for global scope of the model.
|
|
102
|
+
embedding: Recipe for the embedding table.
|
|
103
|
+
attention: Recipe for the attention blocks. This could be specified with
|
|
104
|
+
different LayerQuantRecipe for each block by providing a dictionary
|
|
105
|
+
keyed by the TransformerBlock index.
|
|
106
|
+
feedforward: Recipe for the feedforward layers. This could be specified with
|
|
107
|
+
different LayerQuantRecipe for each block by providing a dictionary
|
|
108
|
+
keyed by the TransformerBlock index.
|
|
88
109
|
"""
|
|
89
110
|
|
|
90
111
|
default: Optional[LayerQuantRecipe] = None
|
|
112
|
+
embedding: Optional[LayerQuantRecipe] = None
|
|
113
|
+
attention: Union[
|
|
114
|
+
Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
|
|
115
|
+
] = None
|
|
116
|
+
feedforward: Union[
|
|
117
|
+
Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
|
|
118
|
+
] = None
|
|
91
119
|
|
|
92
120
|
def __str__(self):
|
|
93
|
-
return f"""
|
|
121
|
+
return f"""GenerativeQuantRecipe(
|
|
94
122
|
Default: {self.default}
|
|
123
|
+
Embedding: {self.embedding}
|
|
124
|
+
Attention: {self.attention}
|
|
125
|
+
Feedforward: {self.feedforward}
|
|
95
126
|
)"""
|
|
96
127
|
|
|
97
128
|
__repr__ = __str__
|
|
@@ -104,3 +135,17 @@ class TransformerQuantRecipe:
|
|
|
104
135
|
"""
|
|
105
136
|
if self.default is not None:
|
|
106
137
|
self.default.verify()
|
|
138
|
+
if self.embedding is not None:
|
|
139
|
+
self.embedding.verify()
|
|
140
|
+
if self.attention is not None:
|
|
141
|
+
if isinstance(self.attention, dict):
|
|
142
|
+
for recipe in self.attention.values():
|
|
143
|
+
recipe.verify()
|
|
144
|
+
else:
|
|
145
|
+
self.attention.verify()
|
|
146
|
+
if self.feedforward is not None:
|
|
147
|
+
if isinstance(self.feedforward, dict):
|
|
148
|
+
for recipe in self.feedforward.values():
|
|
149
|
+
recipe.verify()
|
|
150
|
+
else:
|
|
151
|
+
self.feedforward.verify()
|
|
@@ -22,7 +22,7 @@ Typical usage example:
|
|
|
22
22
|
|
|
23
23
|
1. Applying a single layer recipe to the entire model
|
|
24
24
|
|
|
25
|
-
quant_recipe.
|
|
25
|
+
quant_recipe.GenerativeQuantRecipe(
|
|
26
26
|
default=quant_recipe_utils.create_layer_quant_int8_dynamic()
|
|
27
27
|
)
|
|
28
28
|
"""
|
|
@@ -46,6 +46,6 @@ def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
|
|
|
46
46
|
activation_dtype=quant_attrs.Dtype.FP32,
|
|
47
47
|
weight_dtype=quant_attrs.Dtype.FP16,
|
|
48
48
|
mode=quant_attrs.Mode.WEIGHT_ONLY,
|
|
49
|
-
algorithm=quant_attrs.Algorithm.
|
|
49
|
+
algorithm=quant_attrs.Algorithm.FLOAT_CAST,
|
|
50
50
|
granularity=quant_attrs.Granularity.NONE,
|
|
51
51
|
)
|
|
@@ -34,15 +34,15 @@ from ai_edge_torch.quantize import quant_config
|
|
|
34
34
|
|
|
35
35
|
def full_linear_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
|
36
36
|
return quant_config.QuantConfig(
|
|
37
|
-
|
|
38
|
-
default=quant_recipe_utils.create_layer_quant_int8_dynamic()
|
|
37
|
+
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
38
|
+
default=quant_recipe_utils.create_layer_quant_int8_dynamic(),
|
|
39
39
|
)
|
|
40
40
|
)
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
def full_fp16_recipe() -> quant_config.QuantConfig:
|
|
44
44
|
return quant_config.QuantConfig(
|
|
45
|
-
|
|
45
|
+
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
46
46
|
default=quant_recipe_utils.create_layer_quant_fp16()
|
|
47
47
|
)
|
|
48
48
|
)
|
|
@@ -27,5 +27,6 @@ def get_supported_layer_schemes():
|
|
|
27
27
|
|
|
28
28
|
return [
|
|
29
29
|
(_t.FP32, _t.INT8, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.CHANNELWISE),
|
|
30
|
-
(_t.FP32, _t.
|
|
30
|
+
(_t.FP32, _t.INT8, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.CHANNELWISE),
|
|
31
|
+
(_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.FLOAT_CAST, _g.NONE),
|
|
31
32
|
]
|
|
@@ -55,6 +55,30 @@ class TestModelConversion(unittest.TestCase):
|
|
|
55
55
|
)
|
|
56
56
|
)
|
|
57
57
|
|
|
58
|
+
def test_toy_model_with_multi_batches(self):
|
|
59
|
+
config = toy_model_with_kv_cache.get_model_config()
|
|
60
|
+
config.batch_size = 2
|
|
61
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
62
|
+
idx, input_pos = torch.tensor([[1], [2]], dtype=torch.long), torch.tensor(
|
|
63
|
+
[10], dtype=torch.int64
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
67
|
+
|
|
68
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
69
|
+
skip_output_check = True
|
|
70
|
+
if skip_output_check is False:
|
|
71
|
+
self.assertTrue(
|
|
72
|
+
model_coverage.compare_tflite_torch(
|
|
73
|
+
edge_model,
|
|
74
|
+
pytorch_model,
|
|
75
|
+
(idx, input_pos),
|
|
76
|
+
num_valid_inputs=1,
|
|
77
|
+
atol=1e-5,
|
|
78
|
+
rtol=1e-5,
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
|
|
58
82
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
|
59
83
|
config = toy_model_with_kv_cache.get_model_config()
|
|
60
84
|
config.enable_hlfb = True
|
|
@@ -21,11 +21,13 @@ import torch
|
|
|
21
21
|
import ai_edge_torch
|
|
22
22
|
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
|
|
23
23
|
from ai_edge_torch.generative.quantize import quant_recipe
|
|
24
|
+
from ai_edge_torch.generative.quantize import quant_recipe_utils
|
|
24
25
|
from ai_edge_torch.generative.quantize import quant_recipes
|
|
25
26
|
from ai_edge_torch.generative.quantize.quant_attrs import Algorithm
|
|
26
27
|
from ai_edge_torch.generative.quantize.quant_attrs import Dtype
|
|
27
28
|
from ai_edge_torch.generative.quantize.quant_attrs import Granularity
|
|
28
29
|
from ai_edge_torch.generative.quantize.quant_attrs import Mode
|
|
30
|
+
from ai_edge_torch.quantize import quant_config
|
|
29
31
|
from ai_edge_torch.testing import model_coverage
|
|
30
32
|
|
|
31
33
|
|
|
@@ -34,34 +36,47 @@ class TestVerifyRecipes(unittest.TestCase):
|
|
|
34
36
|
|
|
35
37
|
@parameterized.expand(
|
|
36
38
|
[
|
|
37
|
-
(Dtype.FP32, Dtype.FP32
|
|
38
|
-
(Dtype.INT8, Dtype.INT8
|
|
39
|
-
(Dtype.INT8, Dtype.FP16
|
|
40
|
-
(Dtype.FP16, Dtype.INT8
|
|
41
|
-
(Dtype.
|
|
42
|
-
(Dtype.INT8, Dtype.INT8, Mode.WEIGHT_ONLY),
|
|
43
|
-
(Dtype.FP16, Dtype.INT8, Mode.WEIGHT_ONLY),
|
|
44
|
-
(Dtype.INT8, Dtype.FP16, Mode.WEIGHT_ONLY),
|
|
45
|
-
(Dtype.FP16, Dtype.FP16, Mode.WEIGHT_ONLY),
|
|
39
|
+
(Dtype.FP32, Dtype.FP32),
|
|
40
|
+
(Dtype.INT8, Dtype.INT8),
|
|
41
|
+
(Dtype.INT8, Dtype.FP16),
|
|
42
|
+
(Dtype.FP16, Dtype.INT8),
|
|
43
|
+
(Dtype.FP16, Dtype.FP16),
|
|
46
44
|
]
|
|
47
45
|
)
|
|
48
46
|
def test_verify_invalid_recipes(
|
|
49
47
|
self,
|
|
50
48
|
activation,
|
|
51
49
|
weight,
|
|
52
|
-
mode,
|
|
53
|
-
algo=Algorithm.MIN_MAX,
|
|
54
|
-
granularity=Granularity.CHANNELWISE,
|
|
55
50
|
):
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
51
|
+
for m in Mode:
|
|
52
|
+
for a in Algorithm:
|
|
53
|
+
for g in Granularity:
|
|
54
|
+
with self.assertRaises(ValueError):
|
|
55
|
+
quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()
|
|
60
56
|
|
|
61
57
|
@parameterized.expand(
|
|
62
58
|
[
|
|
63
|
-
(
|
|
64
|
-
|
|
59
|
+
(
|
|
60
|
+
Dtype.FP32,
|
|
61
|
+
Dtype.INT8,
|
|
62
|
+
Mode.DYNAMIC_RANGE,
|
|
63
|
+
Algorithm.MIN_MAX,
|
|
64
|
+
Granularity.CHANNELWISE,
|
|
65
|
+
),
|
|
66
|
+
(
|
|
67
|
+
Dtype.FP32,
|
|
68
|
+
Dtype.INT8,
|
|
69
|
+
Mode.WEIGHT_ONLY,
|
|
70
|
+
Algorithm.MIN_MAX,
|
|
71
|
+
Granularity.CHANNELWISE,
|
|
72
|
+
),
|
|
73
|
+
(
|
|
74
|
+
Dtype.FP32,
|
|
75
|
+
Dtype.FP16,
|
|
76
|
+
Mode.WEIGHT_ONLY,
|
|
77
|
+
Algorithm.FLOAT_CAST,
|
|
78
|
+
Granularity.NONE,
|
|
79
|
+
),
|
|
65
80
|
]
|
|
66
81
|
)
|
|
67
82
|
def test_verify_valid_recipes(
|
|
@@ -69,8 +84,8 @@ class TestVerifyRecipes(unittest.TestCase):
|
|
|
69
84
|
activation,
|
|
70
85
|
weight,
|
|
71
86
|
mode,
|
|
87
|
+
algo,
|
|
72
88
|
granularity,
|
|
73
|
-
algo=Algorithm.MIN_MAX,
|
|
74
89
|
):
|
|
75
90
|
quant_recipe.LayerQuantRecipe(activation, weight, mode, algo, granularity).verify()
|
|
76
91
|
|
|
@@ -78,7 +93,47 @@ class TestVerifyRecipes(unittest.TestCase):
|
|
|
78
93
|
class TestQuantizeConvert(unittest.TestCase):
|
|
79
94
|
"""Test conversion with quantization."""
|
|
80
95
|
|
|
81
|
-
def
|
|
96
|
+
def _attention_1_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
|
97
|
+
return quant_config.QuantConfig(
|
|
98
|
+
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
99
|
+
attention={1: quant_recipe_utils.create_layer_quant_int8_dynamic()},
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def _feedforward_0_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
|
104
|
+
return quant_config.QuantConfig(
|
|
105
|
+
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
106
|
+
feedforward={0: quant_recipe_utils.create_layer_quant_int8_dynamic()},
|
|
107
|
+
)
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
@parameterized.expand(
|
|
111
|
+
[
|
|
112
|
+
(quant_recipes.full_fp16_recipe(), 0.75),
|
|
113
|
+
(quant_recipes.full_linear_int8_dynamic_recipe(), 0.64),
|
|
114
|
+
(_attention_1_int8_dynamic_recipe(), 0.95),
|
|
115
|
+
(_feedforward_0_int8_dynamic_recipe(), 0.87),
|
|
116
|
+
]
|
|
117
|
+
)
|
|
118
|
+
def test_quantize_convert_toy_sizes(self, quant_config, expected_compression):
|
|
119
|
+
self.skipTest("b/346896669")
|
|
120
|
+
config = toy_model_with_kv_cache.get_model_config()
|
|
121
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
122
|
+
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
|
123
|
+
[10], dtype=torch.int64
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
quantized_model = ai_edge_torch.convert(
|
|
127
|
+
pytorch_model, (idx, input_pos), quant_config=quant_config
|
|
128
|
+
)
|
|
129
|
+
float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
130
|
+
self.assertAlmostEqual(
|
|
131
|
+
len(quantized_model._tflite_model) / len(float_model._tflite_model),
|
|
132
|
+
expected_compression,
|
|
133
|
+
delta=0.01,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def test_quantize_convert_compare_toy(self):
|
|
82
137
|
self.skipTest("b/338288901")
|
|
83
138
|
config = toy_model_with_kv_cache.get_model_config()
|
|
84
139
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|