ai-edge-quantizer-nightly 0.4.0.dev20251002__py3-none-any.whl → 0.4.0.dev20251003__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,6 +17,7 @@
17
17
 
18
18
  from collections.abc import Sequence
19
19
  import copy
20
+ import logging
20
21
 
21
22
  import numpy as np
22
23
 
@@ -24,10 +25,15 @@ from ai_edge_quantizer import qtyping
24
25
  from ai_edge_quantizer import transformation_instruction_generator
25
26
  from ai_edge_quantizer import transformation_performer
26
27
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
28
+ from ai_edge_quantizer.utils import tfl_interpreter_utils
29
+ from ai_edge_litert import interpreter as tfl # pylint: disable=g-direct-tensorflow-import
27
30
  from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
28
31
  from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
29
32
 
30
33
 
34
+ _DEQUANT_SUFFIX = "_dequant"
35
+
36
+
31
37
  class ModelModifier:
32
38
  """Model Modifier class that produce the final quantized TFlite model."""
33
39
 
@@ -105,10 +111,94 @@ class ModelModifier:
105
111
  )
106
112
  constant_buffer_size = self._process_constant_map(quantized_model)
107
113
  # we leave 256MB for the model architecture.
108
- if constant_buffer_size > 2**31 - 2**28:
109
- return self._serialize_large_model(quantized_model)
110
- else:
111
- return self._serialize_small_model(quantized_model)
114
+ serialize_fun = (
115
+ self._serialize_large_model
116
+ if constant_buffer_size > 2**31 - 2**28
117
+ else self._serialize_small_model
118
+ )
119
+ serialized_quantized_model = serialize_fun(quantized_model)
120
+
121
+ # Update signature defs if dequant is inserted before output.
122
+ if self._has_dequant_before_output(instructions):
123
+ quantized_model = self._update_signature_defs_for_dequant_output(
124
+ quantized_model, serialized_quantized_model
125
+ )
126
+ serialized_quantized_model = serialize_fun(quantized_model)
127
+
128
+ return serialized_quantized_model
129
+
130
+ def _update_signature_defs_for_dequant_output(
131
+ self, model: schema_py_generated.ModelT, serialized_model: bytearray
132
+ ):
133
+ """Updates the signature definitions in the model.
134
+
135
+ This function is called when a dequantize operation is inserted before
136
+ an output tensor. It updates the tensor index in the signature
137
+ definitions to point to the newly inserted dequantize output tensor.
138
+
139
+ Args:
140
+ model: The TFlite ModelT object.
141
+ serialized_model: The serialized bytearray of the TFlite model.
142
+
143
+ Returns:
144
+ The updated TFlite ModelT object.
145
+ """
146
+ interpreter = tfl.Interpreter(model_content=bytes(serialized_model))
147
+
148
+ for signature_def in model.signatureDefs:
149
+ signature_key = signature_def.signatureKey.decode("utf-8")
150
+ logging.info("Signature = %s", signature_key)
151
+ subgraph_idx = tfl_interpreter_utils.get_signature_main_subgraph_index(
152
+ interpreter, signature_key
153
+ )
154
+ output_details = interpreter.get_signature_runner(
155
+ signature_key
156
+ ).get_output_details()
157
+ subgraph = model.subgraphs[subgraph_idx]
158
+ graph_info = qtyping.GraphInfo(subgraph.tensors, model.buffers)
159
+
160
+ for output in subgraph.outputs:
161
+ tensor_name = tfl_flatbuffer_utils.get_tensor_name(
162
+ graph_info.subgraph_tensors[output]
163
+ )
164
+ logging.info("\tOutput tensor = `%s`", tensor_name)
165
+
166
+ for signature_name, tensor_details in output_details.items():
167
+ if tensor_details["name"] + _DEQUANT_SUFFIX == tensor_name:
168
+ logging.info(
169
+ "\t\tfound tensor mapping: `%s`->`%s` for signature name: `%s`",
170
+ tensor_details["name"],
171
+ tensor_name,
172
+ signature_name,
173
+ )
174
+ for signature_item in signature_def.outputs:
175
+ if signature_item.name.decode("utf-8") == signature_name:
176
+ signature_item.tensorIndex = output
177
+ logging.info(
178
+ "\t\t\tswapped tensor index: %s->%s",
179
+ tensor_details["index"],
180
+ output,
181
+ )
182
+ break
183
+ break
184
+
185
+ return model
186
+
187
+ def _has_dequant_before_output(
188
+ self, instructions: dict[str, qtyping.TensorTransformationInsts]
189
+ ) -> bool:
190
+ """Check if the model has dequant insert to output."""
191
+ for tensor_name, tensor_trans_insts in instructions.items():
192
+ for instr in tensor_trans_insts.instructions:
193
+ if (
194
+ qtyping.QuantTransformation.ADD_DEQUANTIZE == instr.transformation
195
+ and instr.consumers == [-1]
196
+ ):
197
+ logging.info(
198
+ "Found dequant insert to output for tensor: %s", tensor_name
199
+ )
200
+ return True
201
+ return False
112
202
 
