ai-edge-torch-nightly 0.2.0.dev20240610__py3-none-any.whl → 0.2.0.dev20240611__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -24,6 +24,7 @@ from typing import Any, Dict, Optional, Tuple, Union
24
24
  import torch
25
25
  from torch_xla import stablehlo
26
26
 
27
+ from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
27
28
  from ai_edge_torch.quantize import quant_config as qcfg
28
29
 
29
30
  try:
@@ -249,11 +250,6 @@ def _set_tfl_converter_quant_flags(
249
250
  converter._experimental_qdq_conversion_mode = "DYNAMIC"
250
251
  elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_STATIC:
251
252
  converter._experimental_qdq_conversion_mode = "STATIC"
252
- elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.TFLITE_DYNAMIC:
253
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
254
- elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.TFLITE_FP16:
255
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
256
- converter.target_spec.supported_types = [tf.float16]
257
253
 
258
254
 
259
255
  def convert_stablehlo_to_tflite(
@@ -323,8 +319,24 @@ def convert_stablehlo_to_tflite(
323
319
  converter._experimental_enable_composite_direct_lowering = True
324
320
 
325
321
  _set_tfl_converter_quant_flags(converter, quant_config)
322
+ if (
323
+ quant_config is not None
324
+ and quant_config._quantizer_mode
325
+ == quant_config._QuantizerMode.AI_EDGE_QUANTIZER
326
+ ):
327
+ translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
328
+ quant_config.generative_recipe
329
+ )
330
+
326
331
  _apply_tfl_backdoor_flags(converter, _tfl_converter_flags)
327
332
 
328
333
  tflite_model = converter.convert()
329
334
 
335
+ if (
336
+ quant_config is not None
337
+ and quant_config._quantizer_mode
338
+ == quant_config._QuantizerMode.AI_EDGE_QUANTIZER
339
+ ):
340
+ tflite_model = translate_recipe.quantize_model(tflite_model, translated_recipe)
341
+
330
342
  return tflite_model
@@ -27,6 +27,7 @@ class ActivationType(enum.Enum):
27
27
  SILU = enum.auto()
28
28
  GELU = enum.auto()
29
29
  GELU_TANH = enum.auto()
30
+ GELU_QUICK = enum.auto()
30
31
  GE_GLU = enum.auto()
31
32
  RELU = enum.auto()
32
33
 
@@ -0,0 +1,164 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import json
17
+
18
+ from ai_edge_quantizer import quantizer
19
+
20
+ from ai_edge_torch.generative.quantize import quant_attrs
21
+ from ai_edge_torch.generative.quantize import quant_recipe
22
+
23
+ _OpExecutionMode = quantizer.qtyping.OpExecutionMode
24
+ _OpName = quantizer.qtyping.TFLOperationName
25
+ _TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig
26
+ _OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
27
+
28
+ _DEFAULT_REGEX_STR = '.*'
29
+ _ATTENTION_IDX_REGEX_STR = (
30
+ 'transformer_blocks\[{}\]/ai_edge_torch.generative.layers.attention'
31
+ )
32
+ _FEEDFORWARD_IDX_REGEX_STR = (
33
+ 'transformer_blocks\[{}\]/ai_edge_torch.generative.layers.feed_forward'
34
+ )
35
+ _EMBEDDING_REGEX_STR = 'Embedding_tok_embedding'
36
+ _ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}'
37
+
38
+
39
+ def _get_nbits_from_dtype(dtype: quant_attrs.Dtype) -> int:
40
+ if dtype == quant_attrs.Dtype.FP32:
41
+ return 32
42
+ elif dtype == quant_attrs.Dtype.FP16:
43
+ return 16
44
+ elif dtype == quant_attrs.Dtype.INT8:
45
+ return 8
46
+ raise ValueError('Unimplemented number of bits')
47
+
48
+
49
+ def _get_dtype_from_dtype(dtype: quant_attrs.Dtype) -> quantizer.qtyping.TensorDataType:
50
+ if dtype == quant_attrs.Dtype.FP32 or dtype == quant_attrs.Dtype.FP16:
51
+ return quantizer.qtyping.TensorDataType.FLOAT
52
+ else:
53
+ return quantizer.qtyping.TensorDataType.INT
54
+
55
+
56
+ def _get_execution_mode_from_mode(mode: quant_attrs.Mode) -> _OpExecutionMode:
57
+ if mode == quant_attrs.Mode.DYNAMIC_RANGE:
58
+ return _OpExecutionMode.DRQ
59
+ elif mode == quant_attrs.Mode.WEIGHT_ONLY:
60
+ return _OpExecutionMode.WEIGHT_ONLY
61
+ raise ValueError('Unimplemented execution mode')
62
+
63
+
64
+ def _get_channelwise_from_granularity(granularity: quant_attrs.Granularity) -> bool:
65
+ if granularity == quant_attrs.Granularity.CHANNELWISE:
66
+ return True
67
+ elif granularity == quant_attrs.Granularity.NONE:
68
+ return False
69
+ raise ValueError('Unimplemented granularity')
70
+
71
+
72
+ def _get_algorithm_key_from_algorithm(algo: quant_attrs.Algorithm) -> str:
73
+ if algo == quant_attrs.Algorithm.MIN_MAX:
74
+ return quantizer.algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT
75
+ elif algo == quant_attrs.Algorithm.FLOAT_CAST:
76
+ return quantizer.algorithm_manager.AlgorithmName.FLOAT_CASTING
77
+ raise ValueError('Unimplemented algorithm')
78
+
79
+
80
+ def _set_quant_config(
81
+ rm: quantizer.recipe_manager.RecipeManager,
82
+ layer_recipe: quant_recipe.LayerQuantRecipe,
83
+ regex: str,
84
+ ):
85
+ support_op_list = [_OpName.FULLY_CONNECTED, _OpName.CONV_2D]
86
+ if layer_recipe.algorithm == quant_attrs.Algorithm.MIN_MAX:
87
+ support_op_list += [_OpName.BATCH_MATMUL, _OpName.EMBEDDING_LOOKUP]
88
+ for op_name in support_op_list:
89
+ rm.add_quantization_config(
90
+ regex=regex,
91
+ operation_name=op_name,
92
+ op_config=_OpQuantConfig(
93
+ weight_tensor_config=_TensorQuantConfig(
94
+ num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
95
+ symmetric=True,
96
+ channel_wise=_get_channelwise_from_granularity(
97
+ layer_recipe.granularity
98
+ ),
99
+ dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
100
+ ),
101
+ execution_mode=_get_execution_mode_from_mode(layer_recipe.mode),
102
+ ),
103
+ algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
104
+ override_algorithm=True,
105
+ )
106
+
107
+
108
+ def translate_to_ai_edge_recipe(
109
+ recipe: quant_recipe.GenerativeQuantRecipe,
110
+ ) -> quantizer.recipe_manager.ModelQuantizationRecipe:
111
+ rm = quantizer.recipe_manager.RecipeManager()
112
+
113
+ if recipe.default is not None:
114
+ _set_quant_config(rm, recipe.default, _DEFAULT_REGEX_STR)
115
+
116
+ if recipe.embedding is not None:
117
+ _set_quant_config(rm, recipe.embedding, _EMBEDDING_REGEX_STR)
118
+
119
+ if recipe.attention is not None:
120
+ if isinstance(recipe.attention, dict):
121
+ for idx, layer in recipe.attention.items():
122
+ _set_quant_config(rm, layer, _ATTENTION_IDX_REGEX_STR.format(idx))
123
+ else:
124
+ _set_quant_config(
125
+ rm,
126
+ recipe.attention,
127
+ _ATTENTION_IDX_REGEX_STR.format(_ANY_TWO_DIGITS_REGEX_STR),
128
+ )
129
+
130
+ if recipe.feedforward is not None:
131
+ if isinstance(recipe.feedforward, dict):
132
+ for idx, layer in recipe.feedforward.items():
133
+ _set_quant_config(rm, layer, _FEEDFORWARD_IDX_REGEX_STR.format(idx))
134
+ else:
135
+ _set_quant_config(
136
+ rm,
137
+ recipe.feedforward,
138
+ _FEEDFORWARD_IDX_REGEX_STR.format(_ANY_TWO_DIGITS_REGEX_STR),
139
+ )
140
+
141
+ return rm.get_quantization_recipe()
142
+
143
+
144
+ def quantize_model(
145
+ model: bytearray, recipe: quantizer.recipe_manager.ModelQuantizationRecipe
146
+ ) -> bytearray:
147
+ # TODO(b/336599483): Remove tempfile and use bytearray instead
148
+ tmp_model_path = '/tmp/tmp.tflite'
149
+ tmp_recipe_path = '/tmp/recipe.json'
150
+ with open(tmp_model_path, 'wb') as fp:
151
+ fp.write(model)
152
+ with open(tmp_recipe_path, 'w') as rp:
153
+ rp.write(json.dumps(recipe))
154
+
155
+ qt = quantizer.Quantizer(tmp_model_path, tmp_recipe_path)
156
+ result = qt.quantize()
157
+
158
+ # TODO(b/336599483): Remove tempfile and use bytearray instead
159
+ import os
160
+
161
+ os.remove(tmp_model_path)
162
+ os.remove(tmp_recipe_path)
163
+
164
+ return result.quantized_model
@@ -32,9 +32,11 @@ class Algorithm(enum.Enum):
32
32
  Attributes:
33
33
  MIN_MAX: Maps the min/max of floating point space to the min/max of
34
34
  quantized space and quantize uniformly.
35
+ FLOAT_CAST: Casts a float to another float of a different type.
35
36
  """
36
37
 
37
38
  MIN_MAX = enum.auto()
39
+ FLOAT_CAST = enum.auto()
38
40
 
39
41
 
40
42
  @enum.unique
@@ -14,8 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from dataclasses import dataclass
17
- import enum
18
- from typing import Optional
17
+ from typing import Optional, Union
19
18
 
20
19
  from ai_edge_torch.generative.quantize import quant_attrs
21
20
  from ai_edge_torch.generative.quantize import supported_schemes
@@ -80,18 +79,50 @@ class LayerQuantRecipe:
80
79
 
81
80
 
82
81
  @dataclass
83
- class TransformerQuantRecipe:
82
+ class GenerativeQuantRecipe:
84
83
  """Quantization recipe for a model composed of the Edge Generative API layers.
85
84
 
85
+ Some layers can be specified with different `LayerQuantRecipe` for each block by
86
+ providing a dictionary keyed by the TransformerBlock index, e.g. attention
87
+ and feedforward. For example,
88
+
89
+ ```
90
+ default = LayerQuantRecipeA
91
+ attention = { 2: LayerQuantRecipeB }
92
+ feedforward = { 3: LayerQuantRecipeC }
93
+ ```
94
+
95
+ will apply LayerQuantRecipeA to the entire model, overriden by
96
+ LayerQuantRecipeB for the TransformerBlock[2].attention layer and
97
+ LayerQuantRecipeC for the TransformerBlock[3].feedforward layer. Any config
98
+ with invalid indices will be ignored.
99
+
86
100
  Attributes:
87
101
  default: The quantization recipe for global scope of the model.
102
+ embedding: Recipe for the embedding table.
103
+ attention: Recipe for the attention blocks. This could be specified with
104
+ different LayerQuantRecipe for each block by providing a dictionary
105
+ keyed by the TransformerBlock index.
106
+ feedforward: Recipe for the feedforward layers. This could be specified with
107
+ different LayerQuantRecipe for each block by providing a dictionary
108
+ keyed by the TransformerBlock index.
88
109
  """
89
110
 
90
111
  default: Optional[LayerQuantRecipe] = None
112
+ embedding: Optional[LayerQuantRecipe] = None
113
+ attention: Union[
114
+ Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
115
+ ] = None
116
+ feedforward: Union[
117
+ Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
118
+ ] = None
91
119
 
92
120
  def __str__(self):
93
- return f"""TransformerQuantRecipe(
121
+ return f"""GenerativeQuantRecipe(
94
122
  Default: {self.default}
123
+ Embedding: {self.embedding}
124
+ Attention: {self.attention}
125
+ Feedforward: {self.feedforward}
95
126
  )"""
96
127
 
97
128
  __repr__ = __str__
@@ -104,3 +135,17 @@ class TransformerQuantRecipe:
104
135
  """
105
136
  if self.default is not None:
106
137
  self.default.verify()
138
+ if self.embedding is not None:
139
+ self.embedding.verify()
140
+ if self.attention is not None:
141
+ if isinstance(self.attention, dict):
142
+ for recipe in self.attention.values():
143
+ recipe.verify()
144
+ else:
145
+ self.attention.verify()
146
+ if self.feedforward is not None:
147
+ if isinstance(self.feedforward, dict):
148
+ for recipe in self.feedforward.values():
149
+ recipe.verify()
150
+ else:
151
+ self.feedforward.verify()
@@ -22,7 +22,7 @@ Typical usage example:
22
22
 
23
23
  1. Applying a single layer recipe to the entire model
24
24
 
25
- quant_recipe.TransformerQuantRecipe(
25
+ quant_recipe.GenerativeQuantRecipe(
26
26
  default=quant_recipe_utils.create_layer_quant_int8_dynamic()
27
27
  )
28
28
  """
@@ -46,6 +46,6 @@ def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
46
46
  activation_dtype=quant_attrs.Dtype.FP32,
47
47
  weight_dtype=quant_attrs.Dtype.FP16,
48
48
  mode=quant_attrs.Mode.WEIGHT_ONLY,
49
- algorithm=quant_attrs.Algorithm.MIN_MAX,
49
+ algorithm=quant_attrs.Algorithm.FLOAT_CAST,
50
50
  granularity=quant_attrs.Granularity.NONE,
51
51
  )
@@ -34,15 +34,15 @@ from ai_edge_torch.quantize import quant_config
34
34
 
35
35
  def full_linear_int8_dynamic_recipe() -> quant_config.QuantConfig:
36
36
  return quant_config.QuantConfig(
37
- transformer_recipe=quant_recipe.TransformerQuantRecipe(
38
- default=quant_recipe_utils.create_layer_quant_int8_dynamic()
37
+ generative_recipe=quant_recipe.GenerativeQuantRecipe(
38
+ default=quant_recipe_utils.create_layer_quant_int8_dynamic(),
39
39
  )
40
40
  )
41
41
 
42
42
 
43
43
  def full_fp16_recipe() -> quant_config.QuantConfig:
44
44
  return quant_config.QuantConfig(
45
- transformer_recipe=quant_recipe.TransformerQuantRecipe(
45
+ generative_recipe=quant_recipe.GenerativeQuantRecipe(
46
46
  default=quant_recipe_utils.create_layer_quant_fp16()
47
47
  )
48
48
  )
@@ -27,5 +27,6 @@ def get_supported_layer_schemes():
27
27
 
28
28
  return [
29
29
  (_t.FP32, _t.INT8, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.CHANNELWISE),
30
- (_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.NONE),
30
+ (_t.FP32, _t.INT8, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.CHANNELWISE),
31
+ (_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.FLOAT_CAST, _g.NONE),
31
32
  ]
@@ -21,11 +21,13 @@ import torch
21
21
  import ai_edge_torch
22
22
  from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
23
23
  from ai_edge_torch.generative.quantize import quant_recipe
24
+ from ai_edge_torch.generative.quantize import quant_recipe_utils
24
25
  from ai_edge_torch.generative.quantize import quant_recipes
25
26
  from ai_edge_torch.generative.quantize.quant_attrs import Algorithm
26
27
  from ai_edge_torch.generative.quantize.quant_attrs import Dtype
27
28
  from ai_edge_torch.generative.quantize.quant_attrs import Granularity
28
29
  from ai_edge_torch.generative.quantize.quant_attrs import Mode
30
+ from ai_edge_torch.quantize import quant_config
29
31
  from ai_edge_torch.testing import model_coverage
30
32
 
31
33
 
@@ -34,34 +36,47 @@ class TestVerifyRecipes(unittest.TestCase):
34
36
 
35
37
  @parameterized.expand(
36
38
  [
37
- (Dtype.FP32, Dtype.FP32, Mode.DYNAMIC_RANGE),
38
- (Dtype.INT8, Dtype.INT8, Mode.DYNAMIC_RANGE),
39
- (Dtype.INT8, Dtype.FP16, Mode.DYNAMIC_RANGE),
40
- (Dtype.FP16, Dtype.INT8, Mode.DYNAMIC_RANGE),
41
- (Dtype.FP32, Dtype.FP32, Mode.WEIGHT_ONLY),
42
- (Dtype.INT8, Dtype.INT8, Mode.WEIGHT_ONLY),
43
- (Dtype.FP16, Dtype.INT8, Mode.WEIGHT_ONLY),
44
- (Dtype.INT8, Dtype.FP16, Mode.WEIGHT_ONLY),
45
- (Dtype.FP16, Dtype.FP16, Mode.WEIGHT_ONLY),
39
+ (Dtype.FP32, Dtype.FP32),
40
+ (Dtype.INT8, Dtype.INT8),
41
+ (Dtype.INT8, Dtype.FP16),
42
+ (Dtype.FP16, Dtype.INT8),
43
+ (Dtype.FP16, Dtype.FP16),
46
44
  ]
47
45
  )
48
46
  def test_verify_invalid_recipes(
49
47
  self,
50
48
  activation,
51
49
  weight,
52
- mode,
53
- algo=Algorithm.MIN_MAX,
54
- granularity=Granularity.CHANNELWISE,
55
50
  ):
56
- with self.assertRaises(ValueError):
57
- quant_recipe.LayerQuantRecipe(
58
- activation, weight, mode, algo, granularity
59
- ).verify()
51
+ for m in Mode:
52
+ for a in Algorithm:
53
+ for g in Granularity:
54
+ with self.assertRaises(ValueError):
55
+ quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()
60
56
 
61
57
  @parameterized.expand(
62
58
  [
63
- (Dtype.FP32, Dtype.INT8, Mode.DYNAMIC_RANGE, Granularity.CHANNELWISE),
64
- (Dtype.FP32, Dtype.FP16, Mode.WEIGHT_ONLY, Granularity.NONE),
59
+ (
60
+ Dtype.FP32,
61
+ Dtype.INT8,
62
+ Mode.DYNAMIC_RANGE,
63
+ Algorithm.MIN_MAX,
64
+ Granularity.CHANNELWISE,
65
+ ),
66
+ (
67
+ Dtype.FP32,
68
+ Dtype.INT8,
69
+ Mode.WEIGHT_ONLY,
70
+ Algorithm.MIN_MAX,
71
+ Granularity.CHANNELWISE,
72
+ ),
73
+ (
74
+ Dtype.FP32,
75
+ Dtype.FP16,
76
+ Mode.WEIGHT_ONLY,
77
+ Algorithm.FLOAT_CAST,
78
+ Granularity.NONE,
79
+ ),
65
80
  ]
66
81
  )
67
82
  def test_verify_valid_recipes(
@@ -69,8 +84,8 @@ class TestVerifyRecipes(unittest.TestCase):
69
84
  activation,
70
85
  weight,
71
86
  mode,
87
+ algo,
72
88
  granularity,
73
- algo=Algorithm.MIN_MAX,
74
89
  ):
75
90
  quant_recipe.LayerQuantRecipe(activation, weight, mode, algo, granularity).verify()
76
91
 
@@ -78,7 +93,46 @@ class TestVerifyRecipes(unittest.TestCase):
78
93
  class TestQuantizeConvert(unittest.TestCase):
79
94
  """Test conversion with quantization."""
80
95
 
81
- def test_quantize_convert_toy(self):
96
+ def _attention_1_int8_dynamic_recipe() -> quant_config.QuantConfig:
97
+ return quant_config.QuantConfig(
98
+ generative_recipe=quant_recipe.GenerativeQuantRecipe(
99
+ attention={1: quant_recipe_utils.create_layer_quant_int8_dynamic()},
100
+ )
101
+ )
102
+
103
+ def _feedforward_0_int8_dynamic_recipe() -> quant_config.QuantConfig:
104
+ return quant_config.QuantConfig(
105
+ generative_recipe=quant_recipe.GenerativeQuantRecipe(
106
+ feedforward={0: quant_recipe_utils.create_layer_quant_int8_dynamic()},
107
+ )
108
+ )
109
+
110
+ @parameterized.expand(
111
+ [
112
+ (quant_recipes.full_fp16_recipe(), 0.75),
113
+ (quant_recipes.full_linear_int8_dynamic_recipe(), 0.64),
114
+ (_attention_1_int8_dynamic_recipe(), 0.95),
115
+ (_feedforward_0_int8_dynamic_recipe(), 0.87),
116
+ ]
117
+ )
118
+ def test_quantize_convert_toy_sizes(self, quant_config, expected_compression):
119
+ config = toy_model_with_kv_cache.get_model_config()
120
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
121
+ idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
122
+ [10], dtype=torch.int64
123
+ )
124
+
125
+ quantized_model = ai_edge_torch.convert(
126
+ pytorch_model, (idx, input_pos), quant_config=quant_config
127
+ )
128
+ float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
129
+ self.assertAlmostEqual(
130
+ len(quantized_model._tflite_model) / len(float_model._tflite_model),
131
+ expected_compression,
132
+ delta=0.01,
133
+ )
134
+
135
+ def test_quantize_convert_compare_toy(self):
82
136
  self.skipTest("b/338288901")
83
137
  config = toy_model_with_kv_cache.get_model_config()
84
138
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
@@ -32,27 +32,26 @@ class QuantConfig:
32
32
  pt2e_quantizer: The instance of PT2EQuantizer used to quantize the model
33
33
  with PT2E quantization. This method of quantization is not applicable to
34
34
  models created with the Edge Generative API.
35
- transformer_recipe: Quantization recipe to be applied on a model created
35
+ generative_recipe: Quantization recipe to be applied on a model created
36
36
  with the Edge Generative API.
37
37
  """
38
38
 
39
39
  pt2e_quantizer: pt2eq.PT2EQuantizer = None
40
- transformer_recipe: quant_recipe.TransformerQuantRecipe = None
40
+ generative_recipe: quant_recipe.GenerativeQuantRecipe = None
41
41
 
42
42
  @enum.unique
43
43
  class _QuantizerMode(enum.Enum):
44
44
  NONE = enum.auto()
45
45
  PT2E_DYNAMIC = enum.auto()
46
46
  PT2E_STATIC = enum.auto()
47
- TFLITE_DYNAMIC = enum.auto()
48
- TFLITE_FP16 = enum.auto()
47
+ AI_EDGE_QUANTIZER = enum.auto()
49
48
 
50
49
  _quantizer_mode: _QuantizerMode = _QuantizerMode.NONE
51
50
 
52
51
  def __init__(
53
52
  self,
54
53
  pt2e_quantizer: Optional[pt2eq.PT2EQuantizer] = None,
55
- transformer_recipe: Optional[quant_recipe.TransformerQuantRecipe] = None,
54
+ generative_recipe: Optional[quant_recipe.GenerativeQuantRecipe] = None,
56
55
  ):
57
56
  """Initializes some internal states based on selected quantization method.
58
57
 
@@ -61,8 +60,8 @@ class QuantConfig:
61
60
  is properly setup. Additionally sets up an utility enum _quantizer_mode to
62
61
  guide certain conversion processes.
63
62
  """
64
- if pt2e_quantizer is not None and transformer_recipe is not None:
65
- raise ValueError('Cannot set both pt2e_quantizer and transformer_recipe.')
63
+ if pt2e_quantizer is not None and generative_recipe is not None:
64
+ raise ValueError('Cannot set both pt2e_quantizer and generative_recipe.')
66
65
  elif pt2e_quantizer is not None:
67
66
  object.__setattr__(self, 'pt2e_quantizer', pt2e_quantizer)
68
67
  object.__setattr__(
@@ -74,12 +73,9 @@ class QuantConfig:
74
73
  else self._QuantizerMode.PT2E_STATIC
75
74
  ),
76
75
  )
77
- elif transformer_recipe is not None:
78
- transformer_recipe.verify()
79
- object.__setattr__(self, 'transformer_recipe', transformer_recipe)
80
- if self.transformer_recipe.default.mode == quant_attrs.Mode.DYNAMIC_RANGE:
81
- object.__setattr__(self, '_quantizer_mode', self._QuantizerMode.TFLITE_DYNAMIC)
82
- elif self.transformer_recipe.default.weight_dtype == quant_attrs.Dtype.FP16:
83
- object.__setattr__(self, '_quantizer_mode', self._QuantizerMode.TFLITE_FP16)
76
+ elif generative_recipe is not None:
77
+ generative_recipe.verify()
78
+ object.__setattr__(self, 'generative_recipe', generative_recipe)
79
+ object.__setattr__(self, '_quantizer_mode', self._QuantizerMode.AI_EDGE_QUANTIZER)
84
80
  else:
85
- raise ValueError('Either pt2e_quantizer or transformer_recipe must be set.')
81
+ raise ValueError('Either pt2e_quantizer or generative_recipe must be set.')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240610
3
+ Version: 0.2.0.dev20240611
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=FPMmuFU3pyMREtjB_san1fy_0PFtAsgA0VZfOYvDrb4,100
2
2
  ai_edge_torch/model.py,sha256=kmcgELjsYl8YzF8nUF6P7q4i8MWS-pLGpfsy-yTUXmE,4243
3
3
  ai_edge_torch/convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
4
4
  ai_edge_torch/convert/conversion.py,sha256=GN2Js232u_5Y118wg3qIfEoYewxbxLl3TpSnO6osi8c,4029
5
- ai_edge_torch/convert/conversion_utils.py,sha256=NpVm3Ms81_cIW5IYgGsr0BVganJJgBKWVBDe5h_ZaGE,11021
5
+ ai_edge_torch/convert/conversion_utils.py,sha256=9BqCL38DErv1vEVGtT3BIJVhdwZjw2EQ-_m5UpvVVYE,11294
6
6
  ai_edge_torch/convert/converter.py,sha256=bjj5TV5_g4sGyuSh8ThEDydlNMqhkGSY4SzXK6vwhqI,6927
7
7
  ai_edge_torch/convert/fx_passes/__init__.py,sha256=EPs4PSIDLuRH5EBETi6deaOvaaf_Q4xD3_9NVcR7x8o,2810
8
8
  ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
@@ -70,7 +70,7 @@ ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNl
70
70
  ai_edge_torch/generative/layers/builder.py,sha256=jAyrR5hsSI0aimKZumyvxdJ1GovERIfsK0g-dezX2gs,4163
71
71
  ai_edge_torch/generative/layers/feed_forward.py,sha256=4j2QaSCw59Jkk_ixKDpKEj7FLRauzuExTiSNRzAjAhE,2820
72
72
  ai_edge_torch/generative/layers/kv_cache.py,sha256=4uiZLO3om5G3--kT04Jt0esEYznbkJ7QLzSHfb8mjc4,3090
73
- ai_edge_torch/generative/layers/model_config.py,sha256=g_XJXcQOCkE-mt58fSH4-T4GY_uLeMilg6mxwDMCfz4,4557
73
+ ai_edge_torch/generative/layers/model_config.py,sha256=toWECENDWgay9hsZcy4C89qph0KI3CpaeFqFc8Fr-Xk,4584
74
74
  ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0lLWrAGrSKSsbegRWnj3c,1867
75
75
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
76
76
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
@@ -80,15 +80,17 @@ ai_edge_torch/generative/layers/unet/builder.py,sha256=iH0_nuY9TF2ap5h1JbGNCOonP
80
80
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=sbtbDEHmMV9GLKngwjsNvqm8wovLxnlidkQbXdXkXKs,4060
81
81
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
82
82
  ai_edge_torch/generative/quantize/example.py,sha256=t-YwyKSPAG-OZC1DfH-0vfie2RHHpTSQjxUY-tmhu5g,1543
83
- ai_edge_torch/generative/quantize/quant_attrs.py,sha256=ffBALrrbrfiG_mrOr-f3B1Gc6PlAma9gtvVnfP7SDzI,1862
84
- ai_edge_torch/generative/quantize/quant_recipe.py,sha256=BOk4E0FW-_YD8Y-oPVmIDsgXx_bPtvzsP_V1av5DvgU,3327
85
- ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=9ktL7fT8C5j1dnY_7fkiFL4oWNLVs1dMWXkS_EuyA3Y,1913
86
- ai_edge_torch/generative/quantize/quant_recipes.py,sha256=CRA2ENevS-3usHqidWDe2wrf_epILE_7Hx-XfZQ9buk,1798
87
- ai_edge_torch/generative/quantize/supported_schemes.py,sha256=OQ4ghQXknA1PPjuY-xBgAmOpaIBgYFM8F2YAIot06hE,1345
83
+ ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
84
+ ai_edge_torch/generative/quantize/quant_recipe.py,sha256=Y8zahKw7b_h7ajPaJZVef4jG-MoqImRCpVSbFtV_i24,5139
85
+ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=-vd6Qp0BdXJVKg4f0_hhwbKOi3QPIAPVqyXnJ-ZnISQ,1915
86
+ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=9ItD70jQRXMEhWod-nUfEeoWGJUUu6V9YOffF07VU9g,1795
87
+ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
88
+ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
89
+ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=qUB4f2DoB14dLkNPWf6TZodpT81mfAJeWM-lCAmkuHY,5735
88
90
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
89
91
  ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
90
92
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=i_SAW-hD8SaHuopMZI9IuXXDFn5uSTJa1nKZhaC3dAQ,6811
91
- ai_edge_torch/generative/test/test_quantize.py,sha256=f70sH1ZFzdCwYj0MG-eg54WOC4LasR0D8CTUYpjxZYM,3728
93
+ ai_edge_torch/generative/test/test_quantize.py,sha256=NVlMixAxVpDUabEvp6zTHHgIDgHFsMRwlf5MuyDwrPg,5355
92
94
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
93
95
  ai_edge_torch/generative/utilities/autoencoder_loader.py,sha256=G2Nosy33JzkjGALPR4JjvffdFX1JWOj2zjbbuaDJEgg,10065
94
96
  ai_edge_torch/generative/utilities/loader.py,sha256=Hs92478j1g4jQGvbdP1aWvOy907HjwqQZE-NFy6HELo,11326
@@ -103,12 +105,12 @@ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=aUAPKnH4_Jxpp
103
105
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
104
106
  ai_edge_torch/quantize/pt2e_quantizer.py,sha256=ye1f5vAZ0Vr4RWAtfrgU1o3JLs03Sa4inHRq3YxJDGo,15602
105
107
  ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=yjzKoptnfEeW_sN7sODUfj3nCtUMXVzq3vHKxblsd5Y,36046
106
- ai_edge_torch/quantize/quant_config.py,sha256=ExThdTXqnWmGC3-F6sdXbXr8nYzkEe_qCziCfhsoMPA,3435
108
+ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDdN5XtvHwjc,3148
107
109
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
108
110
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
109
111
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
110
- ai_edge_torch_nightly-0.2.0.dev20240610.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
111
- ai_edge_torch_nightly-0.2.0.dev20240610.dist-info/METADATA,sha256=6hL5PV3S56VU2l6xqS-YrmzMZeajtXsikIdR7kDYcWE,1748
112
- ai_edge_torch_nightly-0.2.0.dev20240610.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
113
- ai_edge_torch_nightly-0.2.0.dev20240610.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
114
- ai_edge_torch_nightly-0.2.0.dev20240610.dist-info/RECORD,,
112
+ ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
113
+ ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/METADATA,sha256=WPGu2pq6N57fBtpunyFhunPe73UK_SVbqlZQsZwjWGo,1748
114
+ ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
115
+ ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
116
+ ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/RECORD,,