ai-edge-quantizer-nightly 0.3.0.dev20250622__py3-none-any.whl → 0.3.0.dev20250624__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -113,6 +113,7 @@ MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
113
113
  ),
114
114
  _TFLOpName.MAX_POOL_2D: common_quantize.materialize_max_pool_2d,
115
115
  _TFLOpName.RESIZE_BILINEAR: common_quantize.materialize_resize_bilinear,
116
+ _TFLOpName.GATHER_ND: common_quantize.materialize_gather_nd,
116
117
  }
117
118
  for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
118
119
  register_quantized_op(
@@ -252,6 +253,7 @@ _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
252
253
  ),
253
254
  _TFLOpName.MAX_POOL_2D: common_quantize.materialize_max_pool_2d,
254
255
  _TFLOpName.RESIZE_BILINEAR: common_quantize.materialize_resize_bilinear,
256
+ _TFLOpName.GATHER_ND: common_quantize.materialize_gather_nd,
255
257
  })
256
258
 
257
259
  for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
@@ -745,6 +745,23 @@ def materialize_resize_bilinear(
745
745
  )
746
746
 
747
747
 
748
+ def materialize_gather_nd(
749
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
750
+ op_info: qtyping.OpInfo,
751
+ graph_info: qtyping.GraphInfo,
752
+ tensor_name_to_qsv: dict[str, Any],
753
+ ) -> list[qtyping.TensorTransformationParams]:
754
+ """Materialize tensors in tfl.gather_nd."""
755
+ return common_utils.materialize_standard_op(
756
+ op_info,
757
+ graph_info,
758
+ tensor_name_to_qsv,
759
+ get_tensor_quant_params_fn,
760
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
761
+ inputs_to_ignore=[1], # Gather indices do not need to be quantized.
762
+ )
763
+
764
+
748
765
  def _get_tensor_shape_for_blockwise(
749
766
  tensor_shape: Sequence[int], quantized_dim: int, block_size: int
750
767
  ) -> list[int]:
@@ -293,17 +293,46 @@ def _materialize_standard_op_with_same_as_input_scale(
293
293
  get_tensor_quant_params_fn=get_tensor_quant_params_fn,
294
294
  )
295
295
  op_tensor_params.append(input_tensor_params)