113
203
  def _process_constant_map(
114
204
  self, quantized_model: schema_py_generated.ModelT
@@ -142,7 +232,7 @@ class ModelModifier:
142
232
  remainder = len(bytearr) % 16
143
233
  if remainder != 0:
144
234
  padding_size = 16 - remainder
145
- bytearr.extend(b'\0' * padding_size)
235
+ bytearr.extend(b"\0" * padding_size)
146
236
 
147
237
  # TODO: b/333797307 - support > 2GB output model
148
238
  def _serialize_large_model(
@@ -125,6 +125,86 @@ class ModelModifierTest(parameterized.TestCase):
125
125
  loosen_mem_use_factor = 4.5
126
126
  self.assertLess(mem_peak / len(self._model_content), loosen_mem_use_factor)
127
127
 
128
+ def test_has_dequant_before_output_true(self):
129
+ instructions = {
130
+ 'tensor1': qtyping.TensorTransformationInsts(
131
+ 'tensor1',
132
+ 0,
133
+ instructions=[
134
+ qtyping.TransformationInst(
135
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
136
+ tensor_id=0,
137
+ producer=0,
138
+ consumers=[-1],
139
+ )
140
+ ],
141
+ )
142
+ }
143
+ self.assertTrue(
144
+ self._model_modifier._has_dequant_before_output(instructions)
145
+ )
146
+
147
+ def test_has_dequant_before_output_false(self):
148
+ instructions = {
149
+ 'tensor1': qtyping.TensorTransformationInsts(
150
+ 'tensor1',
151
+ 0,
152
+ instructions=[
153
+ qtyping.TransformationInst(
154
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
155
+ tensor_id=0,
156
+ producer=0,
157
+ consumers=[1],
158
+ )
159
+ ],
160
+ )
161
+ }
162
+ self.assertFalse(
163
+ self._model_modifier._has_dequant_before_output(instructions)
164
+ )
165
+
166
+ def test_pad_bytearray(self):
167
+ arr = bytearray(b'\x01\x02\x03')
168
+ self._model_modifier._pad_bytearray(arr)
169
+ self.assertLen(arr, 16)
170
+ self.assertEqual(arr, b'\x01\x02\x03' + b'\0' * 13)
171
+
172
+ arr = bytearray(b'\x01' * 16)
173
+ self._model_modifier._pad_bytearray(arr)
174
+ self.assertLen(arr, 16)
175
+
176
+ arr = bytearray(b'\x01' * 17)
177
+ self._model_modifier._pad_bytearray(arr)
178
+ self.assertLen(arr, 32)
179
+
180
+
181
+ class ModelModifierTestWithSignature(parameterized.TestCase):
182
+
183
+ def setUp(self):
184
+ super().setUp()
185
+ self._model_path = os.path.join(
186
+ TEST_DATA_PREFIX_PATH,
187
+ 'tests/models/single_fc.tflite',
188
+ )
189
+ self._model_content: bytes = tfl_flatbuffer_utils.get_model_content(
190
+ self._model_path
191
+ )
192
+ self._model_modifier = model_modifier.ModelModifier(self._model_content)
193
+
194
+ def test_update_signature_defs_for_dequant_output_succeeds(self):
195
+ # This is a simplified test that only checks if the function runs without
196
+ # crashing and returns a model. A more thorough test with a model
197
+ # with a known signature was added in `quantizer_test`.
198
+ model_bytearray = flatbuffer_utils.read_model_from_bytearray(
199
+ self._model_content
200
+ )
201
+ updated_model = (
202
+ self._model_modifier._update_signature_defs_for_dequant_output(
203
+ model_bytearray, bytearray(self._model_content)
204
+ )
205
+ )
206
+ self.assertIsNotNone(updated_model)
207
+
128
208
 
129
209
  if __name__ == '__main__':
130
210
  googletest.main()
@@ -51,6 +51,30 @@ def _get_calibration_data(num_samples: int = 16):
51
51
  return calibration_data
52
52
 
53
53
 
54
+ def _is_all_signature_defs_inputs_float(model_content: bytes):
55
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_content)
56
+ for signature_key in tfl_interpreter.get_signature_list():
57
+ input_details = tfl_interpreter.get_signature_runner(
58
+ signature_key
59
+ ).get_input_details()
60
+ for tensor_details in input_details.values():
61
+ if tensor_details['dtype'] != np.float32:
62
+ return False
63
+ return True
64
+
65
+
66
+ def _is_all_signature_defs_outputs_float(model_content: bytes):
67
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_content)
68
+ for signature_key in tfl_interpreter.get_signature_list():
69
+ output_details = tfl_interpreter.get_signature_runner(
70
+ signature_key
71
+ ).get_output_details()
72
+ for tensor_details in output_details.values():
73
+ if tensor_details['dtype'] != np.float32:
74
+ return False
75
+ return True
76
+
77
+
54
78
  class QuantizerTest(parameterized.TestCase):
