ai-edge-quantizer-nightly 0.4.0.dev20251001__py3-none-any.whl → 0.4.0.dev20251003__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +2 -0
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +28 -0
- ai_edge_quantizer/default_policy.py +4 -2
- ai_edge_quantizer/model_modifier.py +95 -5
- ai_edge_quantizer/model_modifier_test.py +80 -0
- ai_edge_quantizer/qtyping.py +1 -0
- ai_edge_quantizer/quantizer_test.py +118 -10
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +1 -1
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +1 -0
- {ai_edge_quantizer_nightly-0.4.0.dev20251001.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20251003.dist-info}/METADATA +1 -1
- {ai_edge_quantizer_nightly-0.4.0.dev20251001.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20251003.dist-info}/RECORD +14 -14
- {ai_edge_quantizer_nightly-0.4.0.dev20251001.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20251003.dist-info}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.4.0.dev20251001.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20251003.dist-info}/WHEEL +0 -0
- {ai_edge_quantizer_nightly-0.4.0.dev20251001.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20251003.dist-info}/top_level.txt +0 -0
|
@@ -131,6 +131,7 @@ MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
|
|
|
131
131
|
_TFLOpName.REDUCE_MIN: common_quantize.materialize_reduce_min,
|
|
132
132
|
_TFLOpName.EQUAL: common_quantize.materialize_equal,
|
|
133
133
|
_TFLOpName.NOT_EQUAL: common_quantize.materialize_not_equal,
|
|
134
|
+
_TFLOpName.MIRROR_PAD: common_quantize.materialize_mirror_pad,
|
|
134
135
|
}
|
|
135
136
|
for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
|
136
137
|
register_quantized_op(
|
|
@@ -284,6 +285,7 @@ _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
|
|
|
284
285
|
_TFLOpName.REDUCE_MIN: common_quantize.materialize_reduce_min,
|
|
285
286
|
_TFLOpName.EQUAL: common_quantize.materialize_equal,
|
|
286
287
|
_TFLOpName.NOT_EQUAL: common_quantize.materialize_not_equal,
|
|
288
|
+
_TFLOpName.MIRROR_PAD: common_quantize.materialize_mirror_pad,
|
|
287
289
|
})
|
|
288
290
|
|
|
289
291
|
for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
|
@@ -748,6 +748,34 @@ def materialize_padv2(
|
|
|
748
748
|
)
|
|
749
749
|
|
|
750
750
|
|
|
751
|
+
def materialize_mirror_pad(
|
|
752
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
753
|
+
op_info: qtyping.OpInfo,
|
|
754
|
+
graph_info: qtyping.GraphInfo,
|
|
755
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
756
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
757
|
+
"""Materialize tensors in tfl.mirror_pad.
|
|
758
|
+
|
|
759
|
+
Args:
|
|
760
|
+
get_tensor_quant_params_fn: Function to get quantization parameters for the
|
|
761
|
+
tensor.
|
|
762
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
|
763
|
+
graph_info: Graph information needed to perform quantization for the op.
|
|
764
|
+
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
|
765
|
+
|
|
766
|
+
Returns:
|
|
767
|
+
A list of `qtyping.TensorTransformationParams` for the tensors in the op.
|
|
768
|
+
"""
|
|
769
|
+
return common_utils.materialize_standard_op(
|
|
770
|
+
op_info,
|
|
771
|
+
graph_info,
|
|
772
|
+
tensor_name_to_qsv,
|
|
773
|
+
get_tensor_quant_params_fn,
|
|
774
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
775
|
+
inputs_to_ignore=[1], # Paddings tensor does not need to be quantized.
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
|
|
751
779
|
def materialize_squared_difference(
|
|
752
780
|
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
753
781
|
op_info: qtyping.OpInfo,
|
|
@@ -199,7 +199,8 @@ DEFAULT_JSON_POLICY = """
|
|
|
199
199
|
"PADV2",
|
|
200
200
|
"REDUCE_MIN",
|
|
201
201
|
"EQUAL",
|
|
202
|
-
"NOT_EQUAL"
|
|
202
|
+
"NOT_EQUAL",
|
|
203
|
+
"MIRROR_PAD"
|
|
203
204
|
],
|
|
204
205
|
"static_wi8_ai8": [
|
|
205
206
|
"ADD",
|
|
@@ -248,7 +249,8 @@ DEFAULT_JSON_POLICY = """
|
|
|
248
249
|
"PADV2",
|
|
249
250
|
"REDUCE_MIN",
|
|
250
251
|
"EQUAL",
|
|
251
|
-
"NOT_EQUAL"
|
|
252
|
+
"NOT_EQUAL",
|
|
253
|
+
"MIRROR_PAD"
|
|
252
254
|
],
|
|
253
255
|
"static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
|
|
254
256
|
"static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
|
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
from collections.abc import Sequence
|
|
19
19
|
import copy
|
|
20
|
+
import logging
|
|
20
21
|
|
|
21
22
|
import numpy as np
|
|
22
23
|
|
|
@@ -24,10 +25,15 @@ from ai_edge_quantizer import qtyping
|
|
|
24
25
|
from ai_edge_quantizer import transformation_instruction_generator
|
|
25
26
|
from ai_edge_quantizer import transformation_performer
|
|
26
27
|
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
|
28
|
+
from ai_edge_quantizer.utils import tfl_interpreter_utils
|
|
29
|
+
from ai_edge_litert import interpreter as tfl # pylint: disable=g-direct-tensorflow-import
|
|
27
30
|
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
|
|
28
31
|
from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
|
|
29
32
|
|
|
30
33
|
|
|
34
|
+
_DEQUANT_SUFFIX = "_dequant"
|
|
35
|
+
|
|
36
|
+
|
|
31
37
|
class ModelModifier:
|
|
32
38
|
"""Model Modifier class that produce the final quantized TFlite model."""
|
|
33
39
|
|
|
@@ -105,10 +111,94 @@ class ModelModifier:
|
|
|
105
111
|
)
|
|
106
112
|
constant_buffer_size = self._process_constant_map(quantized_model)
|
|
107
113
|
# we leave 256MB for the model architecture.
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
114
|
+
serialize_fun = (
|
|
115
|
+
self._serialize_large_model
|
|
116
|
+
if constant_buffer_size > 2**31 - 2**28
|
|
117
|
+
else self._serialize_small_model
|
|
118
|
+
)
|
|
119
|
+
serialized_quantized_model = serialize_fun(quantized_model)
|
|
120
|
+
|
|
121
|
+
# Update signature defs if dequant is inserted before output.
|
|
122
|
+
if self._has_dequant_before_output(instructions):
|
|
123
|
+
quantized_model = self._update_signature_defs_for_dequant_output(
|
|
124
|
+
quantized_model, serialized_quantized_model
|
|
125
|
+
)
|
|
126
|
+
serialized_quantized_model = serialize_fun(quantized_model)
|
|
127
|
+
|
|
128
|
+
return serialized_quantized_model
|
|
129
|
+
|
|
130
|
+
def _update_signature_defs_for_dequant_output(
|
|
131
|
+
self, model: schema_py_generated.ModelT, serialized_model: bytearray
|
|
132
|
+
):
|
|
133
|
+
"""Updates the signature definitions in the model.
|
|
134
|
+
|
|
135
|
+
This function is called when a dequantize operation is inserted before
|
|
136
|
+
an output tensor. It updates the tensor index in the signature
|
|
137
|
+
definitions to point to the newly inserted dequantize output tensor.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
model: The TFlite ModelT object.
|
|
141
|
+
serialized_model: The serialized bytearray of the TFlite model.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
The updated TFlite ModelT object.
|
|
145
|
+
"""
|
|
146
|
+
interpreter = tfl.Interpreter(model_content=bytes(serialized_model))
|
|
147
|
+
|
|
148
|
+
for signature_def in model.signatureDefs:
|
|
149
|
+
signature_key = signature_def.signatureKey.decode("utf-8")
|
|
150
|
+
logging.info("Signature = %s", signature_key)
|
|
151
|
+
subgraph_idx = tfl_interpreter_utils.get_signature_main_subgraph_index(
|
|
152
|
+
interpreter, signature_key
|
|
153
|
+
)
|
|
154
|
+
output_details = interpreter.get_signature_runner(
|
|
155
|
+
signature_key
|
|
156
|
+
).get_output_details()
|
|
157
|
+
subgraph = model.subgraphs[subgraph_idx]
|
|
158
|
+
graph_info = qtyping.GraphInfo(subgraph.tensors, model.buffers)
|
|
159
|
+
|
|
160
|
+
for output in subgraph.outputs:
|
|
161
|
+
tensor_name = tfl_flatbuffer_utils.get_tensor_name(
|
|
162
|
+
graph_info.subgraph_tensors[output]
|
|
163
|
+
)
|
|
164
|
+
logging.info("\tOutput tensor = `%s`", tensor_name)
|
|
165
|
+
|
|
166
|
+
for signature_name, tensor_details in output_details.items():
|
|
167
|
+
if tensor_details["name"] + _DEQUANT_SUFFIX == tensor_name:
|
|
168
|
+
logging.info(
|
|
169
|
+
"\t\tfound tensor mapping: `%s`->`%s` for signature name: `%s`",
|
|
170
|
+
tensor_details["name"],
|
|
171
|
+
tensor_name,
|
|
172
|
+
signature_name,
|
|
173
|
+
)
|
|
174
|
+
for signature_item in signature_def.outputs:
|
|
175
|
+
if signature_item.name.decode("utf-8") == signature_name:
|
|
176
|
+
signature_item.tensorIndex = output
|
|
177
|
+
logging.info(
|
|
178
|
+
"\t\t\tswapped tensor index: %s->%s",
|
|
179
|
+
tensor_details["index"],
|
|
180
|
+
output,
|
|
181
|
+
)
|
|
182
|
+
break
|
|
183
|
+
break
|
|
184
|
+
|
|
185
|
+
return model
|
|
186
|
+
|
|
187
|
+
def _has_dequant_before_output(
|
|
188
|
+
self, instructions: dict[str, qtyping.TensorTransformationInsts]
|
|
189
|
+
) -> bool:
|
|
190
|
+
"""Check if the model has dequant insert to output."""
|
|
191
|
+
for tensor_name, tensor_trans_insts in instructions.items():
|
|
192
|
+
for instr in tensor_trans_insts.instructions:
|
|
193
|
+
if (
|
|
194
|
+
qtyping.QuantTransformation.ADD_DEQUANTIZE == instr.transformation
|
|
195
|
+
and instr.consumers == [-1]
|
|
196
|
+
):
|
|
197
|
+
logging.info(
|
|
198
|
+
"Found dequant insert to output for tensor: %s", tensor_name
|
|
199
|
+
)
|
|
200
|
+
return True
|
|
201
|
+
return False
|
|
112
202
|
|
|
113
203
|
def _process_constant_map(
|
|
114
204
|
self, quantized_model: schema_py_generated.ModelT
|
|
@@ -142,7 +232,7 @@ class ModelModifier:
|
|
|
142
232
|
remainder = len(bytearr) % 16
|
|
143
233
|
if remainder != 0:
|
|
144
234
|
padding_size = 16 - remainder
|
|
145
|
-
bytearr.extend(b
|
|
235
|
+
bytearr.extend(b"\0" * padding_size)
|
|
146
236
|
|
|
147
237
|
# TODO: b/333797307 - support > 2GB output model
|
|
148
238
|
def _serialize_large_model(
|
|
@@ -125,6 +125,86 @@ class ModelModifierTest(parameterized.TestCase):
|
|
|
125
125
|
loosen_mem_use_factor = 4.5
|
|
126
126
|
self.assertLess(mem_peak / len(self._model_content), loosen_mem_use_factor)
|
|
127
127
|
|
|
128
|
+
def test_has_dequant_before_output_true(self):
|
|
129
|
+
instructions = {
|
|
130
|
+
'tensor1': qtyping.TensorTransformationInsts(
|
|
131
|
+
'tensor1',
|
|
132
|
+
0,
|
|
133
|
+
instructions=[
|
|
134
|
+
qtyping.TransformationInst(
|
|
135
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
|
136
|
+
tensor_id=0,
|
|
137
|
+
producer=0,
|
|
138
|
+
consumers=[-1],
|
|
139
|
+
)
|
|
140
|
+
],
|
|
141
|
+
)
|
|
142
|
+
}
|
|
143
|
+
self.assertTrue(
|
|
144
|
+
self._model_modifier._has_dequant_before_output(instructions)
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def test_has_dequant_before_output_false(self):
|
|
148
|
+
instructions = {
|
|
149
|
+
'tensor1': qtyping.TensorTransformationInsts(
|
|
150
|
+
'tensor1',
|
|
151
|
+
0,
|
|
152
|
+
instructions=[
|
|
153
|
+
qtyping.TransformationInst(
|
|
154
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
|
155
|
+
tensor_id=0,
|
|
156
|
+
producer=0,
|
|
157
|
+
consumers=[1],
|
|
158
|
+
)
|
|
159
|
+
],
|
|
160
|
+
)
|
|
161
|
+
}
|
|
162
|
+
self.assertFalse(
|
|
163
|
+
self._model_modifier._has_dequant_before_output(instructions)
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def test_pad_bytearray(self):
|
|
167
|
+
arr = bytearray(b'\x01\x02\x03')
|
|
168
|
+
self._model_modifier._pad_bytearray(arr)
|
|
169
|
+
self.assertLen(arr, 16)
|
|
170
|
+
self.assertEqual(arr, b'\x01\x02\x03' + b'\0' * 13)
|
|
171
|
+
|
|
172
|
+
arr = bytearray(b'\x01' * 16)
|
|
173
|
+
self._model_modifier._pad_bytearray(arr)
|
|
174
|
+
self.assertLen(arr, 16)
|
|
175
|
+
|
|
176
|
+
arr = bytearray(b'\x01' * 17)
|
|
177
|
+
self._model_modifier._pad_bytearray(arr)
|
|
178
|
+
self.assertLen(arr, 32)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class ModelModifierTestWithSignature(parameterized.TestCase):
|
|
182
|
+
|
|
183
|
+
def setUp(self):
|
|
184
|
+
super().setUp()
|
|
185
|
+
self._model_path = os.path.join(
|
|
186
|
+
TEST_DATA_PREFIX_PATH,
|
|
187
|
+
'tests/models/single_fc.tflite',
|
|
188
|
+
)
|
|
189
|
+
self._model_content: bytes = tfl_flatbuffer_utils.get_model_content(
|
|
190
|
+
self._model_path
|
|
191
|
+
)
|
|
192
|
+
self._model_modifier = model_modifier.ModelModifier(self._model_content)
|
|
193
|
+
|
|
194
|
+
def test_update_signature_defs_for_dequant_output_succeeds(self):
|
|
195
|
+
# This is a simplified test that only checks if the function runs without
|
|
196
|
+
# crashing and returns a model. A more thorough test with a model
|
|
197
|
+
# with a known signature was added in `quantizer_test`.
|
|
198
|
+
model_bytearray = flatbuffer_utils.read_model_from_bytearray(
|
|
199
|
+
self._model_content
|
|
200
|
+
)
|
|
201
|
+
updated_model = (
|
|
202
|
+
self._model_modifier._update_signature_defs_for_dequant_output(
|
|
203
|
+
model_bytearray, bytearray(self._model_content)
|
|
204
|
+
)
|
|
205
|
+
)
|
|
206
|
+
self.assertIsNotNone(updated_model)
|
|
207
|
+
|
|
128
208
|
|
|
129
209
|
if __name__ == '__main__':
|
|
130
210
|
googletest.main()
|
ai_edge_quantizer/qtyping.py
CHANGED
|
@@ -51,6 +51,30 @@ def _get_calibration_data(num_samples: int = 16):
|
|
|
51
51
|
return calibration_data
|
|
52
52
|
|
|
53
53
|
|
|
54
|
+
def _is_all_signature_defs_inputs_float(model_content: bytes):
|
|
55
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_content)
|
|
56
|
+
for signature_key in tfl_interpreter.get_signature_list():
|
|
57
|
+
input_details = tfl_interpreter.get_signature_runner(
|
|
58
|
+
signature_key
|
|
59
|
+
).get_input_details()
|
|
60
|
+
for tensor_details in input_details.values():
|
|
61
|
+
if tensor_details['dtype'] != np.float32:
|
|
62
|
+
return False
|
|
63
|
+
return True
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _is_all_signature_defs_outputs_float(model_content: bytes):
|
|
67
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_content)
|
|
68
|
+
for signature_key in tfl_interpreter.get_signature_list():
|
|
69
|
+
output_details = tfl_interpreter.get_signature_runner(
|
|
70
|
+
signature_key
|
|
71
|
+
).get_output_details()
|
|
72
|
+
for tensor_details in output_details.values():
|
|
73
|
+
if tensor_details['dtype'] != np.float32:
|
|
74
|
+
return False
|
|
75
|
+
return True
|
|
76
|
+
|
|
77
|
+
|
|
54
78
|
class QuantizerTest(parameterized.TestCase):
|
|
55
79
|
|
|
56
80
|
def setUp(self):
|
|
@@ -547,21 +571,21 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
|
|
|
547
571
|
'signature_1': [{
|
|
548
572
|
'cache_0': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
|
549
573
|
'cache_1': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
|
550
|
-
'positions':
|
|
551
|
-
np.int32
|
|
574
|
+
'positions': (
|
|
575
|
+
_RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
|
|
552
576
|
),
|
|
553
|
-
'tokens':
|
|
554
|
-
np.int32
|
|
577
|
+
'tokens': (
|
|
578
|
+
_RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
|
|
555
579
|
),
|
|
556
580
|
}],
|
|
557
581
|
'signature_2': [{
|
|
558
582
|
'cache_0': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
|
559
583
|
'cache_1': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
|
560
|
-
'positions':
|
|
561
|
-
np.int32
|
|
584
|
+
'positions': (
|
|
585
|
+
_RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
|
|
562
586
|
),
|
|
563
|
-
'tokens':
|
|
564
|
-
np.int32
|
|
587
|
+
'tokens': (
|
|
588
|
+
_RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
|
|
565
589
|
),
|
|
566
590
|
}],
|
|
567
591
|
}
|
|
@@ -578,8 +602,8 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
|
|
|
578
602
|
)
|
|
579
603
|
|
|
580
604
|
self._quantizer.update_quantization_recipe(
|
|
581
|
-
regex='
|
|
582
|
-
operation_name=qtyping.TFLOperationName.
|
|
605
|
+
regex='.*',
|
|
606
|
+
operation_name=qtyping.TFLOperationName.OUTPUT,
|
|
583
607
|
algorithm_key=_AlgorithmName.NO_QUANTIZE,
|
|
584
608
|
)
|
|
585
609
|
|
|
@@ -591,6 +615,90 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
|
|
|
591
615
|
self._quantizer.quantize(calib_result)
|
|
592
616
|
self.assertIsNotNone(self._quantizer._result.quantized_model)
|
|
593
617
|
|
|
618
|
+
def test_toy_gemma2_update_signature_defs_succeeds(self):
|
|
619
|
+
|
|
620
|
+
self.assertTrue(
|
|
621
|
+
_is_all_signature_defs_outputs_float(
|
|
622
|
+
open(self._test_model_path, 'rb').read()
|
|
623
|
+
)
|
|
624
|
+
)
|
|
625
|
+
calib_result = self._quantizer.calibrate(
|
|
626
|
+
self._toy_gemma2_calibration_dataset
|
|
627
|
+
)
|
|
628
|
+
self.assertIsNotNone(calib_result)
|
|
629
|
+
self._quantizer.quantize(calib_result)
|
|
630
|
+
self.assertIsNotNone(self._quantizer._result.quantized_model)
|
|
631
|
+
self.assertTrue(
|
|
632
|
+
_is_all_signature_defs_outputs_float(
|
|
633
|
+
self._quantizer._result.quantized_model
|
|
634
|
+
)
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
class QuantizerFullyConnectedTest(parameterized.TestCase):
|
|
639
|
+
|
|
640
|
+
def setUp(self):
|
|
641
|
+
super().setUp()
|
|
642
|
+
self._tmp_save_path = self.create_tempdir().full_path
|
|
643
|
+
self._test_model_path = os.path.join(
|
|
644
|
+
TEST_DATA_PREFIX_PATH,
|
|
645
|
+
'tests/models/single_fc.tflite',
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
self._test_recipe_path = os.path.join(
|
|
649
|
+
TEST_DATA_PREFIX_PATH,
|
|
650
|
+
'recipes/default_a8w8_recipe.json',
|
|
651
|
+
)
|
|
652
|
+
with open(self._test_recipe_path) as json_file:
|
|
653
|
+
self._test_recipe = json.load(json_file)
|
|
654
|
+
|
|
655
|
+
self._quantizer = quantizer.Quantizer(
|
|
656
|
+
self._test_model_path, self._test_recipe_path
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
self._quantizer.update_quantization_recipe(
|
|
660
|
+
regex='.*',
|
|
661
|
+
operation_name=qtyping.TFLOperationName.INPUT,
|
|
662
|
+
algorithm_key=_AlgorithmName.NO_QUANTIZE,
|
|
663
|
+
)
|
|
664
|
+
self._quantizer.update_quantization_recipe(
|
|
665
|
+
regex='.*',
|
|
666
|
+
operation_name=qtyping.TFLOperationName.OUTPUT,
|
|
667
|
+
algorithm_key=_AlgorithmName.NO_QUANTIZE,
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
def test_fully_connected_quantization_succeeds(self):
|
|
671
|
+
calib_result = self._quantizer.calibrate(
|
|
672
|
+
tfl_interpreter_utils.create_random_normal_input_data(
|
|
673
|
+
self._test_model_path, num_samples=4
|
|
674
|
+
)
|
|
675
|
+
)
|
|
676
|
+
self.assertIsNotNone(calib_result)
|
|
677
|
+
self._quantizer.quantize(calib_result)
|
|
678
|
+
self.assertIsNotNone(self._quantizer._result.quantized_model)
|
|
679
|
+
|
|
680
|
+
def test_fully_connected_quantization_update_signature_defs_succeeds(self):
|
|
681
|
+
|
|
682
|
+
model_content = open(self._test_model_path, 'rb').read()
|
|
683
|
+
self.assertTrue(_is_all_signature_defs_inputs_float(model_content))
|
|
684
|
+
self.assertTrue(_is_all_signature_defs_outputs_float(model_content))
|
|
685
|
+
|
|
686
|
+
calib_result = self._quantizer.calibrate(
|
|
687
|
+
tfl_interpreter_utils.create_random_normal_input_data(
|
|
688
|
+
self._test_model_path, num_samples=4
|
|
689
|
+
)
|
|
690
|
+
)
|
|
691
|
+
self.assertIsNotNone(calib_result)
|
|
692
|
+
quant_result = self._quantizer.quantize(calib_result)
|
|
693
|
+
self.assertIsNotNone(quant_result.quantized_model)
|
|
694
|
+
|
|
695
|
+
self.assertTrue(
|
|
696
|
+
_is_all_signature_defs_inputs_float(quant_result.quantized_model)
|
|
697
|
+
)
|
|
698
|
+
self.assertTrue(
|
|
699
|
+
_is_all_signature_defs_outputs_float(quant_result.quantized_model)
|
|
700
|
+
)
|
|
701
|
+
|
|
594
702
|
|
|
595
703
|
if __name__ == '__main__':
|
|
596
704
|
googletest.main()
|
|
@@ -28,7 +28,7 @@ class ConstrainedOpsUtilsTest(parameterized.TestCase):
|
|
|
28
28
|
dict(
|
|
29
29
|
testcase_name="same_as_input_scale",
|
|
30
30
|
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
31
|
-
expected_num_ops=
|
|
31
|
+
expected_num_ops=16,
|
|
32
32
|
),
|
|
33
33
|
dict(
|
|
34
34
|
testcase_name="same_as_output_scale",
|
|
@@ -74,6 +74,7 @@ TFL_OP_NAME_TO_CODE = immutabledict.immutabledict({
|
|
|
74
74
|
_TFLOpName.REDUCE_MIN: schema.BuiltinOperator.REDUCE_MIN,
|
|
75
75
|
_TFLOpName.EQUAL: schema.BuiltinOperator.EQUAL,
|
|
76
76
|
_TFLOpName.NOT_EQUAL: schema.BuiltinOperator.NOT_EQUAL,
|
|
77
|
+
_TFLOpName.MIRROR_PAD: schema.BuiltinOperator.MIRROR_PAD,
|
|
77
78
|
})
|
|
78
79
|
|
|
79
80
|
TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-quantizer-nightly
|
|
3
|
-
Version: 0.4.0.
|
|
3
|
+
Version: 0.4.0.dev20251003
|
|
4
4
|
Summary: A quantizer for advanced developers to quantize converted AI Edge models.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
|
|
@@ -1,20 +1,20 @@
|
|
|
1
1
|
ai_edge_quantizer/__init__.py,sha256=4pFSkukSwahYyzwqia0yPRyz8TnFQfGRthVJhYpMWas,793
|
|
2
|
-
ai_edge_quantizer/algorithm_manager.py,sha256=
|
|
2
|
+
ai_edge_quantizer/algorithm_manager.py,sha256=Ri4bNqbSTmtlsYZiJYHtkjNEsl8h5tZ_1uV_stJ3HUY,16156
|
|
3
3
|
ai_edge_quantizer/algorithm_manager_api.py,sha256=u903TG0s1uIDhJqfeJne3CFl8A93phZrwgV2-hwdcXU,9247
|
|
4
4
|
ai_edge_quantizer/algorithm_manager_api_test.py,sha256=w6bSONvXkX6bzXAGc0-7b6gNDt9oz9ieq97KP8Sg_JU,7666
|
|
5
5
|
ai_edge_quantizer/calibrator.py,sha256=Sms7_AIHPH9G5xFaz5Ef3a5gPhxuIWQI8d2LUM8C96I,12071
|
|
6
6
|
ai_edge_quantizer/calibrator_test.py,sha256=ZLzIMWB2FSFU4TOatDioYuwp_kLh8iSCefZ5_Q9FU7s,11900
|
|
7
7
|
ai_edge_quantizer/conftest.py,sha256=SxCz-5LlRD_lQm4hQc4c6IGG7DS8d7IyEWY9gnscPN0,794
|
|
8
|
-
ai_edge_quantizer/default_policy.py,sha256=
|
|
9
|
-
ai_edge_quantizer/model_modifier.py,sha256=
|
|
10
|
-
ai_edge_quantizer/model_modifier_test.py,sha256=
|
|
8
|
+
ai_edge_quantizer/default_policy.py,sha256=6eJA0eX5Npv8lw_0EDS5iPldInoURQKEDhDZ272VG1Q,11770
|
|
9
|
+
ai_edge_quantizer/model_modifier.py,sha256=U70JByv6CItP8tg4bdyMfX-R3UlwylAGSviZkF_FSAM,10468
|
|
10
|
+
ai_edge_quantizer/model_modifier_test.py,sha256=CV4pgMEQkBJr_qbYR720TO8HBCutbEYLHptDHgdQMUE,7274
|
|
11
11
|
ai_edge_quantizer/model_validator.py,sha256=Hj0_5o-Oa3dSlJ3ryVjRhvsyelHNyek1GrtG9buMczg,13153
|
|
12
12
|
ai_edge_quantizer/model_validator_test.py,sha256=EeqOP_mrZsnZ3rug756s0ryDDqd2KgIDld5Lm_gDuWY,13020
|
|
13
13
|
ai_edge_quantizer/params_generator.py,sha256=0w-sDGk84sVNkXoduon1wDqq30sGOHVgBVbdg44QVF4,20153
|
|
14
14
|
ai_edge_quantizer/params_generator_test.py,sha256=RDYoRZDJfEZRtjlTAU2kZ_4t3JHOqEHxfJX9V4ETAhg,40597
|
|
15
|
-
ai_edge_quantizer/qtyping.py,sha256=
|
|
15
|
+
ai_edge_quantizer/qtyping.py,sha256=7aEMPA4qr4CGD3NXtZgG2fDoQX5NzK9jwSv1yWNqQV4,17149
|
|
16
16
|
ai_edge_quantizer/quantizer.py,sha256=ckAEOnnBxuCKZuvlzdChevCKPuE-IeDPHCNtFTWr250,17857
|
|
17
|
-
ai_edge_quantizer/quantizer_test.py,sha256=
|
|
17
|
+
ai_edge_quantizer/quantizer_test.py,sha256=bh4IowxRF249p_XKIKQ0f17PmeDddfcOUzvQ2ht1L0E,26530
|
|
18
18
|
ai_edge_quantizer/recipe.py,sha256=MEkfQ2Sg3KAE9LAORHWcbjYNPg06EUbwc1d-VspQA2U,6461
|
|
19
19
|
ai_edge_quantizer/recipe_manager.py,sha256=6l2uq8KL23KLu9OQDmPGkxrFiwHrdDB9xnn-ni8WdEM,15036
|
|
20
20
|
ai_edge_quantizer/recipe_manager_test.py,sha256=qjgGUF-wggXnSXqZ5khmqrDMIQI5CShk52IVWTahq6s,36817
|
|
@@ -28,7 +28,7 @@ ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py,sha256=lpq1g2ayg3lCP
|
|
|
28
28
|
ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py,sha256=Bs9CK7wZAw6jNaZ8xEtbwO2vM34VYXNZSMVWvxJo9nw,9297
|
|
29
29
|
ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py,sha256=EqIHGEZ1LgUrTN7zf880RuAzEv3Qy7kgh5ivObJGHSo,22646
|
|
30
30
|
ai_edge_quantizer/algorithms/uniform_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
|
|
31
|
-
ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=
|
|
31
|
+
ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=rUHraqNi0iJ0AyUQfAiYkyXG2rSy7GlhnqmwJeoLStg,38952
|
|
32
32
|
ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py,sha256=GGf_n3wIeg3GB_eGsmyNJ0fTcxgpeMMbugTMRONK6TQ,3553
|
|
33
33
|
ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=BDdn_uBZakfHyzdMJPKadsOqxqyC-s6W2ZzFH99L4fE,8652
|
|
34
34
|
ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py,sha256=sT5eX5TLZEHTtPfnSkCPDlS0sQxlTFWbCsbvOuj--yY,8889
|
|
@@ -66,16 +66,16 @@ ai_edge_quantizer/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V
|
|
|
66
66
|
ai_edge_quantizer/utils/calibration_utils.py,sha256=iMf_bSCf-O86MzDt5D9hLKqbTydqLwirluaC6BJ9yHo,11553
|
|
67
67
|
ai_edge_quantizer/utils/calibration_utils_test.py,sha256=4BlksXl7b4yptL8xPR67hmJCnjhN9V10a2PunzfHrUE,9372
|
|
68
68
|
ai_edge_quantizer/utils/constrained_ops_utils.py,sha256=EAITCf7Ku_PFZcw3K-wd-8hGbyuRd5W5UtNdGvalwAE,4478
|
|
69
|
-
ai_edge_quantizer/utils/constrained_ops_utils_test.py,sha256=
|
|
69
|
+
ai_edge_quantizer/utils/constrained_ops_utils_test.py,sha256=HNZstSm6-7xgSmM-7ilHRjuOKsq6tivpxayphm9Oghs,1756
|
|
70
70
|
ai_edge_quantizer/utils/test_utils.py,sha256=a4Nk-wbeB09dFjTDZiA0K67d26j5DD0UDH_GIVmVG_4,8685
|
|
71
|
-
ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=
|
|
71
|
+
ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=LN-WonrcJLP9bB4lULd5VIg_8YLTcp891ZuDZ5nDGe8,12006
|
|
72
72
|
ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=K1SbK8q92qYVtiVj0I0GtugsPTkpIpEKv9zakvFV_Sc,8555
|
|
73
73
|
ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=EoVjI_hplX_Rml3hfRsGmQOihexmizeJqt4SQcET9aA,14925
|
|
74
74
|
ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=6fjkM-rycZ95L4yfvlr0TN6RlrhfPzxNUYrZaYO_F0A,12013
|
|
75
75
|
ai_edge_quantizer/utils/validation_utils.py,sha256=yJH9Cvepr_XWn-3Hsh91j7HuC5iLQHAyskyQ48bGNoc,4797
|
|
76
76
|
ai_edge_quantizer/utils/validation_utils_test.py,sha256=1sblJWHLTYTbn1Qi9rwnrREOSXRy5KwHAWSwgI1e_aU,3697
|
|
77
|
-
ai_edge_quantizer_nightly-0.4.0.
|
|
78
|
-
ai_edge_quantizer_nightly-0.4.0.
|
|
79
|
-
ai_edge_quantizer_nightly-0.4.0.
|
|
80
|
-
ai_edge_quantizer_nightly-0.4.0.
|
|
81
|
-
ai_edge_quantizer_nightly-0.4.0.
|
|
77
|
+
ai_edge_quantizer_nightly-0.4.0.dev20251003.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
78
|
+
ai_edge_quantizer_nightly-0.4.0.dev20251003.dist-info/METADATA,sha256=4JgzYaleb7eOcjr1cR3DHEefcppO47fcd1Rt4YBEnsY,1508
|
|
79
|
+
ai_edge_quantizer_nightly-0.4.0.dev20251003.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
|
80
|
+
ai_edge_quantizer_nightly-0.4.0.dev20251003.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
|
|
81
|
+
ai_edge_quantizer_nightly-0.4.0.dev20251003.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|