ai-edge-torch-nightly 0.2.0.dev20240610__py3-none-any.whl → 0.2.0.dev20240617__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (30) hide show
  1. ai_edge_torch/convert/conversion_utils.py +17 -5
  2. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +19 -0
  3. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -2
  4. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +9 -6
  5. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +33 -25
  6. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +523 -202
  7. ai_edge_torch/generative/examples/t5/t5_attention.py +10 -39
  8. ai_edge_torch/generative/layers/attention.py +154 -26
  9. ai_edge_torch/generative/layers/model_config.py +4 -0
  10. ai_edge_torch/generative/layers/unet/blocks_2d.py +473 -49
  11. ai_edge_torch/generative/layers/unet/builder.py +20 -2
  12. ai_edge_torch/generative/layers/unet/model_config.py +157 -5
  13. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  14. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +164 -0
  15. ai_edge_torch/generative/quantize/quant_attrs.py +2 -0
  16. ai_edge_torch/generative/quantize/quant_recipe.py +49 -4
  17. ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -2
  18. ai_edge_torch/generative/quantize/quant_recipes.py +3 -3
  19. ai_edge_torch/generative/quantize/supported_schemes.py +2 -1
  20. ai_edge_torch/generative/test/test_model_conversion.py +24 -0
  21. ai_edge_torch/generative/test/test_quantize.py +75 -20
  22. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +860 -0
  23. ai_edge_torch/generative/utilities/t5_loader.py +33 -17
  24. ai_edge_torch/quantize/quant_config.py +11 -15
  25. {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/METADATA +1 -1
  26. {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/RECORD +29 -27
  27. ai_edge_torch/generative/utilities/autoencoder_loader.py +0 -298
  28. {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/LICENSE +0 -0
  29. {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/WHEEL +0 -0
  30. {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/top_level.txt +0 -0
@@ -22,16 +22,28 @@ from typing import List, Optional
22
22
  import ai_edge_torch.generative.layers.model_config as layers_cfg
23
23
 
24
24
 
25
- @dataclass
25
+ @enum.unique
26
26
  class SamplingType(enum.Enum):
27
27
  NEAREST = enum.auto()
28
28
  BILINEAR = enum.auto()
29
+ AVERAGE = enum.auto()
30
+ CONVOLUTION = enum.auto()
29
31
 
30
32
 
31
33
  @dataclass
32
- class SamplingConfig:
34
+ class UpSamplingConfig:
35
+ mode: SamplingType
33
36
  scale_factor: float
37
+
38
+
39
+ @dataclass
40
+ class DownSamplingConfig:
34
41
  mode: SamplingType
42
+ in_channels: int
43
+ kernel_size: int
44
+ stride: int
45
+ padding: int
46
+ out_channels: Optional[int] = None
35
47
 
36
48
 
37
49
  @dataclass
@@ -46,9 +58,38 @@ class ResidualBlock2DConfig:
46
58
 
47
59
  @dataclass
48
60
  class AttentionBlock2DConfig:
49
- dims: int
61
+ dim: int
62
+ normalization_config: layers_cfg.NormalizationConfig
63
+ attention_config: layers_cfg.AttentionConfig
64
+ enable_hlfb: bool = True
65
+ attention_batch_size: int = 1
66
+
67
+
68
+ @dataclass
69
+ class CrossAttentionBlock2DConfig:
70
+ query_dim: int
71
+ cross_dim: int
50
72
  normalization_config: layers_cfg.NormalizationConfig
51
73
  attention_config: layers_cfg.AttentionConfig
74
+ enable_hlfb: bool = True
75
+ attention_batch_size: int = 1
76
+
77
+
78
+ @dataclass
79
+ class FeedForwardBlock2DConfig:
80
+ dim: int
81
+ hidden_dim: int
82
+ normalization_config: layers_cfg.NormalizationConfig
83
+ activation_config: layers_cfg.ActivationConfig
84
+ use_bias: bool
85
+
86
+
87
+ @dataclass
88
+ class TransformerBlock2Dconfig:
89
+ pre_conv_normalization_config: layers_cfg.NormalizationConfig
90
+ attention_block_config: AttentionBlock2DConfig
91
+ cross_attention_block_config: CrossAttentionBlock2DConfig
92
+ feed_forward_block_config: FeedForwardBlock2DConfig
52
93
 
53
94
 
54
95
  @dataclass
@@ -58,14 +99,62 @@ class UpDecoderBlock2DConfig:
58
99
  normalization_config: layers_cfg.NormalizationConfig
59
100
  activation_config: layers_cfg.ActivationConfig
60
101
  num_layers: int
61
- # Optional time embedding channels if the residual blocks take a time embedding context as input
102
+ # Optional time embedding channels if the residual blocks take a time embedding as input
62
103
  time_embedding_channels: Optional[int] = None
63
104
  # Whether to add upsample operation after residual blocks
64
105
  add_upsample: bool = True
65
106
  # Whether to add a conv2d layer after upsample
66
107
  upsample_conv: bool = True
67
108
  # Optional sampling config if add_upsample is True.
68
- sampling_config: Optional[SamplingConfig] = None
109
+ sampling_config: Optional[UpSamplingConfig] = None
110
+ # Optional config of transformer blocks interleaved with residual blocks
111
+ transformer_block_config: Optional[TransformerBlock2Dconfig] = None
112
+ # Optional dimension of context tensor if context tensor is given as input.
113
+ context_dim: Optional[int] = None
114
+
115
+
116
+ @dataclass
117
+ class SkipUpDecoderBlock2DConfig:
118
+ in_channels: int
119
+ out_channels: int
120
+ # The dimension of output channels of previous connected block
121
+ prev_out_channels: int
122
+ normalization_config: layers_cfg.NormalizationConfig
123
+ activation_config: layers_cfg.ActivationConfig
124
+ num_layers: int
125
+ # Optional time embedding channels if the residual blocks take a time embedding as input
126
+ time_embedding_channels: Optional[int] = None
127
+ # Whether to add upsample operation after residual blocks
128
+ add_upsample: bool = True
129
+ # Whether to add a conv2d layer after upsample
130
+ upsample_conv: bool = True
131
+ # Optional sampling config if add_upsample is True.
132
+ sampling_config: Optional[UpSamplingConfig] = None
133
+ # Optional config of transformer blocks interleaved with residual blocks
134
+ transformer_block_config: Optional[TransformerBlock2Dconfig] = None
135
+ # Optional dimension of context tensor if context tensor is given as input.
136
+ context_dim: Optional[int] = None
137
+
138
+
139
+ @dataclass
140
+ class DownEncoderBlock2DConfig:
141
+ in_channels: int
142
+ out_channels: int
143
+ normalization_config: layers_cfg.NormalizationConfig
144
+ activation_config: layers_cfg.ActivationConfig
145
+ num_layers: int
146
+ # Padding for the downsampling convolution.
147
+ padding: int = 1
148
+ # Optional time embedding channels if the residual blocks take a time embedding as input
149
+ time_embedding_channels: Optional[int] = None
150
+ # Whether to add downsample operation after residual blocks
151
+ add_downsample: bool = True
152
+ # Optional sampling config if add_upsample is True.
153
+ sampling_config: Optional[DownSamplingConfig] = None
154
+ # Optional config of transformer blocks interleaved with residual blocks
155
+ transformer_block_config: Optional[TransformerBlock2Dconfig] = None
156
+ # Optional dimension of context tensor if context tensor is given as input.
157
+ context_dim: Optional[int] = None
69
158
 
70
159
 
71
160
  @dataclass
@@ -78,6 +167,10 @@ class MidBlock2DConfig:
78
167
  time_embedding_channels: Optional[int] = None
79
168
  # Optional config of attention blocks interleaved with residual blocks
80
169
  attention_block_config: Optional[AttentionBlock2DConfig] = None
170
+ # Optional config of transformer blocks interleaved with residual blocks
171
+ transformer_block_config: Optional[TransformerBlock2Dconfig] = None
172
+ # Optional dimension of context tensor if context tensor is given as input.
173
+ context_dim: Optional[int] = None
81
174
 
82
175
 
83
176
  @dataclass
@@ -115,3 +208,62 @@ class AutoEncoderConfig:
115
208
 
116
209
  # The configuration of middle blocks, that is, after the last block of encoder and before the first block of decoder.
117
210
  mid_block_config: MidBlock2DConfig
211
+
212
+
213
+ @dataclass
214
+ class DiffusionModelConfig:
215
+ """Configurations of Diffusion model."""
216
+
217
+ # Number of channels in the input tensor.
218
+ in_channels: int
219
+
220
+ # Number of channels in the output tensor.
221
+ out_channels: int
222
+
223
+ # The output channels of each block.
224
+ block_out_channels: List[int]
225
+
226
+ # The layesr number of each block.
227
+ layers_per_block: int
228
+
229
+ # The padding to use for the downsampling.
230
+ downsample_padding: int
231
+
232
+ # Normalization config used in residual blocks.
233
+ residual_norm_config: layers_cfg.NormalizationConfig
234
+
235
+ # Activation config used in residual blocks
236
+ residual_activation_type: layers_cfg.ActivationType
237
+
238
+ # The batch size used in transformer blocks, for attention layers.
239
+ transformer_batch_size: int
240
+
241
+ # The number of attention heads used in transformer blocks.
242
+ transformer_num_attention_heads: int
243
+
244
+ # The dimension of cross attention used in transformer blocks.
245
+ transformer_cross_attention_dim: int
246
+
247
+ # Normalization config used in prev conv layer of transformer blocks.
248
+ transformer_pre_conv_norm_config: layers_cfg.NormalizationConfig
249
+
250
+ # Normalization config used in transformer blocks.
251
+ transformer_norm_config: layers_cfg.NormalizationConfig
252
+
253
+ # Activation type of feed forward used in transformer blocks.
254
+ transformer_ff_activation_type: layers_cfg.ActivationType
255
+
256
+ # Number of layers in mid block.
257
+ mid_block_layers: int
258
+
259
+ # Dimension of time embedding.
260
+ time_embedding_dim: int
261
+
262
+ # Time embedding dimensions for blocks.
263
+ time_embedding_blocks_dim: int
264
+
265
+ # Normalization config used for final layer
266
+ final_norm_config: layers_cfg.NormalizationConfig
267
+
268
+ # Activation type used in final layer
269
+ final_activation_type: layers_cfg.ActivationType
@@ -0,0 +1,164 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import json
17
+
18
+ from ai_edge_quantizer import quantizer
19
+
20
+ from ai_edge_torch.generative.quantize import quant_attrs
21
+ from ai_edge_torch.generative.quantize import quant_recipe
22
+
23
+ _OpExecutionMode = quantizer.qtyping.OpExecutionMode
24
+ _OpName = quantizer.qtyping.TFLOperationName
25
+ _TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig
26
+ _OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
27
+
28
+ _DEFAULT_REGEX_STR = '.*'
29
+ _ATTENTION_IDX_REGEX_STR = (
30
+ 'transformer_blocks\[{}\]/ai_edge_torch.generative.layers.attention'
31
+ )
32
+ _FEEDFORWARD_IDX_REGEX_STR = (
33
+ 'transformer_blocks\[{}\]/ai_edge_torch.generative.layers.feed_forward'
34
+ )
35
+ _EMBEDDING_REGEX_STR = 'Embedding_tok_embedding'
36
+ _ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}'
37
+
38
+
39
+ def _get_nbits_from_dtype(dtype: quant_attrs.Dtype) -> int:
40
+ if dtype == quant_attrs.Dtype.FP32:
41
+ return 32
42
+ elif dtype == quant_attrs.Dtype.FP16:
43
+ return 16
44
+ elif dtype == quant_attrs.Dtype.INT8:
45
+ return 8
46
+ raise ValueError('Unimplemented number of bits')
47
+
48
+
49
+ def _get_dtype_from_dtype(dtype: quant_attrs.Dtype) -> quantizer.qtyping.TensorDataType:
50
+ if dtype == quant_attrs.Dtype.FP32 or dtype == quant_attrs.Dtype.FP16:
51
+ return quantizer.qtyping.TensorDataType.FLOAT
52
+ else:
53
+ return quantizer.qtyping.TensorDataType.INT
54
+
55
+
56
+ def _get_execution_mode_from_mode(mode: quant_attrs.Mode) -> _OpExecutionMode:
57
+ if mode == quant_attrs.Mode.DYNAMIC_RANGE:
58
+ return _OpExecutionMode.DRQ
59
+ elif mode == quant_attrs.Mode.WEIGHT_ONLY:
60
+ return _OpExecutionMode.WEIGHT_ONLY
61
+ raise ValueError('Unimplemented execution mode')
62
+
63
+
64
+ def _get_channelwise_from_granularity(granularity: quant_attrs.Granularity) -> bool:
65
+ if granularity == quant_attrs.Granularity.CHANNELWISE:
66
+ return True
67
+ elif granularity == quant_attrs.Granularity.NONE:
68
+ return False
69
+ raise ValueError('Unimplemented granularity')
70
+
71
+
72
+ def _get_algorithm_key_from_algorithm(algo: quant_attrs.Algorithm) -> str:
73
+ if algo == quant_attrs.Algorithm.MIN_MAX:
74
+ return quantizer.algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT
75
+ elif algo == quant_attrs.Algorithm.FLOAT_CAST:
76
+ return quantizer.algorithm_manager.AlgorithmName.FLOAT_CASTING
77
+ raise ValueError('Unimplemented algorithm')
78
+
79
+
80
+ def _set_quant_config(
81
+ rm: quantizer.recipe_manager.RecipeManager,
82
+ layer_recipe: quant_recipe.LayerQuantRecipe,
83
+ regex: str,
84
+ ):
85
+ support_op_list = [_OpName.FULLY_CONNECTED, _OpName.CONV_2D]
86
+ if layer_recipe.algorithm == quant_attrs.Algorithm.MIN_MAX:
87
+ support_op_list += [_OpName.BATCH_MATMUL, _OpName.EMBEDDING_LOOKUP]
88
+ for op_name in support_op_list:
89
+ rm.add_quantization_config(
90
+ regex=regex,
91
+ operation_name=op_name,
92
+ op_config=_OpQuantConfig(
93
+ weight_tensor_config=_TensorQuantConfig(
94
+ num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
95
+ symmetric=True,
96
+ channel_wise=_get_channelwise_from_granularity(
97
+ layer_recipe.granularity
98
+ ),
99
+ dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
100
+ ),
101
+ execution_mode=_get_execution_mode_from_mode(layer_recipe.mode),
102
+ ),
103
+ algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
104
+ override_algorithm=True,
105
+ )
106
+
107
+
108
+ def translate_to_ai_edge_recipe(
109
+ recipe: quant_recipe.GenerativeQuantRecipe,
110
+ ) -> quantizer.recipe_manager.ModelQuantizationRecipe:
111
+ rm = quantizer.recipe_manager.RecipeManager()
112
+
113
+ if recipe.default is not None:
114
+ _set_quant_config(rm, recipe.default, _DEFAULT_REGEX_STR)
115
+
116
+ if recipe.embedding is not None:
117
+ _set_quant_config(rm, recipe.embedding, _EMBEDDING_REGEX_STR)
118
+
119
+ if recipe.attention is not None:
120
+ if isinstance(recipe.attention, dict):
121
+ for idx, layer in recipe.attention.items():
122
+ _set_quant_config(rm, layer, _ATTENTION_IDX_REGEX_STR.format(idx))
123
+ else:
124
+ _set_quant_config(
125
+ rm,
126
+ recipe.attention,
127
+ _ATTENTION_IDX_REGEX_STR.format(_ANY_TWO_DIGITS_REGEX_STR),
128
+ )
129
+
130
+ if recipe.feedforward is not None:
131
+ if isinstance(recipe.feedforward, dict):
132
+ for idx, layer in recipe.feedforward.items():
133
+ _set_quant_config(rm, layer, _FEEDFORWARD_IDX_REGEX_STR.format(idx))
134
+ else:
135
+ _set_quant_config(
136
+ rm,
137
+ recipe.feedforward,
138
+ _FEEDFORWARD_IDX_REGEX_STR.format(_ANY_TWO_DIGITS_REGEX_STR),
139
+ )
140
+
141
+ return rm.get_quantization_recipe()
142
+
143
+
144
+ def quantize_model(
145
+ model: bytearray, recipe: quantizer.recipe_manager.ModelQuantizationRecipe
146
+ ) -> bytearray:
147
+ # TODO(b/336599483): Remove tempfile and use bytearray instead
148
+ tmp_model_path = '/tmp/tmp.tflite'
149
+ tmp_recipe_path = '/tmp/recipe.json'
150
+ with open(tmp_model_path, 'wb') as fp:
151
+ fp.write(model)
152
+ with open(tmp_recipe_path, 'w') as rp:
153
+ rp.write(json.dumps(recipe))
154
+
155
+ qt = quantizer.Quantizer(tmp_model_path, tmp_recipe_path)
156
+ result = qt.quantize()
157
+
158
+ # TODO(b/336599483): Remove tempfile and use bytearray instead
159
+ import os
160
+
161
+ os.remove(tmp_model_path)
162
+ os.remove(tmp_recipe_path)
163
+
164
+ return result.quantized_model
@@ -32,9 +32,11 @@ class Algorithm(enum.Enum):
32
32
  Attributes:
33
33
  MIN_MAX: Maps the min/max of floating point space to the min/max of
34
34
  quantized space and quantize uniformly.
35
+ FLOAT_CAST: Casts a float to another float of a different type.
35
36
  """
36
37
 
37
38
  MIN_MAX = enum.auto()
39
+ FLOAT_CAST = enum.auto()
38
40
 
39
41
 
40
42
  @enum.unique
@@ -14,8 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from dataclasses import dataclass
17
- import enum
18
- from typing import Optional
17
+ from typing import Optional, Union
19
18
 
20
19
  from ai_edge_torch.generative.quantize import quant_attrs
21
20
  from ai_edge_torch.generative.quantize import supported_schemes
@@ -80,18 +79,50 @@ class LayerQuantRecipe:
80
79
 
81
80
 
82
81
  @dataclass
83
- class TransformerQuantRecipe:
82
+ class GenerativeQuantRecipe:
84
83
  """Quantization recipe for a model composed of the Edge Generative API layers.
85
84
 
85
+ Some layers can be specified with different `LayerQuantRecipe` for each block by
86
+ providing a dictionary keyed by the TransformerBlock index, e.g. attention
87
+ and feedforward. For example,
88
+
89
+ ```
90
+ default = LayerQuantRecipeA
91
+ attention = { 2: LayerQuantRecipeB }
92
+ feedforward = { 3: LayerQuantRecipeC }
93
+ ```
94
+
95
+ will apply LayerQuantRecipeA to the entire model, overriden by
96
+ LayerQuantRecipeB for the TransformerBlock[2].attention layer and
97
+ LayerQuantRecipeC for the TransformerBlock[3].feedforward layer. Any config
98
+ with invalid indices will be ignored.
99
+
86
100
  Attributes:
87
101
  default: The quantization recipe for global scope of the model.
102
+ embedding: Recipe for the embedding table.
103
+ attention: Recipe for the attention blocks. This could be specified with
104
+ different LayerQuantRecipe for each block by providing a dictionary
105
+ keyed by the TransformerBlock index.
106
+ feedforward: Recipe for the feedforward layers. This could be specified with
107
+ different LayerQuantRecipe for each block by providing a dictionary
108
+ keyed by the TransformerBlock index.
88
109
  """
89
110
 
90
111
  default: Optional[LayerQuantRecipe] = None
112
+ embedding: Optional[LayerQuantRecipe] = None
113
+ attention: Union[
114
+ Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
115
+ ] = None
116
+ feedforward: Union[
117
+ Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
118
+ ] = None
91
119
 
92
120
  def __str__(self):
93
- return f"""TransformerQuantRecipe(
121
+ return f"""GenerativeQuantRecipe(
94
122
  Default: {self.default}
123
+ Embedding: {self.embedding}
124
+ Attention: {self.attention}
125
+ Feedforward: {self.feedforward}
95
126
  )"""
96
127
 
97
128
  __repr__ = __str__
@@ -104,3 +135,17 @@ class TransformerQuantRecipe:
104
135
  """
105
136
  if self.default is not None:
106
137
  self.default.verify()
138
+ if self.embedding is not None:
139
+ self.embedding.verify()
140
+ if self.attention is not None:
141
+ if isinstance(self.attention, dict):
142
+ for recipe in self.attention.values():
143
+ recipe.verify()
144
+ else:
145
+ self.attention.verify()
146
+ if self.feedforward is not None:
147
+ if isinstance(self.feedforward, dict):
148
+ for recipe in self.feedforward.values():
149
+ recipe.verify()
150
+ else:
151
+ self.feedforward.verify()
@@ -22,7 +22,7 @@ Typical usage example:
22
22
 
23
23
  1. Applying a single layer recipe to the entire model
24
24
 
25
- quant_recipe.TransformerQuantRecipe(
25
+ quant_recipe.GenerativeQuantRecipe(
26
26
  default=quant_recipe_utils.create_layer_quant_int8_dynamic()
27
27
  )
28
28
  """
@@ -46,6 +46,6 @@ def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
46
46
  activation_dtype=quant_attrs.Dtype.FP32,
47
47
  weight_dtype=quant_attrs.Dtype.FP16,
48
48
  mode=quant_attrs.Mode.WEIGHT_ONLY,
49
- algorithm=quant_attrs.Algorithm.MIN_MAX,
49
+ algorithm=quant_attrs.Algorithm.FLOAT_CAST,
50
50
  granularity=quant_attrs.Granularity.NONE,
51
51
  )
@@ -34,15 +34,15 @@ from ai_edge_torch.quantize import quant_config
34
34
 
35
35
  def full_linear_int8_dynamic_recipe() -> quant_config.QuantConfig:
36
36
  return quant_config.QuantConfig(
37
- transformer_recipe=quant_recipe.TransformerQuantRecipe(
38
- default=quant_recipe_utils.create_layer_quant_int8_dynamic()
37
+ generative_recipe=quant_recipe.GenerativeQuantRecipe(
38
+ default=quant_recipe_utils.create_layer_quant_int8_dynamic(),
39
39
  )
40
40
  )
41
41
 
42
42
 
43
43
  def full_fp16_recipe() -> quant_config.QuantConfig:
44
44
  return quant_config.QuantConfig(
45
- transformer_recipe=quant_recipe.TransformerQuantRecipe(
45
+ generative_recipe=quant_recipe.GenerativeQuantRecipe(
46
46
  default=quant_recipe_utils.create_layer_quant_fp16()
47
47
  )
48
48
  )
@@ -27,5 +27,6 @@ def get_supported_layer_schemes():
27
27
 
28
28
  return [
29
29
  (_t.FP32, _t.INT8, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.CHANNELWISE),
30
- (_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.NONE),
30
+ (_t.FP32, _t.INT8, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.CHANNELWISE),
31
+ (_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.FLOAT_CAST, _g.NONE),
31
32
  ]
@@ -55,6 +55,30 @@ class TestModelConversion(unittest.TestCase):
55
55
  )
56
56
  )
57
57
 
58
+ def test_toy_model_with_multi_batches(self):
59
+ config = toy_model_with_kv_cache.get_model_config()
60
+ config.batch_size = 2
61
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
62
+ idx, input_pos = torch.tensor([[1], [2]], dtype=torch.long), torch.tensor(
63
+ [10], dtype=torch.int64
64
+ )
65
+
66
+ edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
67
+
68
+ # TODO(b/338288901): re-enable test to check output tensors.
69
+ skip_output_check = True
70
+ if skip_output_check is False:
71
+ self.assertTrue(
72
+ model_coverage.compare_tflite_torch(
73
+ edge_model,
74
+ pytorch_model,
75
+ (idx, input_pos),
76
+ num_valid_inputs=1,
77
+ atol=1e-5,
78
+ rtol=1e-5,
79
+ )
80
+ )
81
+
58
82
  def test_toy_model_with_kv_cache_with_hlfb(self):
59
83
  config = toy_model_with_kv_cache.get_model_config()
60
84
  config.enable_hlfb = True
@@ -21,11 +21,13 @@ import torch
21
21
  import ai_edge_torch
22
22
  from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
23
23
  from ai_edge_torch.generative.quantize import quant_recipe
24
+ from ai_edge_torch.generative.quantize import quant_recipe_utils
24
25
  from ai_edge_torch.generative.quantize import quant_recipes
25
26
  from ai_edge_torch.generative.quantize.quant_attrs import Algorithm
26
27
  from ai_edge_torch.generative.quantize.quant_attrs import Dtype
27
28
  from ai_edge_torch.generative.quantize.quant_attrs import Granularity
28
29
  from ai_edge_torch.generative.quantize.quant_attrs import Mode
30
+ from ai_edge_torch.quantize import quant_config
29
31
  from ai_edge_torch.testing import model_coverage
30
32
 
31
33
 
@@ -34,34 +36,47 @@ class TestVerifyRecipes(unittest.TestCase):
34
36
 
35
37
  @parameterized.expand(
36
38
  [
37
- (Dtype.FP32, Dtype.FP32, Mode.DYNAMIC_RANGE),
38
- (Dtype.INT8, Dtype.INT8, Mode.DYNAMIC_RANGE),
39
- (Dtype.INT8, Dtype.FP16, Mode.DYNAMIC_RANGE),
40
- (Dtype.FP16, Dtype.INT8, Mode.DYNAMIC_RANGE),
41
- (Dtype.FP32, Dtype.FP32, Mode.WEIGHT_ONLY),
42
- (Dtype.INT8, Dtype.INT8, Mode.WEIGHT_ONLY),
43
- (Dtype.FP16, Dtype.INT8, Mode.WEIGHT_ONLY),
44
- (Dtype.INT8, Dtype.FP16, Mode.WEIGHT_ONLY),
45
- (Dtype.FP16, Dtype.FP16, Mode.WEIGHT_ONLY),
39
+ (Dtype.FP32, Dtype.FP32),
40
+ (Dtype.INT8, Dtype.INT8),
41
+ (Dtype.INT8, Dtype.FP16),
42
+ (Dtype.FP16, Dtype.INT8),
43
+ (Dtype.FP16, Dtype.FP16),
46
44
  ]
47
45
  )
48
46
  def test_verify_invalid_recipes(
49
47
  self,
50
48
  activation,
51
49
  weight,
52
- mode,
53
- algo=Algorithm.MIN_MAX,
54
- granularity=Granularity.CHANNELWISE,
55
50
  ):
56
- with self.assertRaises(ValueError):
57
- quant_recipe.LayerQuantRecipe(
58
- activation, weight, mode, algo, granularity
59
- ).verify()
51
+ for m in Mode:
52
+ for a in Algorithm:
53
+ for g in Granularity:
54
+ with self.assertRaises(ValueError):
55
+ quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()
60
56
 
61
57
  @parameterized.expand(
62
58
  [
63
- (Dtype.FP32, Dtype.INT8, Mode.DYNAMIC_RANGE, Granularity.CHANNELWISE),
64
- (Dtype.FP32, Dtype.FP16, Mode.WEIGHT_ONLY, Granularity.NONE),
59
+ (
60
+ Dtype.FP32,
61
+ Dtype.INT8,
62
+ Mode.DYNAMIC_RANGE,
63
+ Algorithm.MIN_MAX,
64
+ Granularity.CHANNELWISE,
65
+ ),
66
+ (
67
+ Dtype.FP32,
68
+ Dtype.INT8,
69
+ Mode.WEIGHT_ONLY,
70
+ Algorithm.MIN_MAX,
71
+ Granularity.CHANNELWISE,
72
+ ),
73
+ (
74
+ Dtype.FP32,
75
+ Dtype.FP16,
76
+ Mode.WEIGHT_ONLY,
77
+ Algorithm.FLOAT_CAST,
78
+ Granularity.NONE,
79
+ ),
65
80
  ]
66
81
  )
67
82
  def test_verify_valid_recipes(
@@ -69,8 +84,8 @@ class TestVerifyRecipes(unittest.TestCase):
69
84
  activation,
70
85
  weight,
71
86
  mode,
87
+ algo,
72
88
  granularity,
73
- algo=Algorithm.MIN_MAX,
74
89
  ):
75
90
  quant_recipe.LayerQuantRecipe(activation, weight, mode, algo, granularity).verify()
76
91
 
@@ -78,7 +93,47 @@ class TestVerifyRecipes(unittest.TestCase):
78
93
  class TestQuantizeConvert(unittest.TestCase):
79
94
  """Test conversion with quantization."""
80
95
 
81
- def test_quantize_convert_toy(self):
96
+ def _attention_1_int8_dynamic_recipe() -> quant_config.QuantConfig:
97
+ return quant_config.QuantConfig(
98
+ generative_recipe=quant_recipe.GenerativeQuantRecipe(
99
+ attention={1: quant_recipe_utils.create_layer_quant_int8_dynamic()},
100
+ )
101
+ )
102
+
103
+ def _feedforward_0_int8_dynamic_recipe() -> quant_config.QuantConfig:
104
+ return quant_config.QuantConfig(
105
+ generative_recipe=quant_recipe.GenerativeQuantRecipe(
106
+ feedforward={0: quant_recipe_utils.create_layer_quant_int8_dynamic()},
107
+ )
108
+ )
109
+
110
+ @parameterized.expand(
111
+ [
112
+ (quant_recipes.full_fp16_recipe(), 0.75),
113
+ (quant_recipes.full_linear_int8_dynamic_recipe(), 0.64),
114
+ (_attention_1_int8_dynamic_recipe(), 0.95),
115
+ (_feedforward_0_int8_dynamic_recipe(), 0.87),
116
+ ]
117
+ )
118
+ def test_quantize_convert_toy_sizes(self, quant_config, expected_compression):
119
+ self.skipTest("b/346896669")
120
+ config = toy_model_with_kv_cache.get_model_config()
121
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
122
+ idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
123
+ [10], dtype=torch.int64
124
+ )
125
+
126
+ quantized_model = ai_edge_torch.convert(
127
+ pytorch_model, (idx, input_pos), quant_config=quant_config
128
+ )
129
+ float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
130
+ self.assertAlmostEqual(
131
+ len(quantized_model._tflite_model) / len(float_model._tflite_model),
132
+ expected_compression,
133
+ delta=0.01,
134
+ )
135
+
136
+ def test_quantize_convert_compare_toy(self):
82
137
  self.skipTest("b/338288901")
83
138
  config = toy_model_with_kv_cache.get_model_config()
84
139
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)