55
79
 
56
80
  def setUp(self):
@@ -547,21 +571,21 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
547
571
  'signature_1': [{
548
572
  'cache_0': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
549
573
  'cache_1': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
550
- 'positions': _RNG.integers(low=0, high=10, size=(1, 100)).astype(
551
- np.int32
574
+ 'positions': (
575
+ _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
552
576
  ),
553
- 'tokens': _RNG.integers(low=0, high=10, size=(1, 100)).astype(
554
- np.int32
577
+ 'tokens': (
578
+ _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
555
579
  ),
556
580
  }],
557
581
  'signature_2': [{
558
582
  'cache_0': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
559
583
  'cache_1': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
560
- 'positions': _RNG.integers(low=0, high=10, size=(1, 100)).astype(
561
- np.int32
584
+ 'positions': (
585
+ _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
562
586
  ),
563
- 'tokens': _RNG.integers(low=0, high=10, size=(1, 100)).astype(
564
- np.int32
587
+ 'tokens': (
588
+ _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
565
589
  ),
566
590
  }],
567
591
  }
@@ -578,8 +602,8 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
578
602
  )
579
603
 
580
604
  self._quantizer.update_quantization_recipe(
581
- regex='StatefulPartitionedCall',
582
- operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
605
+ regex='.*',
606
+ operation_name=qtyping.TFLOperationName.OUTPUT,
583
607
  algorithm_key=_AlgorithmName.NO_QUANTIZE,
584
608
  )
585
609
 
@@ -591,6 +615,90 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
591
615
  self._quantizer.quantize(calib_result)
592
616
  self.assertIsNotNone(self._quantizer._result.quantized_model)
593
617
 
618
+ def test_toy_gemma2_update_signature_defs_succeeds(self):
619
+
620
+ self.assertTrue(
621
+ _is_all_signature_defs_outputs_float(
622
+ open(self._test_model_path, 'rb').read()
623
+ )
624
+ )
625
+ calib_result = self._quantizer.calibrate(
626
+ self._toy_gemma2_calibration_dataset
627
+ )
628
+ self.assertIsNotNone(calib_result)
629
+ self._quantizer.quantize(calib_result)
630
+ self.assertIsNotNone(self._quantizer._result.quantized_model)
631
+ self.assertTrue(
632
+ _is_all_signature_defs_outputs_float(
633
+ self._quantizer._result.quantized_model
634
+ )
635
+ )
636
+
637
+
638
+ class QuantizerFullyConnectedTest(parameterized.TestCase):
639
+
640
+ def setUp(self):
641
+ super().setUp()
642
+ self._tmp_save_path = self.create_tempdir().full_path
643
+ self._test_model_path = os.path.join(
644
+ TEST_DATA_PREFIX_PATH,
645
+ 'tests/models/single_fc.tflite',
646
+ )
647
+
648
+ self._test_recipe_path = os.path.join(
649
+ TEST_DATA_PREFIX_PATH,
650
+ 'recipes/default_a8w8_recipe.json',
651
+ )
652
+ with open(self._test_recipe_path) as json_file:
653
+ self._test_recipe = json.load(json_file)
654
+
655
+ self._quantizer = quantizer.Quantizer(
656
+ self._test_model_path, self._test_recipe_path
657
+ )
658
+
659
+ self._quantizer.update_quantization_recipe(
660
+ regex='.*',
661
+ operation_name=qtyping.TFLOperationName.INPUT,
662
+ algorithm_key=_AlgorithmName.NO_QUANTIZE,
663
+ )
664
+ self._quantizer.update_quantization_recipe(
665
+ regex='.*',
666
+ operation_name=qtyping.TFLOperationName.OUTPUT,
667
+ algorithm_key=_AlgorithmName.NO_QUANTIZE,
668
+ )
669
+
670
+ def test_fully_connected_quantization_succeeds(self):
671
+ calib_result = self._quantizer.calibrate(
672
+ tfl_interpreter_utils.create_random_normal_input_data(
673
+ self._test_model_path, num_samples=4
674
+ )
675
+ )
676
+ self.assertIsNotNone(calib_result)
677
+ self._quantizer.quantize(calib_result)
678
+ self.assertIsNotNone(self._quantizer._result.quantized_model)
679
+
680
+ def test_fully_connected_quantization_update_signature_defs_succeeds(self):
681
+
682
+ model_content = open(self._test_model_path, 'rb').read()
683
+ self.assertTrue(_is_all_signature_defs_inputs_float(model_content))
684
+ self.assertTrue(_is_all_signature_defs_outputs_float(model_content))
685
+
686
+ calib_result = self._quantizer.calibrate(
687
+ tfl_interpreter_utils.create_random_normal_input_data(
688
+ self._test_model_path, num_samples=4
689
+ )
690
+ )
691
+ self.assertIsNotNone(calib_result)
692
+ quant_result = self._quantizer.quantize(calib_result)
693
+ self.assertIsNotNone(quant_result.quantized_model)
694
+
695
+ self.assertTrue(
696
+ _is_all_signature_defs_inputs_float(quant_result.quantized_model)
697
+ )
698
+ self.assertTrue(
699
+ _is_all_signature_defs_outputs_float(quant_result.quantized_model)
700
+ )
701
+
594
702
 
595
703
  if __name__ == '__main__':
596
704
  googletest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-quantizer-nightly
3
- Version: 0.4.0.dev20251002
3
+ Version: 0.4.0.dev20251003
4
4
  Summary: A quantizer for advanced developers to quantize converted AI Edge models.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
6
6
  Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
@@ -6,15 +6,15 @@ ai_edge_quantizer/calibrator.py,sha256=Sms7_AIHPH9G5xFaz5Ef3a5gPhxuIWQI8d2LUM8C9
6
6
  ai_edge_quantizer/calibrator_test.py,sha256=ZLzIMWB2FSFU4TOatDioYuwp_kLh8iSCefZ5_Q9FU7s,11900
7
7
  ai_edge_quantizer/conftest.py,sha256=SxCz-5LlRD_lQm4hQc4c6IGG7DS8d7IyEWY9gnscPN0,794
8
8
  ai_edge_quantizer/default_policy.py,sha256=6eJA0eX5Npv8lw_0EDS5iPldInoURQKEDhDZ272VG1Q,11770
9
- ai_edge_quantizer/model_modifier.py,sha256=teGa8I6kGvn6TQY6Xv53YFIc_pQEhNvM9Zb4bvhezyw,7110
10
- ai_edge_quantizer/model_modifier_test.py,sha256=cJd04SLOG-fQZZNZPcisoBLx3cLtWEwGqUBbLb-pif4,4751
9
+ ai_edge_quantizer/model_modifier.py,sha256=U70JByv6CItP8tg4bdyMfX-R3UlwylAGSviZkF_FSAM,10468
10
+ ai_edge_quantizer/model_modifier_test.py,sha256=CV4pgMEQkBJr_qbYR720TO8HBCutbEYLHptDHgdQMUE,7274
11
11
  ai_edge_quantizer/model_validator.py,sha256=Hj0_5o-Oa3dSlJ3ryVjRhvsyelHNyek1GrtG9buMczg,13153
12
12
  ai_edge_quantizer/model_validator_test.py,sha256=EeqOP_mrZsnZ3rug756s0ryDDqd2KgIDld5Lm_gDuWY,13020
13
13
  ai_edge_quantizer/params_generator.py,sha256=0w-sDGk84sVNkXoduon1wDqq30sGOHVgBVbdg44QVF4,20153
14
14
  ai_edge_quantizer/params_generator_test.py,sha256=RDYoRZDJfEZRtjlTAU2kZ_4t3JHOqEHxfJX9V4ETAhg,40597
15
15
  ai_edge_quantizer/qtyping.py,sha256=7aEMPA4qr4CGD3NXtZgG2fDoQX5NzK9jwSv1yWNqQV4,17149
16
16
  ai_edge_quantizer/quantizer.py,sha256=ckAEOnnBxuCKZuvlzdChevCKPuE-IeDPHCNtFTWr250,17857
17
- ai_edge_quantizer/quantizer_test.py,sha256=m6f4ayyaF3yQb9i4V0aFAbmGw0OKZ2Zam1RoTPh-u24,22917
17
+ ai_edge_quantizer/quantizer_test.py,sha256=bh4IowxRF249p_XKIKQ0f17PmeDddfcOUzvQ2ht1L0E,26530
18
18
  ai_edge_quantizer/recipe.py,sha256=MEkfQ2Sg3KAE9LAORHWcbjYNPg06EUbwc1d-VspQA2U,6461
19
19
  ai_edge_quantizer/recipe_manager.py,sha256=6l2uq8KL23KLu9OQDmPGkxrFiwHrdDB9xnn-ni8WdEM,15036
20
20
  ai_edge_quantizer/recipe_manager_test.py,sha256=qjgGUF-wggXnSXqZ5khmqrDMIQI5CShk52IVWTahq6s,36817
@@ -74,8 +74,8 @@ ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=EoVjI_hplX_Rml3hfRsGmQOi
74
74
  ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=6fjkM-rycZ95L4yfvlr0TN6RlrhfPzxNUYrZaYO_F0A,12013
75
75
  ai_edge_quantizer/utils/validation_utils.py,sha256=yJH9Cvepr_XWn-3Hsh91j7HuC5iLQHAyskyQ48bGNoc,4797
76
76
  ai_edge_quantizer/utils/validation_utils_test.py,sha256=1sblJWHLTYTbn1Qi9rwnrREOSXRy5KwHAWSwgI1e_aU,3697
77
- ai_edge_quantizer_nightly-0.4.0.dev20251002.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
78
- ai_edge_quantizer_nightly-0.4.0.dev20251002.dist-info/METADATA,sha256=gx_gBIYVh7XDUrBl-uDmPRRRrawHIroH_14pjZmhL4w,1508
79
- ai_edge_quantizer_nightly-0.4.0.dev20251002.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
80
- ai_edge_quantizer_nightly-0.4.0.dev20251002.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
81
- ai_edge_quantizer_nightly-0.4.0.dev20251002.dist-info/RECORD,,
77
+ ai_edge_quantizer_nightly-0.4.0.dev20251003.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
78
+ ai_edge_quantizer_nightly-0.4.0.dev20251003.dist-info/METADATA,sha256=4JgzYaleb7eOcjr1cR3DHEefcppO47fcd1Rt4YBEnsY,1508
79
+ ai_edge_quantizer_nightly-0.4.0.dev20251003.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
80
+ ai_edge_quantizer_nightly-0.4.0.dev20251003.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
81
+ ai_edge_quantizer_nightly-0.4.0.dev20251003.dist-info/RECORD,,