ai-edge-quantizer-nightly 0.1.0.dev20250428__py3-none-any.whl → 0.1.0.dev20250430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -104,8 +104,8 @@ class ModelModifier:
104
104
  instructions, quantized_model, tensor_processing_order
105
105
  )
106
106
  constant_buffer_size = self._process_constant_map(quantized_model)
107
- # we leave 64MB for the model architecture.
108
- if constant_buffer_size > 2**31 - 2**26:
107
+ # we leave 256MB for the model architecture.
108
+ if constant_buffer_size > 2**31 - 2**28:
109
109
  return self._serialize_large_model(quantized_model)
110
110
  else:
111
111
  return self._serialize_small_model(quantized_model)
@@ -51,21 +51,39 @@ class TransformationInput:
51
51
  def add_op_code(
52
52
  op_code: schema_py_generated.OperatorCodeT,
53
53
  model_op_codes: list[schema_py_generated.OperatorCodeT],
54
+ custom_op_name: Optional[str] = None,
54
55
  ) -> int:
55
56
  """Add an op code into a model if it's not present.
56
57
 
57
58
  Args:
58
59
  op_code: The op code to be added.
59
60
  model_op_codes: The op codes of the model.
61
+ custom_op_name: The custom string of the op code. If None, the op code will
62
+ be added as a builtin op code.
60
63
 
61
64
  Returns:
62
65
  The index of the op code in the model.
63
66
  """
67
+ if (
68
+ op_code == schema_py_generated.BuiltinOperator.CUSTOM
69
+ and custom_op_name is None
70
+ ):
71
+ raise ValueError('Custom string is required for custom op code.')
72
+
64
73
  for i, model_op_code in enumerate(model_op_codes):
74
+ # If the model already has the op code, just return the index.
65
75
  if model_op_code.builtinCode == op_code:
66
- return i
76
+ if custom_op_name is not None:
77
+ if model_op_code.customCode == custom_op_name:
78
+ return i
79
+ else:
80
+ # Built-in op
81
+ return i
82
+
67
83
  model_op_codes.append(schema_py_generated.OperatorCodeT())
68
84
  model_op_codes[-1].builtinCode = op_code
85
+ if custom_op_name is not None:
86
+ model_op_codes[-1].customCode = custom_op_name
69
87
  return len(model_op_codes) - 1
70
88
 
71
89
 