296
- # Use input quantization params for all output tensors.
297
- _materialize_op_tensors(
298
- op_tensor_params,
299
- output_tensors,
300
- is_inbounding_tensor=False,
301
- op_info=op_info,
302
- graph_info=graph_info,
303
- tensor_name_to_qsv=tensor_name_to_qsv,
304
- get_tensor_quant_params_fn=get_tensor_quant_params_fn,
305
- quant_params=input_tensor_params.consumers[0].parameters,
296
+ # Use input quantization params for all output tensors but without
297
+ # quantized_data in case the input is a constant tensor.
298
+ input_quant_params = dataclasses.replace(
299
+ input_tensor_params.consumers[0].parameters,
300
+ quantized_data=None,
306
301
  )
302
+ if not isinstance(input_quant_params, qtyping.UniformQuantParams):
303
+ raise ValueError(
304
+ "_materialize_standard_op_with_same_as_input_scale only supports"
305
+ f" UniformQuantParams. For tensor {input_tensor_params.tensor_name},"
306
+ f" got {type(input_quant_params)}"
307
+ )
308
+ # Materialize each of the output tensors separately in case there are
309
+ # constants among them, requiring updating `quantized_data` first.
310
+ for output_tensor in output_tensors:
311
+ output_tensor_data = tfl_flatbuffer_utils.get_tensor_data(
312
+ output_tensor, graph_info.buffers
313
+ )
314
+ # Quantize constant inputs' data with the output quantization params.
315
+ if output_tensor_data is None:
316
+ quant_params = input_quant_params
317
+ else:
318
+ quantized_data = uniform_quantize_tensor.uniform_quantize(
319
+ output_tensor_data, input_quant_params
320
+ )
321
+ quant_params = dataclasses.replace(
322
+ input_quant_params,
323
+ quantized_data=quantized_data,
324
+ )
325
+ _materialize_op_tensors(
326
+ op_tensor_params,
327
+ [output_tensor],
328
+ is_inbounding_tensor=False,
329
+ op_info=op_info,
330
+ graph_info=graph_info,
331
+ tensor_name_to_qsv=tensor_name_to_qsv,
332
+ get_tensor_quant_params_fn=get_tensor_quant_params_fn,
333
+ quant_params=quant_params,
334
+ )
335
+
307
336
  # Change output qsv to be the same as input qsv. This is safe since TFL
308
337
  # subgraph is acyclic.
309
338
  input_tensor_qsv = tensor_name_to_qsv[input_tensor_params.tensor_name]
@@ -841,13 +870,6 @@ def get_tensor_transformations(
841
870
  transformations = [_QuantTransformation.QUANTIZE_TENSOR]
842
871
  else:
843
872
  transformations = [_QuantTransformation.NO_QUANTIZE]
844
- elif (
845
- op_quant_config.weight_tensor_config is not None
846
- and op_quant_config.weight_tensor_config.granularity
847
- == qtyping.QuantGranularity.BLOCKWISE
848
- and is_constant
849
- ):
850
- transformations = [_QuantTransformation.EMULATED_SUBCHANNEL]
851
873
  # Check if WEIGHT_ONLY.
852
874
  elif (
853
875
  op_quant_config.compute_precision == qtyping.ComputePrecision.FLOAT
@@ -186,7 +186,8 @@ DEFAULT_JSON_POLICY = """
186
186
  "STABLEHLO_COMPOSITE",
187
187
  "PAD",
188
188
  "MAX_POOL_2D",
189
- "RESIZE_BILINEAR"
189
+ "RESIZE_BILINEAR",
190
+ "GATHER_ND"
190
191
  ],
191
192
  "static_wi8_ai8": [
192
193
  "ADD",
@@ -221,7 +222,8 @@ DEFAULT_JSON_POLICY = """
221
222
  "PAD",
222
223
  "SQUARED_DIFFERENCE",
223
224
  "MAX_POOL_2D",
224
- "RESIZE_BILINEAR"
225
+ "RESIZE_BILINEAR",
226
+ "GATHER_ND"
225
227
  ],
226
228
  "static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
227
229
  "static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
@@ -66,6 +66,7 @@ class TFLOperationName(str, enum.Enum):
66
66
  SQUARED_DIFFERENCE = 'SQUARED_DIFFERENCE'
67
67
  MAX_POOL_2D = 'MAX_POOL_2D'
68
68
  RESIZE_BILINEAR = 'RESIZE_BILINEAR'
69
+ GATHER_ND = 'GATHER_ND'
69
70
 
70
71
 
71
72
  class QuantizeMode(enum.Enum):
@@ -110,8 +111,8 @@ class QuantTransformation(enum.Enum):
110
111
  ADD_DEQUANTIZE = 2
111
112
  # Quantize the float tensor: float_tensor -> quantized_tensor.
112
113
  QUANTIZE_TENSOR = 3
113
- # Create pattern for emulated subchannel quantization, only support fully
114
- # connected op.
114
+ # (Deprecated) Create pattern for emulated subchannel quantization,
115
+ # only support fully connected op.
115
116
  EMULATED_SUBCHANNEL = 4
116
117
  # Duplicate the buffer.
117
118
  DUPLICATE_BUFFER = 5
@@ -673,7 +673,6 @@ class TransformationInstructionsGenerator:
673
673
  """
674
674
  is_tensor_unquantized = False
675
675
  is_tensor_quantized = False
676
- is_operator_emulated = False
677
676
  for instruction in instructions:
678
677
  transform_type = instruction.transformation
679
678
  if transform_type == qtyping.QuantTransformation.NO_QUANTIZE:
@@ -683,17 +682,10 @@ class TransformationInstructionsGenerator:
683
682
  or transform_type == qtyping.QuantTransformation.ADD_DEQUANTIZE
684
683
  ):
685
684
  is_tensor_quantized = True
686
- elif transform_type == qtyping.QuantTransformation.EMULATED_SUBCHANNEL:
687
- is_operator_emulated = True
688
685
  if is_tensor_unquantized and is_tensor_quantized:
689
686
  raise ValueError(
690
687
  "Tensor %s can not be both quantized and unquantized" % tensor_name
691
688
  )
692
- if is_operator_emulated and len(instructions) > 1:
693
- raise ValueError(
694
- "Tensor %s : op replacement transformation can not be combined with"
695
- " other transformations." % tensor_name
696
- )
697
689
 
698
690
  def _check_tensor_transformation_instructions_valid(
699
691
  self,
@@ -953,33 +953,6 @@ class InstructionGeneratorTest(parameterized.TestCase):
953
953
  instructions["StatefulPartitionedCall:0"], output_transformation
954
954
  )
955
955
 
956
- def test_raise_error_on_op_replacement_transformation_is_not_unique(self):
957
- test_model_path = os.path.join(
958
- TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
959
- )
960
- quant_parameters = {}
961
- quant_parameters["tfl.quantize"] = qtyping.TensorTransformationParams(
962
- "tfl.quantize",
963
- qtyping.OpToTensorParams(
964
- subgraph_op_id=0,
965
- transformations=[
966
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
967
- qtyping.QuantTransformation.EMULATED_SUBCHANNEL,
968
- ],
969
- parameters=qtyping.UniformQuantParams(
970
- 8, None, np.array([1]), np.array([0])
971
- ),
972
- ),
973
- [],
974
- )
975
- ins_gen = instruction_generator.TransformationInstructionsGenerator(
976
- test_model_path
977
- )
978
- with self.assertRaisesRegex(
979
- ValueError, "op replacement transformation can not be combined"
980
- ):
981
- ins_gen.quant_params_to_transformation_insts(quant_parameters)
982
-
983
956
  def test_raise_error_on_no_quant_conflict(self):
984
957
  test_model_path = os.path.join(
985
958
  TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
@@ -24,7 +24,6 @@ from ai_edge_quantizer import qtyping
24
24
  from ai_edge_quantizer.transformations import dequant_insert
25
25
  from ai_edge_quantizer.transformations import duplicate_buffer
26
26
  from ai_edge_quantizer.transformations import duplicate_tensor
27
- from ai_edge_quantizer.transformations import emulated_subchannel
28
27
  from ai_edge_quantizer.transformations import insert_hadamard_rotation
29
28
  from ai_edge_quantizer.transformations import quant_insert
30
29
  from ai_edge_quantizer.transformations import quantize_tensor
@@ -72,7 +71,7 @@ class TransformationPerformer:
72
71
  quantize_tensor.quantize_tensor
73
72
  ),
74
73
  qtyping.QuantTransformation.EMULATED_SUBCHANNEL: (
75
- emulated_subchannel.emulated_subchannel
74
+ transformation_utils.raise_deprecated_error
76
75
  ),
77
76
  qtyping.QuantTransformation.ADD_QUANTIZE: quant_insert.insert_quant,
78
77
  qtyping.QuantTransformation.DUPLICATE_BUFFER: (
@@ -203,3 +203,10 @@ def add_new_activation_tensor(
203
203
  new_tensor_id = len(subgraph.tensors)
204
204
  subgraph.tensors.append(new_tensor)
205
205
  return new_tensor_id
206
+
207
+
208
+ def raise_deprecated_error(_: TransformationInput):
209
+ raise NotImplementedError(
210
+ 'This transformation is deprecated. Please contact AI Edge Quantizer team'
211
+ ' if you see this error.'
212
+ )
@@ -60,6 +60,7 @@ TFL_OP_NAME_TO_CODE = immutabledict.immutabledict({
60
60
  _TFLOpName.SQUARED_DIFFERENCE: schema.BuiltinOperator.SQUARED_DIFFERENCE,
61
61
  _TFLOpName.MAX_POOL_2D: schema.BuiltinOperator.MAX_POOL_2D,
62
62
  _TFLOpName.RESIZE_BILINEAR: schema.BuiltinOperator.RESIZE_BILINEAR,
63
+ _TFLOpName.GATHER_ND: schema.BuiltinOperator.GATHER_ND,
63
64
  })
64
65
 
65
66
  TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-quantizer-nightly
3
- Version: 0.3.0.dev20250622
3
+ Version: 0.3.0.dev20250624
4
4
  Summary: A quantizer for advanced developers to quantize converted AI Edge models.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
6
6
  Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
@@ -1,34 +1,34 @@
1
1
  ai_edge_quantizer/__init__.py,sha256=4pFSkukSwahYyzwqia0yPRyz8TnFQfGRthVJhYpMWas,793
2
- ai_edge_quantizer/algorithm_manager.py,sha256=rMTM89YDPkmLKlUQV_Rjr7B2KpcvldAHzfpgUqaOqdU,12216
2
+ ai_edge_quantizer/algorithm_manager.py,sha256=UZVS6ZClIAyaX9RzhXvvymbQv_scR0ybMPYl2CgSPVo,12346
3
3
  ai_edge_quantizer/algorithm_manager_api.py,sha256=u903TG0s1uIDhJqfeJne3CFl8A93phZrwgV2-hwdcXU,9247
4
4
  ai_edge_quantizer/algorithm_manager_api_test.py,sha256=w6bSONvXkX6bzXAGc0-7b6gNDt9oz9ieq97KP8Sg_JU,7666
5
5
  ai_edge_quantizer/calibrator.py,sha256=Sms7_AIHPH9G5xFaz5Ef3a5gPhxuIWQI8d2LUM8C96I,12071
6
6
  ai_edge_quantizer/calibrator_test.py,sha256=C_oWOaRugPKYX74jF-eRFH-k6nGOdA8I9_uPiocaOuE,11900
7
7
  ai_edge_quantizer/conftest.py,sha256=SxCz-5LlRD_lQm4hQc4c6IGG7DS8d7IyEWY9gnscPN0,794
8
- ai_edge_quantizer/default_policy.py,sha256=zghBh9dTB-ouPFumV-0siBSnEbp0WxF6tGOsn3TLirg,11242
8
+ ai_edge_quantizer/default_policy.py,sha256=0Am2TrgyV7gNl7dbul07rVp58OKDuPyJW9SIqRTrD2g,11280
9
9
  ai_edge_quantizer/model_modifier.py,sha256=teGa8I6kGvn6TQY6Xv53YFIc_pQEhNvM9Zb4bvhezyw,7110
10
10
  ai_edge_quantizer/model_modifier_test.py,sha256=cJd04SLOG-fQZZNZPcisoBLx3cLtWEwGqUBbLb-pif4,4751
11
11
  ai_edge_quantizer/model_validator.py,sha256=Hj0_5o-Oa3dSlJ3ryVjRhvsyelHNyek1GrtG9buMczg,13153
12
12
  ai_edge_quantizer/model_validator_test.py,sha256=EeqOP_mrZsnZ3rug756s0ryDDqd2KgIDld5Lm_gDuWY,13020
13
13
  ai_edge_quantizer/params_generator.py,sha256=gC7G6Ne4Fumc8RSmIAbx96ZBhszZlHqBKSmE9p6RPTo,20099
14
14
  ai_edge_quantizer/params_generator_test.py,sha256=RDYoRZDJfEZRtjlTAU2kZ_4t3JHOqEHxfJX9V4ETAhg,40597
15
- ai_edge_quantizer/qtyping.py,sha256=kX1AoD-YlHYbDI1RfGVXIbPn-CYT7HUF2x77-hPtKBM,16565
15
+ ai_edge_quantizer/qtyping.py,sha256=vq-9jwDViSndHhcC1_RVu2Bk0qu5MgYPGLTRO9z2Naw,16604
16
16
  ai_edge_quantizer/quantizer.py,sha256=g3DMqFMrMpt9jQttCE0WcdNbMtk0JZnmN5MmCHrNdyM,13202
17
17
  ai_edge_quantizer/quantizer_test.py,sha256=K_HBA56JkFI3HL8VLWCqGEfC0ISh5ldMKoNyBdGRAJg,20368
18
18
  ai_edge_quantizer/recipe.py,sha256=FR0uJceumZrnle2VRSOQZ1uXup4S1cTYKRH-N53mWRo,2919
19
19
  ai_edge_quantizer/recipe_manager.py,sha256=qcGUD7e7BISKdsY9WH2rdaRR3acmzSA5qMezGNbzlpo,8931
20
20
  ai_edge_quantizer/recipe_manager_test.py,sha256=GVOfGFZPRciUb4EF4GkSi6d96LdjS6PbUkAJ0ayy0k8,32243
21
21
  ai_edge_quantizer/recipe_test.py,sha256=Fg_sfxovI2fRjk5qdu18ghOvXdUvhDR1TxbE0GHDczc,3381
22
- ai_edge_quantizer/transformation_instruction_generator.py,sha256=B_TQQe9_Qs7UKXLjMMuz5lORUvXyZOxBS2SpntTnkI8,28077
23
- ai_edge_quantizer/transformation_instruction_generator_test.py,sha256=E0QSDCav6N6izlJ-a1ZJOsb2VEUxuxBmTbt0-EgDdxY,49890
24
- ai_edge_quantizer/transformation_performer.py,sha256=nkkqbs81ITB5u2FoWeG9z5d8EtLtCiltOxcQ34okN8E,13091
22
+ ai_edge_quantizer/transformation_instruction_generator.py,sha256=iMGXy7_ufqgQRzu4drAfO31VGdze35peEFh1BMZlVHk,27714
23
+ ai_edge_quantizer/transformation_instruction_generator_test.py,sha256=Zw3EOSnvzjuB4NWeo129eJZxK_EHno9oF9OtEQ-0dnM,48905
24
+ ai_edge_quantizer/transformation_performer.py,sha256=o4J6OUbI0dLoobVYjkOFw5Po3yH0gZJXrfuTIYais4o,13029
25
25
  ai_edge_quantizer/transformation_performer_test.py,sha256=xk6A3LStCyPclN51--9uO7XjSxNfZmpdfvrzOL0maNM,20349
26
26
  ai_edge_quantizer/algorithms/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
27
27
  ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
28
28
  ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py,sha256=Bs9CK7wZAw6jNaZ8xEtbwO2vM34VYXNZSMVWvxJo9nw,9297
29
29
  ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py,sha256=EqIHGEZ1LgUrTN7zf880RuAzEv3Qy7kgh5ivObJGHSo,22646
30
30
  ai_edge_quantizer/algorithms/uniform_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
31
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=rImKK2ax7LrRx6XurSdvRTk0h6WtFGtQn9sYNJcn-uw,30222
31
+ ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=juXQWnzBTvasxydeDTNxWuE9ag9j6GOmfHMjC4JQu1Y,30799
32
32
  ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py,sha256=GGf_n3wIeg3GB_eGsmyNJ0fTcxgpeMMbugTMRONK6TQ,3553
33
33
  ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=BDdn_uBZakfHyzdMJPKadsOqxqyC-s6W2ZzFH99L4fE,8652
34
34
  ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py,sha256=sT5eX5TLZEHTtPfnSkCPDlS0sQxlTFWbCsbvOuj--yY,8889
@@ -41,7 +41,7 @@ ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py,sha256=sha1d99Xk87bI
41
41
  ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py,sha256=W2QbXP96xeleAmA7qFwco1iq_bOtArGDK6Qj_g6kNl8,15986
42
42
  ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py,sha256=MgG7Qh2_z4I6InBqEEDSVlaR0q48aMz4xqAlxeG2EMk,12436
43
43
  ai_edge_quantizer/algorithms/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
44
- ai_edge_quantizer/algorithms/utils/common_utils.py,sha256=UoZxeAQmZk3b3hK51KFwq6XfdbeduXVjdYIxAxlAzB8,34982
44
+ ai_edge_quantizer/algorithms/utils/common_utils.py,sha256=QrEeCuvA7gY_vK1nbKtqassNDClyAjN1ClZIiw63k5U,35895
45
45
  ai_edge_quantizer/algorithms/utils/common_utils_test.py,sha256=zqapGEfYhjQWe9cNGPLmdbwtEUUYQRhlO_kNe0cXX6E,18104
46
46
  ai_edge_quantizer/transformations/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
47
47
  ai_edge_quantizer/transformations/dequant_insert.py,sha256=sL1LHFVzBDSd9jgrzlHz38LWU0bwmVX7iBkaNcui0ts,3566
@@ -50,28 +50,26 @@ ai_edge_quantizer/transformations/duplicate_buffer.py,sha256=TvTHbm24IiICNkWOlvR
50
50
  ai_edge_quantizer/transformations/duplicate_buffer_test.py,sha256=YYWl3Q5WF60s8T8pLzzA8TCSxz-i7dqc03dJt1LtMw4,3880
51
51
  ai_edge_quantizer/transformations/duplicate_tensor.py,sha256=WKhf2LIAL0MnZe88b6942A37lvHXe1cFjUDqE5VNmvU,2490
52
52
  ai_edge_quantizer/transformations/duplicate_tensor_test.py,sha256=s-RqSxNBMfVJyCunXz2eb7-KA6UiBmbOmL7phLslENQ,5056
53
- ai_edge_quantizer/transformations/emulated_subchannel.py,sha256=HVaRxoC8PCAvy3xeMv3OIymukUy_yW1zK0xN8Ann6I4,13602
54
- ai_edge_quantizer/transformations/emulated_subchannel_test.py,sha256=gZP6u9NdPXl7s19qB_Un8evou9ZZV6I9Gy0E1rdobHM,7722
55
53
  ai_edge_quantizer/transformations/insert_hadamard_rotation.py,sha256=rBbKgcVKHie38NT2UQ7KQ1xCb2tRu_rVl0yFloOAW_A,7562
56
54
  ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py,sha256=iV1p3nZfHUATV2YRoBOYurnu3pLy8n3aFppLWGQOPdA,7268
57
55
  ai_edge_quantizer/transformations/quant_insert.py,sha256=jn6HsJaV-sqBiFPY-Aqbd64t8zgcYVkEkZI375x_FWY,3958
58
56
  ai_edge_quantizer/transformations/quant_insert_test.py,sha256=X9ptPDvJCFkR5tejKnD1SlHFGPazQTW-wNNMV9MEAuw,10107
59
57
  ai_edge_quantizer/transformations/quantize_tensor.py,sha256=kjaNrw9mnrn0t8u0vey9S_uPz3iVUicwy4rluxVqV3E,7617
60
58
  ai_edge_quantizer/transformations/quantize_tensor_test.py,sha256=mHLO3_MRt36A8-ZN8ADn5tBBJlqjTWa7ZUN8Mmu5Rcw,9116
61
- ai_edge_quantizer/transformations/transformation_utils.py,sha256=GwIaKVsePZYgVG2lSanOswcaZYMjvgyqstDVwXl9DGY,6923
59
+ ai_edge_quantizer/transformations/transformation_utils.py,sha256=efJdAkA24wlg6Vj5NFO7_7MDuvQLSNn-l11Vs_JPktI,7123
62
60
  ai_edge_quantizer/transformations/transformation_utils_test.py,sha256=MWgq29t7rvxRQIfi4ny9IoODFCTcbpjnIwoCL40zDKk,8698
63
61
  ai_edge_quantizer/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
64
62
  ai_edge_quantizer/utils/calibration_utils.py,sha256=e3dG7Nm94Ix0hkTWTWPUhEG6a8QR_cAM3PSwblfJV5g,15106
65
63
  ai_edge_quantizer/utils/calibration_utils_test.py,sha256=4BlksXl7b4yptL8xPR67hmJCnjhN9V10a2PunzfHrUE,9372
66
64
  ai_edge_quantizer/utils/test_utils.py,sha256=spqUmSNciOKPQHCBkHE7Zo34eMFq_BfBCAnMT3jAulU,8615
67
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=pZv8FMWyjBSLN5MGJ2K_dZ6oqkJGbp9RI4CfnlPuPII,10830
65
+ ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=NnD57Gkx9upNP8Mso-_yp8Z3x1AqlIWb06jPg-hyRkc,10890
68
66
  ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=K1SbK8q92qYVtiVj0I0GtugsPTkpIpEKv9zakvFV_Sc,8555
69
67
  ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=vTyy6-4PgfFPL3C8uTq_iPFBwdxCjhrWzUiec4DdFPw,14323
70
68
  ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=6fjkM-rycZ95L4yfvlr0TN6RlrhfPzxNUYrZaYO_F0A,12013
71
69
  ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
72
70
  ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
73
- ai_edge_quantizer_nightly-0.3.0.dev20250622.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
74
- ai_edge_quantizer_nightly-0.3.0.dev20250622.dist-info/METADATA,sha256=0-WpgPHWtwW_Wvysp7yPgMXb6nNP6sXI-vJphPlKrBs,1528
75
- ai_edge_quantizer_nightly-0.3.0.dev20250622.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
76
- ai_edge_quantizer_nightly-0.3.0.dev20250622.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
77
- ai_edge_quantizer_nightly-0.3.0.dev20250622.dist-info/RECORD,,
71
+ ai_edge_quantizer_nightly-0.3.0.dev20250624.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
72
+ ai_edge_quantizer_nightly-0.3.0.dev20250624.dist-info/METADATA,sha256=aigmp6Hzdxwsj0hwX5ARfya0brbvfrjYq-nMdlcQja4,1528
73
+ ai_edge_quantizer_nightly-0.3.0.dev20250624.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
74
+ ai_edge_quantizer_nightly-0.3.0.dev20250624.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
75
+ ai_edge_quantizer_nightly-0.3.0.dev20250624.dist-info/RECORD,,
@@ -1,363 +0,0 @@
1
- # Copyright 2024 The AI Edge Quantizer Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Transformation pattern for emulated subchannel quantization."""
17
-
18
- from typing import cast
19
- import numpy as np
20
- from ai_edge_quantizer import qtyping
21
- from ai_edge_quantizer.transformations import quantize_tensor
22
- from ai_edge_quantizer.transformations import transformation_utils
23
- from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
24
-
25
-
26
- def emulated_subchannel(
27
- transformation_input: transformation_utils.TransformationInput,
28
- ) -> qtyping.TransformationInfo:
29
- """Emulated subchannel quantization for fully_connected op.
30
-
31
- The input tensor must also be the weight tensor of the fully_connected op.
32
-
33
- after the transformation, the fully connected op will be replaced by:
34
- reshape -> batch_matmul -> mul -> sum -> add (if bias is present) ->
35
- activation (if fused activation function exist, only support ReLU for now)
36
-
37
- Args:
38
- transformation_input: input structure that contains all information needed
39
- for the transformation.
40
-
41
- Returns:
42
- The transformation info.
43
- """
44
- # only apply to a single fully_connected op
45
- if len(transformation_input.consumers) > 1:
46
- raise ValueError('Emulated Subchannel transformation only support one op')
47
- if isinstance(
48
- transformation_input.quant_params, qtyping.NonLinearQuantParams
49
- ):
50
- raise ValueError(
51
- 'Emulated Subchannel transformation only support uniform quantization'
52
- )
53
- if (
54
- transformation_input.op_codes[
55
- transformation_input.subgraph.operators[
56
- transformation_input.consumers[0]
57
- ].opcodeIndex
58
- ].builtinCode
59
- != schema_py_generated.BuiltinOperator.FULLY_CONNECTED
60
- ):
61
- raise ValueError(
62
- 'Emulated Subchannel transformation only support fully_connected op'
63
- )
64
- if transformation_input.producer != -1:
65
- raise ValueError(
66
- 'Emulated Subchannel transformation only support constant tensor'
67
- )
68
-
69
- # insert all tne necessary op codes into the model
70
- reshape_op_code_idx = transformation_utils.add_op_code(
71
- schema_py_generated.BuiltinOperator.RESHAPE, transformation_input.op_codes
72
- )
73
- bmm_op_code_idx = transformation_utils.add_op_code(
74
- schema_py_generated.BuiltinOperator.BATCH_MATMUL,
75
- transformation_input.op_codes,
76
- )
77
- mul_op_code_idx = transformation_utils.add_op_code(
78
- schema_py_generated.BuiltinOperator.MUL, transformation_input.op_codes
79
- )
80
- sum_op_code_idx = transformation_utils.add_op_code(
81
- schema_py_generated.BuiltinOperator.SUM, transformation_input.op_codes
82
- )
83
-
84
- original_fc_op_idx = transformation_input.consumers[0]
85
- if cast(
86
- schema_py_generated.FullyConnectedOptionsT,
87
- transformation_input.subgraph.operators[
88
- original_fc_op_idx
89
- ].builtinOptions,
90
- ).fusedActivationFunction not in (
91
- schema_py_generated.ActivationFunctionType.NONE,
92
- schema_py_generated.ActivationFunctionType.RELU,
93
- ):
94
- raise ValueError(
95
- 'Emulated Subchannel transformation only support'
96
- ' fusedActivationFunction NONE and RELU for now'
97
- )
98
-
99
- weight_tensor = transformation_input.subgraph.tensors[
100
- transformation_input.tensor_id
101
- ]
102
- weight_tensor.type = quantize_tensor.quant_params_to_tflite_type(
103
- transformation_input.quant_params.num_bits
104
- )
105
-
106
- # modify the weight tensor with the correct quantization parameters
107
- transformation_input.buffers[weight_tensor.buffer].data = np.frombuffer(
108
- cast(
109
- np.ndarray, transformation_input.quant_params.quantized_data
110
- ).tobytes(),
111
- dtype=np.uint8,
112
- )
113
- weight_tensor.shape = cast(
114
- np.ndarray, transformation_input.quant_params.quantized_data
115
- ).shape
116
- weight_tensor.quantization.scale = np.ones(shape=[1], dtype=np.float32)
117
- weight_tensor.quantization.zeroPoint = np.zeros(
118
- shape=[1], dtype=np.int64
119
- ).flatten()
120
-
121
- # assuming zero point is 0, so no need to add a zero point tensor
122
- for val in transformation_input.quant_params.zero_point.flatten():
123
- if val != 0:
124
- raise ValueError(
125
- 'Emulated Subchannel transformation only support zero point 0 for now'
126
- )
127
-
128
- scale_tensor_id = transformation_utils.add_new_constant_tensor(
129
- weight_tensor.name + b'_scale',
130
- transformation_input.quant_params.scale,
131
- schema_py_generated.TensorType.FLOAT32,
132
- transformation_input.subgraph,
133
- transformation_input.buffers,
134
- )
135
-
136
- # for fully connected op, the reduce axis is always 1
137
- reduce_axes_data = np.array([1], dtype=np.int32)
138
- reduce_axes_tensor_id = transformation_utils.add_new_constant_tensor(
139
- weight_tensor.name + b'_reduce_axes',
140
- reduce_axes_data,
141
- schema_py_generated.TensorType.INT32,
142
- transformation_input.subgraph,
143
- transformation_input.buffers,
144
- )
145
-
146
- # find the input and output tensor of the fully connected op
147
- activation_input_id = transformation_input.subgraph.operators[
148
- transformation_input.consumers[0]
149
- ].inputs[0]
150
- activation_output_id = transformation_input.subgraph.operators[
151
- transformation_input.consumers[0]
152
- ].outputs[0]
153
- activation_input = transformation_input.subgraph.tensors[activation_input_id]
154
- activation_output = transformation_input.subgraph.tensors[
155
- activation_output_id
156
- ]
157
-
158
- if len(activation_input.shape) != 3:
159
- raise ValueError(
160
- 'Emulated Subchannel transformation only support 3D input tensor'
161
- )
162
- bmm_input_shape = [
163
- activation_input.shape[0] * activation_input.shape[1],
164
- weight_tensor.shape[1],
165
- 1,
166
- weight_tensor.shape[2],
167
- ]
168
- intermediate_tensor_shape = [
169
- activation_input.shape[0] * activation_input.shape[1],
170
- weight_tensor.shape[1],
171
- 1,
172
- weight_tensor.shape[3],
173
- ]
174
- sum_output_shape = [
175
- activation_input.shape[0] * activation_input.shape[1],
176
- 1,
177
- 1,
178
- weight_tensor.shape[3],
179
- ]
180
-
181
- # create constant tensors for reshape
182
- reshape1_shape_id = transformation_utils.add_new_constant_tensor(
183
- activation_output.name + b'_reshape_op1_shape',
184
- np.array(bmm_input_shape, dtype=np.int32),
185
- schema_py_generated.TensorType.INT32,
186
- transformation_input.subgraph,
187
- transformation_input.buffers,
188
- )
189
- reshape2_shape_id = transformation_utils.add_new_constant_tensor(
190
- activation_output.name + b'_reshape_op2_shape',
191
- np.array(activation_output.shape, dtype=np.int32),
192
- schema_py_generated.TensorType.INT32,
193
- transformation_input.subgraph,
194
- transformation_input.buffers,
195
- )
196
-
197
- # create all intermediate tensors
198
- bmm_input_id = transformation_utils.add_new_activation_tensor(
199
- activation_output.name + b'_bmm_input',
200
- bmm_input_shape,
201
- schema_py_generated.TensorType.FLOAT32,
202
- transformation_input.subgraph,
203
- )
204
- mul_input_id = transformation_utils.add_new_activation_tensor(
205
- activation_output.name + b'_mul_input',
206
- intermediate_tensor_shape,
207
- schema_py_generated.TensorType.FLOAT32,
208
- transformation_input.subgraph,
209
- )
210
- sum_input_id = transformation_utils.add_new_activation_tensor(
211
- activation_output.name + b'_reduce_sum_input',
212
- intermediate_tensor_shape,
213
- schema_py_generated.TensorType.FLOAT32,
214
- transformation_input.subgraph,
215
- )
216
- reshape_op2_input_id = transformation_utils.add_new_activation_tensor(
217
- activation_output.name + b'_reshape_op2_input',
218
- sum_output_shape,
219
- schema_py_generated.TensorType.FLOAT32,
220
- transformation_input.subgraph,
221
- )
222
-
223
- # reshape
224
- reshape_op1 = schema_py_generated.OperatorT()
225
- reshape_op1.opcodeIndex = reshape_op_code_idx
226
- reshape_op1_option = schema_py_generated.ReshapeOptionsT()
227
- reshape_op1_option.newShape = bmm_input_shape
228
- reshape_op1.inputs = [activation_input_id, reshape1_shape_id]
229
- reshape_op1.outputs = [bmm_input_id]
230
- reshape_op1.builtinOptionsType = (
231
- schema_py_generated.BuiltinOptions.ReshapeOptions
232
- ) # reshape option index
233
- reshape_op1.builtinOptions = reshape_op1_option
234
-
235
- # batch_matmul
236
- bmm_op = schema_py_generated.OperatorT()
237
- bmm_op.opcodeIndex = bmm_op_code_idx
238
- bmm_op.inputs = [bmm_input_id, transformation_input.tensor_id]
239
- bmm_op.outputs = [mul_input_id]
240
- bmm_op.builtinOptionsType = (
241
- schema_py_generated.BuiltinOptions.BatchMatMulOptions
242
- )
243
- bmm_op.builtinOptions = schema_py_generated.BatchMatMulOptionsT()
244
-
245
- # mul
246
- mul_op = schema_py_generated.OperatorT()
247
- mul_op.opcodeIndex = mul_op_code_idx
248
- mul_option = schema_py_generated.MulOptionsT()
249
- mul_option.fusedActivationFunction = (
250
- schema_py_generated.ActivationFunctionType.NONE
251
- )
252
- mul_op.inputs = [mul_input_id, scale_tensor_id]
253
- mul_op.outputs = [sum_input_id]
254
- mul_op.builtinOptionsType = schema_py_generated.BuiltinOptions.MulOptions
255
- mul_op.builtinOptions = mul_option
256
-
257
- # sum
258
- sum_op = schema_py_generated.OperatorT()
259
- sum_op.opcodeIndex = sum_op_code_idx
260
- sum_op.inputs = [sum_input_id, reduce_axes_tensor_id]
261
- sum_op.outputs = [reshape_op2_input_id]
262
- sum_op.builtinOptionsType = schema_py_generated.BuiltinOptions.ReducerOptions
263
- sum_op.builtinOptions = schema_py_generated.ReducerOptionsT()
264
- sum_op.builtinOptions.keepDims = True
265
-
266
- # reshape
267
- reshape_op2 = schema_py_generated.OperatorT()
268
- reshape_op2.opcodeIndex = reshape_op_code_idx
269
- reshape_op2_option = schema_py_generated.ReshapeOptionsT()
270
- reshape_op2_option.newShape = activation_output.shape
271
- reshape_op2.inputs = [reshape_op2_input_id, reshape2_shape_id]
272
- reshape_op2.outputs = [activation_output_id]
273
- reshape_op2.builtinOptionsType = (
274
- schema_py_generated.BuiltinOptions.ReshapeOptions
275
- )
276
- reshape_op2.builtinOptions = reshape_op2_option
277
-
278
- transformation_input.subgraph.operators.insert(
279
- original_fc_op_idx, reshape_op1
280
- )
281
- transformation_input.subgraph.operators.insert(original_fc_op_idx + 1, bmm_op)
282
- transformation_input.subgraph.operators.insert(original_fc_op_idx + 2, mul_op)
283
- transformation_input.subgraph.operators.insert(original_fc_op_idx + 3, sum_op)
284
- transformation_input.subgraph.operators.insert(
285
- original_fc_op_idx + 4, reshape_op2
286
- )
287
- ops_added = 5
288
- last_op = reshape_op2
289
-
290
- # If there is a bias tensor (the third input to the original fc op),
291
- # we need an add to process it. The current fc op id need to be recalculated
292
- # because we added operators in front of it.
293
- current_fc_op_id = original_fc_op_idx + ops_added
294
- if (
295
- len(transformation_input.subgraph.operators[current_fc_op_id].inputs) > 2
296
- and transformation_input.subgraph.operators[current_fc_op_id].inputs[2]
297
- != -1
298
- ):
299
- add_op_code_idx = transformation_utils.add_op_code(
300
- schema_py_generated.BuiltinOperator.ADD, transformation_input.op_codes
301
- )
302
- reshape_op2_output_id = transformation_utils.add_new_activation_tensor(
303
- activation_output.name + b'_reshape_op2_output',
304
- activation_output.shape,
305
- schema_py_generated.TensorType.FLOAT32,
306
- transformation_input.subgraph,
307
- )
308
- last_op.outputs = [reshape_op2_output_id]
309
- add_op = schema_py_generated.OperatorT()
310
- add_op.opcodeIndex = add_op_code_idx
311
- add_option = schema_py_generated.AddOptionsT()
312
- add_op.builtinOptionsType = schema_py_generated.BuiltinOptions.AddOptions
313
- add_op.builtinOptions = add_option
314
- add_op.inputs = [
315
- reshape_op2_output_id,
316
- transformation_input.subgraph.operators[
317
- original_fc_op_idx + ops_added
318
- ].inputs[2],
319
- ]
320
- add_op.outputs = [activation_output_id]
321
- transformation_input.subgraph.operators.insert(
322
- original_fc_op_idx + ops_added, add_op
323
- )
324
- ops_added += 1
325
- last_op = add_op
326
-
327
- # If the fused activation function is RELU, we need to add a relu op.
328
- # The current fc op id need to be recalculated because we added operators
329
- # in front of it.
330
- fc_fused_activation_function = cast(
331
- schema_py_generated.FullyConnectedOptionsT,
332
- transformation_input.subgraph.operators[
333
- original_fc_op_idx + ops_added
334
- ].builtinOptions,
335
- ).fusedActivationFunction
336
- if (
337
- fc_fused_activation_function
338
- == schema_py_generated.ActivationFunctionType.RELU
339
- ):
340
- activation_output.name += b'_relu'
341
- relu_input_id = transformation_utils.add_new_activation_tensor(
342
- activation_output.name + b'_relu_input',
343
- activation_output.shape,
344
- schema_py_generated.TensorType.FLOAT32,
345
- transformation_input.subgraph,
346
- )
347
- last_op.outputs = [relu_input_id]
348
- relu_op = schema_py_generated.OperatorT()
349
- relu_op.opcodeIndex = transformation_utils.add_op_code(
350
- schema_py_generated.BuiltinOperator.RELU, transformation_input.op_codes
351
- )
352
- relu_op.inputs = [relu_input_id]
353
- relu_op.outputs = [activation_output_id]
354
- transformation_input.subgraph.operators.insert(
355
- original_fc_op_idx + ops_added, relu_op
356
- )
357
- ops_added += 1
358
- last_op = relu_op
359
-
360
- del transformation_input.subgraph.operators[original_fc_op_idx + ops_added]
361
- return qtyping.TransformationInfo(
362
- original_fc_op_idx, ops_added - 1, activation_output_id
363
- )
@@ -1,212 +0,0 @@
1
- # Copyright 2024 The AI Edge Quantizer Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Tests for emulated_subchannel."""
17
-
18
- import os
19
- import numpy as np
20
- from tensorflow.python.platform import googletest
21
- from ai_edge_quantizer import qtyping
22
- from ai_edge_quantizer.transformations import emulated_subchannel
23
- from ai_edge_quantizer.transformations import transformation_utils
24
- from ai_edge_quantizer.utils import test_utils
25
- from ai_edge_quantizer.utils import tfl_flatbuffer_utils
26
- from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
27
-
28
- TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("..")
29
-
30
-
31
- class EmulatedSubchannelTest(googletest.TestCase):
32
- """Tests for emulated_subchannel."""
33
-
34
- def setUp(self):
35
- super().setUp()
36
- self.params = qtyping.UniformQuantParams(
37
- num_bits=8,
38
- quantized_dimension=None,
39
- scale=np.ones([1, 1, 1, 4], dtype=np.float32),
40
- zero_point=np.zeros([1, 1, 1, 4], dtype=np.int64),
41
- symmetric=True,
42
- quantized_data=np.ones([1, 4, 2, 4], dtype=np.int8),
43
- )
44
-
45
- def test_emulate_subchannel_without_bias_succeeds(self):
46
- """Tests the emulated_subchannel function."""
47
- self._model_path = os.path.join(
48
- TEST_DATA_PREFIX_PATH, "tests/models/single_fc_no_bias.tflite"
49
- )
50
- self._model = tfl_flatbuffer_utils.read_model(self._model_path)
51
- subgraph = self._model.subgraphs[0]
52
- model = self._model
53
- ret = emulated_subchannel.emulated_subchannel(
54
- transformation_utils.TransformationInput(
55
- tensor_id=1,
56
- op_codes=model.operatorCodes,
57
- buffers=model.buffers,
58
- subgraph=subgraph,
59
- producer=-1,
60
- consumers=[0],
61
- quant_params=self.params,
62
- )
63
- )
64
- self.assertEqual(ret.op_id, 0)
65
- self.assertEqual(ret.num_ops_added, 4)
66
- self.assertEqual(ret.output_tensor_id, 2)
67
- self.assertEqual(
68
- model.operatorCodes[subgraph.operators[0].opcodeIndex].builtinCode,
69
- schema_py_generated.BuiltinOperator.RESHAPE,
70
- )
71
- self.assertEqual(
72
- model.operatorCodes[subgraph.operators[1].opcodeIndex].builtinCode,
73
- schema_py_generated.BuiltinOperator.BATCH_MATMUL,
74
- )
75
- self.assertEqual(
76
- model.operatorCodes[subgraph.operators[2].opcodeIndex].builtinCode,
77
- schema_py_generated.BuiltinOperator.MUL,
78
- )
79
- self.assertEqual(
80
- model.operatorCodes[subgraph.operators[3].opcodeIndex].builtinCode,
81
- schema_py_generated.BuiltinOperator.SUM,
82
- )
83
- self.assertEqual(
84
- model.operatorCodes[subgraph.operators[4].opcodeIndex].builtinCode,
85
- schema_py_generated.BuiltinOperator.RESHAPE,
86
- )
87
- self.assertEqual(
88
- subgraph.tensors[subgraph.operators[2].inputs[1]].name,
89
- b"arith.constant_scale",
90
- )
91
- self.assertListEqual(
92
- np.frombuffer(
93
- model.buffers[
94
- subgraph.tensors[subgraph.operators[2].inputs[1]].buffer
95
- ].data,
96
- dtype=np.float32,
97
- ).tolist(),
98
- np.ones([1, 1, 1, 4]).flatten().tolist(),
99
- )
100
-
101
- def test_emulate_subchannel_with_bias_succeeds(self):
102
- """Tests the emulated_subchannel function."""
103
- self._model_path = os.path.join(
104
- TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias.tflite"
105
- )
106
- self._model = tfl_flatbuffer_utils.read_model(self._model_path)
107
- subgraph = self._model.subgraphs[0]
108
- model = self._model
109
- ret = emulated_subchannel.emulated_subchannel(
110
- transformation_utils.TransformationInput(
111
- tensor_id=1,
112
- op_codes=model.operatorCodes,
113
- buffers=model.buffers,
114
- subgraph=subgraph,
115
- producer=-1,
116
- consumers=[0],
117
- quant_params=self.params,
118
- )
119
- )
120
- self.assertEqual(ret.op_id, 0)
121
- self.assertEqual(ret.num_ops_added, 5)
122
- self.assertEqual(ret.output_tensor_id, 3)
123
- self.assertEqual(
124
- model.operatorCodes[subgraph.operators[0].opcodeIndex].builtinCode,
125
- schema_py_generated.BuiltinOperator.RESHAPE,
126
- )
127
- self.assertEqual(
128
- model.operatorCodes[subgraph.operators[1].opcodeIndex].builtinCode,
129
- schema_py_generated.BuiltinOperator.BATCH_MATMUL,
130
- )
131
- self.assertEqual(
132
- model.operatorCodes[subgraph.operators[2].opcodeIndex].builtinCode,
133
- schema_py_generated.BuiltinOperator.MUL,
134
- )
135
- self.assertEqual(
136
- model.operatorCodes[subgraph.operators[3].opcodeIndex].builtinCode,
137
- schema_py_generated.BuiltinOperator.SUM,
138
- )
139
- self.assertEqual(
140
- model.operatorCodes[subgraph.operators[4].opcodeIndex].builtinCode,
141
- schema_py_generated.BuiltinOperator.RESHAPE,
142
- )
143
- self.assertEqual(
144
- model.operatorCodes[subgraph.operators[5].opcodeIndex].builtinCode,
145
- schema_py_generated.BuiltinOperator.ADD,
146
- )
147
- self.assertEqual(
148
- subgraph.tensors[subgraph.operators[2].inputs[1]].name,
149
- b"arith.constant_scale",
150
- )
151
- self.assertListEqual(
152
- np.frombuffer(
153
- model.buffers[
154
- subgraph.tensors[subgraph.operators[2].inputs[1]].buffer
155
- ].data,
156
- dtype=np.float32,
157
- ).tolist(),
158
- np.ones([1, 1, 1, 4]).flatten().tolist(),
159
- )
160
-
161
- def test_emulated_subchannel_with_fused_relu_succeeds(self):
162
- """Tests the emulated_subchannel function with fused relu."""
163
- self._model_path = os.path.join(
164
- TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias_relu.tflite"
165
- )
166
- self._model = tfl_flatbuffer_utils.read_model(self._model_path)
167
- self._model = tfl_flatbuffer_utils.read_model(self._model_path)
168
- subgraph = self._model.subgraphs[0]
169
- model = self._model
170
- ret = emulated_subchannel.emulated_subchannel(
171
- transformation_utils.TransformationInput(
172
- tensor_id=1,
173
- op_codes=model.operatorCodes,
174
- buffers=model.buffers,
175
- subgraph=subgraph,
176
- producer=-1,
177
- consumers=[0],
178
- quant_params=self.params,
179
- )
180
- )
181
- self.assertEqual(ret.op_id, 0)
182
- self.assertEqual(ret.num_ops_added, 6)
183
- self.assertEqual(ret.output_tensor_id, 3)
184
- self.assertEqual(
185
- model.operatorCodes[subgraph.operators[6].opcodeIndex].builtinCode,
186
- schema_py_generated.BuiltinOperator.RELU,
187
- )
188
-
189
- def test_emulated_subchannel_raises_when_unsupported_activation(self):
190
- """Tests the emulated_subchannel function with unsupported activation."""
191
- self._model_path = os.path.join(
192
- TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias_relu6.tflite"
193
- )
194
- self._model = tfl_flatbuffer_utils.read_model(self._model_path)
195
- subgraph = self._model.subgraphs[0]
196
- model = self._model
197
- with self.assertRaises(ValueError):
198
- emulated_subchannel.emulated_subchannel(
199
- transformation_utils.TransformationInput(
200
- tensor_id=1,
201
- op_codes=model.operatorCodes,
202
- buffers=model.buffers,
203
- subgraph=subgraph,
204
- producer=-1,
205
- consumers=[0],
206
- quant_params=self.params,
207
- )
208
- )
209
-
210
-
211
- if __name__ == "__main__":
212
- googletest.main()