@@ -146,7 +164,14 @@ def add_new_activation_tensor(
146
164
  The index of the new tensor in the subgraph.
147
165
  """
148
166
  new_tensor = schema_py_generated.TensorT()
149
- new_tensor.shape = shape
167
+ # If there's a dynamic shape, we need to read from the shapeSignature field
168
+ # instead of shape. Shape should contain just 1 for the dynamic dimension but
169
+ # shapeSignature should contain the true shape.
170
+ if -1 in shape:
171
+ new_tensor.shapeSignature = shape
172
+ new_tensor.shape = [1 if i == -1 else i for i in shape]
173
+ else:
174
+ new_tensor.shape = shape
150
175
  new_tensor.type = tensor_type
151
176
  new_tensor.name = tensor_name
152
177
  new_tensor.buffer = 0
@@ -41,19 +41,62 @@ class TransformationUtilsTest(parameterized.TestCase):
41
41
  testcase_name="add_new_op_code",
42
42
  op_code=schema_py_generated.BuiltinOperator.LOGISTIC,
43
43
  expected=1,
44
+ custom_op_name=None,
44
45
  ),
45
46
  dict(
46
47
  testcase_name="add_existing_op_code",
47
48
  op_code=schema_py_generated.BuiltinOperator.FULLY_CONNECTED,
48
49
  expected=0,
50
+ custom_op_name=None,
51
+ ),
52
+ dict(
53
+ testcase_name="add_new_custom_op_code",
54
+ op_code=schema_py_generated.BuiltinOperator.CUSTOM,
55
+ expected=1,
56
+ custom_op_name="random_new_custom_op",
49
57
  ),
50
58
  )
51
- def test_add_op_code(self, op_code, expected):
59
+ def test_add_op_code(self, op_code, expected, custom_op_name):
52
60
  """Tests if the op code is added to the model."""
53
61
  got = transformation_utils.add_op_code(
54
- op_code=op_code, model_op_codes=self.model.operatorCodes
62
+ op_code=op_code,
63
+ model_op_codes=self.model.operatorCodes,
64
+ custom_op_name=custom_op_name,
55
65
  )
56
66
  self.assertEqual(expected, got)
67
+ if custom_op_name is not None:
68
+ self.assertEqual(self.model.operatorCodes[got].customCode, custom_op_name)
69
+
70
+ def test_add_custom_op_code_without_op_string_raises_error(self):
71
+ with self.assertRaisesRegex(ValueError, "Custom string is required"):
72
+ transformation_utils.add_op_code(
73
+ op_code=schema_py_generated.BuiltinOperator.CUSTOM,
74
+ model_op_codes=self.model.operatorCodes,
75
+ custom_op_name=None,
76
+ )
77
+
78
+ def test_add_two_custom_op_codes(self):
79
+ custom_op_name = "random_new_custom_op"
80
+ added_index = transformation_utils.add_op_code(
81
+ op_code=schema_py_generated.BuiltinOperator.CUSTOM,
82
+ model_op_codes=self.model.operatorCodes,
83
+ custom_op_name=custom_op_name,
84
+ )
85
+ self.assertEqual(1, added_index)
86
+ self.assertEqual(
87
+ self.model.operatorCodes[added_index].customCode, custom_op_name
88
+ )
89
+
90
+ custom_op_name_2 = "random_new_custom_op_2"
91
+ added_index = transformation_utils.add_op_code(
92
+ op_code=schema_py_generated.BuiltinOperator.CUSTOM,
93
+ model_op_codes=self.model.operatorCodes,
94
+ custom_op_name=custom_op_name_2,
95
+ )
96
+ self.assertEqual(2, added_index)
97
+ self.assertEqual(
98
+ self.model.operatorCodes[added_index].customCode, custom_op_name_2
99
+ )
57
100
 
58
101
  @parameterized.named_parameters(
59
102
  dict(
@@ -189,6 +232,25 @@ class TransformationUtilsTest(parameterized.TestCase):
189
232
  self.model.subgraphs[0].tensors[-1].shape,
190
233
  )
191
234
 
235
+ def test_add_new_activation_tensor_with_dynamic_shape(self):
236
+ """Tests adding an activation tensor with dynamic shape."""
237
+ subgraph = self.model.subgraphs[0]
238
+ new_id = transformation_utils.add_new_activation_tensor(
239
+ tensor_name="test_tensor",
240
+ shape=[1, -1, -1, 1],
241
+ tensor_type=schema_py_generated.TensorType.FLOAT32,
242
+ subgraph=subgraph,
243
+ )
244
+ # Originally had 4 tensors, new tensor is added at index 4.
245
+ self.assertEqual(new_id, 4)
246
+ self.assertLen(subgraph.tensors, 5)
247
+ self.assertEqual(subgraph.tensors[-1].name, "test_tensor")
248
+ self.assertEqual(
249
+ subgraph.tensors[-1].type, schema_py_generated.TensorType.FLOAT32
250
+ )
251
+ self.assertEqual(subgraph.tensors[-1].shape, [1, 1, 1, 1])
252
+ self.assertEqual(subgraph.tensors[-1].shapeSignature, [1, -1, -1, 1])
253
+
192
254
 
193
255
  if __name__ == "__main__":
194
256
  googletest.main()
@@ -319,7 +319,27 @@ def get_signature_main_subgraph_index(
319
319
  return signature_runner._subgraph_index # pylint:disable=protected-access
320
320
 
321
321
 
322
- def create_random_normal_dataset(
322
+ def _create_random_normal(
323
+ rng: np.random.Generator,
324
+ shape: tuple[int, ...],
325
+ dtype: np.dtype,
326
+ ) -> dict[str, Any]:
327
+ """Creates a random normal dataset sample for given input details."""
328
+ return rng.normal(size=shape).astype(dtype)
329
+
330
+
331
+ def _create_random_integers(
332
+ rng: np.random.Generator,
333
+ shape: tuple[int, ...],
334
+ dtype: np.dtype,
335
+ min_value: int = 0,
336
+ max_value: int = 1024,
337
+ ) -> dict[str, Any]:
338
+ """Creates a random integer dataset sample for given input details."""
339
+ return rng.integers(min_value, max_value, size=shape, dtype=dtype)
340
+
341
+
342
+ def create_random_dataset(
323
343
  input_details: dict[str, Any],
324
344
  num_samples: int,
325
345
  random_seed: Union[int, np._typing.ArrayLike],
@@ -340,9 +360,14 @@ def create_random_normal_dataset(
340
360
  for _ in range(num_samples):
341
361
  input_data = {}
342
362
  for arg_name, input_tensor in input_details.items():
343
- new_data = rng.normal(size=input_tensor["shape"]).astype(
344
- input_tensor["dtype"]
345
- )
363
+ dtype = input_tensor["dtype"]
364
+ shape = input_tensor["shape"]
365
+ if dtype in (np.int32, np.int64):
366
+ new_data = _create_random_integers(rng, shape, dtype)
367
+ elif dtype == np.float32:
368
+ new_data = _create_random_normal(rng, shape, dtype)
369
+ else:
370
+ raise ValueError(f"Unsupported dtype: {input_tensor['dtype']}")
346
371
  input_data[arg_name] = new_data
347
372
  dataset.append(input_data)
348
373
  return dataset
@@ -372,7 +397,7 @@ def create_random_normal_input_data(
372
397
  for signature_key in signature_keys:
373
398
  signature_runner = tfl_interpreter.get_signature_runner(signature_key)
374
399
  input_details = signature_runner.get_input_details()
375
- test_data[signature_key] = create_random_normal_dataset(
400
+ test_data[signature_key] = create_random_dataset(
376
401
  input_details, num_samples, random_seed
377
402
  )
378
403
  return test_data
@@ -19,7 +19,6 @@ from tensorflow.python.platform import googletest
19
19
  from ai_edge_quantizer.utils import test_utils
20
20
  from ai_edge_quantizer.utils import tfl_interpreter_utils
21
21
 
22
-
23
22
  TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("../tests/models")
24
23
 
25
24
 
@@ -159,7 +158,6 @@ class TflUtilsQuantizedModelTest(googletest.TestCase):
159
158
  signature_output = tfl_interpreter_utils.invoke_interpreter_signature(
160
159
  tfl_interpreter, self._signature_input_data
161
160
  )
162
- print(signature_output)
163
161
  self.assertEqual(tuple(signature_output["dense_1"].shape), (1, 10))
164
162
 
165
163
  # Assert the input data is not modified in-place b/353340272.
@@ -328,5 +326,24 @@ class TflUtilsMultiSignatureModelTest(googletest.TestCase):
328
326
  self.assertEqual(multiply_output_content, [20.0])
329
327
 
330
328
 
329
+ class TflUtilsIntegerInputModelTest(googletest.TestCase):
330
+
331
+ def setUp(self):
332
+ super().setUp()
333
+ np.random.seed(0)
334
+ self._test_model_path = os.path.join(
335
+ TEST_DATA_PREFIX_PATH, "toy_model_with_kv_cache_multi_signature.tflite"
336
+ )
337
+
338
+ def test_random_integer_input_data(self):
339
+ test_data = tfl_interpreter_utils.create_random_normal_input_data(
340
+ self._test_model_path
341
+ )
342
+ self.assertEqual(test_data["signature_1"][0]["cache_0"].dtype, np.float32)
343
+ self.assertEqual(test_data["signature_1"][0]["cache_1"].dtype, np.float32)
344
+ self.assertEqual(test_data["signature_1"][0]["positions"].dtype, np.int32)
345
+ self.assertEqual(test_data["signature_1"][0]["tokens"].dtype, np.int32)
346
+
347
+
331
348
  if __name__ == "__main__":
332
349
  googletest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-quantizer-nightly
3
- Version: 0.1.0.dev20250428
3
+ Version: 0.1.0.dev20250430
4
4
  Summary: A quantizer for advanced developers to quantize converted AI Edge models.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
6
6
  Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
@@ -6,7 +6,7 @@ ai_edge_quantizer/calibrator.py,sha256=n7AD9j7UScR-CieoI6DQRMeiG_fhLBfSLRiM4460x
6
6
  ai_edge_quantizer/calibrator_test.py,sha256=C_oWOaRugPKYX74jF-eRFH-k6nGOdA8I9_uPiocaOuE,11900
7
7
  ai_edge_quantizer/conftest.py,sha256=SxCz-5LlRD_lQm4hQc4c6IGG7DS8d7IyEWY9gnscPN0,794
8
8
  ai_edge_quantizer/default_policy.py,sha256=81z4cruBK7mGFt8xFRZK5LKya65axuZwo2zpbcYSicc,11099
9
- ai_edge_quantizer/model_modifier.py,sha256=SPt9X-xBzRvcd4xIS24zLHt3aUS2QwsNDqweFqitCAo,7109
9
+ ai_edge_quantizer/model_modifier.py,sha256=teGa8I6kGvn6TQY6Xv53YFIc_pQEhNvM9Zb4bvhezyw,7110
10
10
  ai_edge_quantizer/model_modifier_test.py,sha256=cJd04SLOG-fQZZNZPcisoBLx3cLtWEwGqUBbLb-pif4,4751
11
11
  ai_edge_quantizer/model_validator.py,sha256=fRNz0jO54cthPTibsCuViUXUuFRHl_fbvEiCukIVy20,13030
12
12
  ai_edge_quantizer/model_validator_test.py,sha256=EeqOP_mrZsnZ3rug756s0ryDDqd2KgIDld5Lm_gDuWY,13020
@@ -54,20 +54,20 @@ ai_edge_quantizer/transformations/quant_insert.py,sha256=jn6HsJaV-sqBiFPY-Aqbd64
54
54
  ai_edge_quantizer/transformations/quant_insert_test.py,sha256=X9ptPDvJCFkR5tejKnD1SlHFGPazQTW-wNNMV9MEAuw,10107
55
55
  ai_edge_quantizer/transformations/quantize_tensor.py,sha256=kjaNrw9mnrn0t8u0vey9S_uPz3iVUicwy4rluxVqV3E,7617
56
56
  ai_edge_quantizer/transformations/quantize_tensor_test.py,sha256=mHLO3_MRt36A8-ZN8ADn5tBBJlqjTWa7ZUN8Mmu5Rcw,9116
57
- ai_edge_quantizer/transformations/transformation_utils.py,sha256=5w0fG6TP362elTHs-JZokl24fuK4Gv6DGyIpybQYb3g,4885
58
- ai_edge_quantizer/transformations/transformation_utils_test.py,sha256=xH64SF3UHDh84vYbt-WvmXNjM-Jg-mefES1ACO1tkqw,6269
57
+ ai_edge_quantizer/transformations/transformation_utils.py,sha256=Hc1jrY3cEUooiTu9qOh4jxyZp58vrokKxzTmzx6V70c,5853
58
+ ai_edge_quantizer/transformations/transformation_utils_test.py,sha256=E90O4PYSjzGdHhaNvm3ii0Xom3cyFfcqQyYjOhYzG-c,8702
59
59
  ai_edge_quantizer/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
60
60
  ai_edge_quantizer/utils/calibration_utils.py,sha256=1Fj9MIO6aLZIRgyd4axvZN4S_O64nB_-Miu1WP664js,2536
61
61
  ai_edge_quantizer/utils/calibration_utils_test.py,sha256=Z-AcdTieesWFKyKBb08ZXm4Mgu6cvJ4bg2-MJ7hLD10,2856
62
62
  ai_edge_quantizer/utils/test_utils.py,sha256=HwZCIpO9fJRAhuN6t6voXKOYQtcioFtt_tpkAlDsAYk,6205
63
63
  ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=NKtw60BJAjIE6Yww8B1vJpxXwp4MSERmpKajXJWm5rI,10568
64
64
  ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=K1SbK8q92qYVtiVj0I0GtugsPTkpIpEKv9zakvFV_Sc,8555
65
- ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=x2xA2CFPpe_2trcV8v5xGaBETvVCfwAcJuq6yieGJ0Y,12687
66
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=Op3JxtOqlrjzmYF18jnnstL1k9xiY9kKJ8S2vklKGkc,11327
65
+ ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=WoewyiZpaua80oP0tpgyrw5Ws1v7f4vl88vdzS0UjDE,13490
66
+ ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=6fjkM-rycZ95L4yfvlr0TN6RlrhfPzxNUYrZaYO_F0A,12013
67
67
  ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
68
68
  ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
69
- ai_edge_quantizer_nightly-0.1.0.dev20250428.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
70
- ai_edge_quantizer_nightly-0.1.0.dev20250428.dist-info/METADATA,sha256=yyhIbc-7ZiZ6-UFWCpYx1LgbsoYfTxl7pnqkCiTGbA8,1527
71
- ai_edge_quantizer_nightly-0.1.0.dev20250428.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
72
- ai_edge_quantizer_nightly-0.1.0.dev20250428.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
73
- ai_edge_quantizer_nightly-0.1.0.dev20250428.dist-info/RECORD,,
69
+ ai_edge_quantizer_nightly-0.1.0.dev20250430.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
70
+ ai_edge_quantizer_nightly-0.1.0.dev20250430.dist-info/METADATA,sha256=_q1njPlZxzBVNALKPM-yvI1dmjSspbXKq8wWHLYitR4,1527
71
+ ai_edge_quantizer_nightly-0.1.0.dev20250430.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
72
+ ai_edge_quantizer_nightly-0.1.0.dev20250430.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
73
+ ai_edge_quantizer_nightly-0.1.0.dev20250430.dist-info/RECORD,,