ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +224 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
- ai_edge_quantizer/calibrator.py +58 -94
- ai_edge_quantizer/calibrator_test.py +5 -74
- ai_edge_quantizer/default_policy.py +108 -16
- ai_edge_quantizer/model_modifier.py +132 -8
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +38 -10
- ai_edge_quantizer/model_validator_test.py +2 -1
- ai_edge_quantizer/params_generator.py +230 -47
- ai_edge_quantizer/params_generator_test.py +366 -261
- ai_edge_quantizer/qtyping.py +92 -6
- ai_edge_quantizer/quantizer.py +167 -23
- ai_edge_quantizer/quantizer_test.py +288 -26
- ai_edge_quantizer/recipe.py +156 -21
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +313 -46
- ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
- ai_edge_quantizer/transformation_performer.py +112 -58
- ai_edge_quantizer/transformation_performer_test.py +176 -4
- ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
- ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
- ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
- ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
- ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
- ai_edge_quantizer/transformations/transformation_utils.py +157 -11
- ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +191 -58
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -15,10 +15,12 @@
|
|
|
15
15
|
|
|
16
16
|
"""Generate model tensor level quantization config."""
|
|
17
17
|
|
|
18
|
+
from collections.abc import Sequence
|
|
18
19
|
import copy
|
|
19
20
|
from typing import Any, Optional, Union
|
|
20
21
|
|
|
21
22
|
from ai_edge_quantizer import algorithm_manager
|
|
23
|
+
from ai_edge_quantizer import default_policy as policy
|
|
22
24
|
from ai_edge_quantizer import qtyping
|
|
23
25
|
from ai_edge_quantizer import recipe_manager
|
|
24
26
|
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
|
@@ -33,12 +35,12 @@ class ParamsGenerator:
|
|
|
33
35
|
def __init__(self, float_tflite: Union[str, bytes]):
|
|
34
36
|
self.flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
|
|
35
37
|
|
|
36
|
-
if not tfl_flatbuffer_utils.is_float_model(self.flatbuffer_model):
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
38
|
+
# if not tfl_flatbuffer_utils.is_float_model(self.flatbuffer_model):
|
|
39
|
+
# raise ValueError(
|
|
40
|
+
# 'The input model for quantization parameters generation is not a'
|
|
41
|
+
# ' float model. Please check the model (e.g., if it is already'
|
|
42
|
+
# ' quantized).'
|
|
43
|
+
# )
|
|
42
44
|
self._check_tensor_names_are_unique()
|
|
43
45
|
self.buffer_to_tensors: dict[int, list[Any]] = (
|
|
44
46
|
tfl_flatbuffer_utils.buffer_to_tensors(self.flatbuffer_model)
|
|
@@ -73,8 +75,10 @@ class ParamsGenerator:
|
|
|
73
75
|
if model_qsvs is None:
|
|
74
76
|
model_qsvs = {}
|
|
75
77
|
|
|
78
|
+
skip_subgraphs = set()
|
|
76
79
|
op_codes = self.flatbuffer_model.operatorCodes
|
|
77
|
-
for subgraph in self.flatbuffer_model.subgraphs:
|
|
80
|
+
for sg_ind, subgraph in enumerate(self.flatbuffer_model.subgraphs):
|
|
81
|
+
|
|
78
82
|
graph_info = qtyping.GraphInfo(
|
|
79
83
|
subgraph.tensors, self.flatbuffer_model.buffers
|
|
80
84
|
)
|
|
@@ -103,10 +107,22 @@ class ParamsGenerator:
|
|
|
103
107
|
algorithm_name, op_quant_config = (
|
|
104
108
|
model_recipe_manager.get_quantization_configs(op_key, op_scope)
|
|
105
109
|
)
|
|
110
|
+
|
|
111
|
+
if sg_ind in skip_subgraphs or policy.is_non_quantizable_composite_op(
|
|
112
|
+
op
|
|
113
|
+
):
|
|
114
|
+
algorithm_name = algorithm_manager.AlgorithmName.NO_QUANTIZE
|
|
115
|
+
|
|
106
116
|
if algorithm_name == algorithm_manager.AlgorithmName.NO_QUANTIZE:
|
|
117
|
+
side_effect_subgraphs = (
|
|
118
|
+
tfl_flatbuffer_utils.get_op_side_effect_subgraphs(op)
|
|
119
|
+
)
|
|
120
|
+
skip_subgraphs.update(side_effect_subgraphs)
|
|
121
|
+
|
|
107
122
|
op_quant_results = self._get_params_for_no_quant_op(
|
|
108
123
|
subgraph_op_id, op, subgraph.tensors
|
|
109
124
|
)
|
|
125
|
+
|
|
110
126
|
else:
|
|
111
127
|
op_info = qtyping.OpInfo(op, op_key, subgraph_op_id, op_quant_config)
|
|
112
128
|
# Step2: query algorithm_manager to get/call the related function.
|
|
@@ -146,7 +162,7 @@ class ParamsGenerator:
|
|
|
146
162
|
RuntimeError: If the tensors sharing the same buffer have different
|
|
147
163
|
quantization settings.
|
|
148
164
|
"""
|
|
149
|
-
self.
|
|
165
|
+
self._check_and_fix_buffer_sharing()
|
|
150
166
|
|
|
151
167
|
def _update_model_quant_results(
|
|
152
168
|
self,
|
|
@@ -252,57 +268,224 @@ class ParamsGenerator:
|
|
|
252
268
|
tensor_params.append(output_tensor_params)
|
|
253
269
|
return tensor_params
|
|
254
270
|
|
|
255
|
-
def
|
|
256
|
-
|
|
271
|
+
def _mark_tensors_requiring_buffer_duplication(
|
|
272
|
+
self, buffers_to_duplicate: Sequence[int]
|
|
273
|
+
) -> None:
|
|
274
|
+
"""Mark tensors that require buffer duplication.
|
|
275
|
+
|
|
276
|
+
Marking a tensor means adding a DUPLICATE_BUFFER transformation as the first
|
|
277
|
+
transformation to be applied for each consumer of the tensor. Need to do
|
|
278
|
+
that for each consumer to preserve a zero layer and not affect the
|
|
279
|
+
horizontal optimization later in the transformation instructions generator.
|
|
280
|
+
|
|
281
|
+
Marks all tensors within each of the provided buffers as requiring buffer
|
|
282
|
+
duplication, except for the last tensor. The order of tensors is assumed to
|
|
283
|
+
be the same during both the marking and transformation performer steps, as
|
|
284
|
+
determined by `self.buffer_to_tensors`. This allows the final tensor to
|
|
285
|
+
reuse the original buffer, as it is not marked for duplication.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
buffers_to_duplicate: Indices of the buffers to duplicate.
|
|
289
|
+
"""
|
|
290
|
+
for buffer_idx in buffers_to_duplicate:
|
|
291
|
+
for tensor in self.buffer_to_tensors[buffer_idx][:-1]:
|
|
292
|
+
tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
|
|
293
|
+
for consumer_params in self.model_quant_results[tensor_name].consumers:
|
|
294
|
+
consumer_params.transformations.insert(
|
|
295
|
+
0, _QuantTrans.DUPLICATE_BUFFER
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
def _mark_tensors_requiring_tensor_duplication(
|
|
299
|
+
self, tensor_names_to_duplicate
|
|
300
|
+
) -> None:
|
|
301
|
+
"""Mark tensors that require tensor duplication.
|
|
302
|
+
|
|
303
|
+
Marking a tensor means adding a DUPLICATE_TENSOR transformation as the first
|
|
304
|
+
transformation to be applied for each consumer of the tensor. Need to do
|
|
305
|
+
that for each consumer to preserve a zero layer and not affect the
|
|
306
|
+
horizontal optimization later in the transformation instructions generator.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
tensor_names_to_duplicate: Names of tensors to duplicate.
|
|
310
|
+
"""
|
|
311
|
+
for tensor_name in tensor_names_to_duplicate:
|
|
312
|
+
for consumer_params in self.model_quant_results[tensor_name].consumers:
|
|
313
|
+
consumer_params.transformations.insert(0, _QuantTrans.DUPLICATE_TENSOR)
|
|
314
|
+
|
|
315
|
+
def _check_buffer_sharing_for_tensor(self, tensor: Any) -> bool:
|
|
316
|
+
"""Check buffer sharing for the tensor against itself.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
tensor: The tensor to check.
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
Whether the tensor has compatible quantization parameters.
|
|
323
|
+
|
|
324
|
+
Raises:
|
|
325
|
+
RuntimeError: If the tensor has incompatible quantization parameters
|
|
326
|
+
and the buffer is not constant.
|
|
327
|
+
"""
|
|
328
|
+
tensor_params = self.model_quant_results.get(
|
|
329
|
+
tfl_flatbuffer_utils.get_tensor_name(tensor), None
|
|
330
|
+
)
|
|
331
|
+
if tensor_params is None:
|
|
332
|
+
return True
|
|
333
|
+
|
|
334
|
+
if _are_tensor_consumer_params_compatible(tensor_params):
|
|
335
|
+
return True
|
|
336
|
+
elif _is_constant_tensor(tensor, self.flatbuffer_model.buffers):
|
|
337
|
+
return False
|
|
338
|
+
else:
|
|
339
|
+
error_msg = (
|
|
340
|
+
f'The tensor {tensor.name} consumers do not have the same'
|
|
341
|
+
' quantization parameters. Please modify your quantization recipe to'
|
|
342
|
+
' make sure the two tensors have the same quantization settings.'
|
|
343
|
+
)
|
|
344
|
+
raise RuntimeError(error_msg)
|
|
345
|
+
|
|
346
|
+
def _check_buffer_sharing_for_self_compatible_tensors(
|
|
347
|
+
self, tensor1: Any, tensor2: Any
|
|
348
|
+
) -> bool:
|
|
349
|
+
"""Check a pair of self compatible tensors have the same quantization params.
|
|
350
|
+
|
|
351
|
+
Self compatible means that all tensor's consumers have the same quantization
|
|
352
|
+
parameters.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
tensor1: The first tensor to check.
|
|
356
|
+
tensor2: The second tensor to check.
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
Whether the tensors have compatible quantization parameters.
|
|
360
|
+
|
|
361
|
+
Raises:
|
|
362
|
+
RuntimeError: If the tensors have incompatible quantization parameters
|
|
363
|
+
and the buffer is not constant.
|
|
364
|
+
"""
|
|
365
|
+
tensor1_params = self.model_quant_results.get(
|
|
366
|
+
tfl_flatbuffer_utils.get_tensor_name(tensor1), None
|
|
367
|
+
)
|
|
368
|
+
tensor2_params = self.model_quant_results.get(
|
|
369
|
+
tfl_flatbuffer_utils.get_tensor_name(tensor2), None
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
if tensor1_params is None or tensor2_params is None:
|
|
373
|
+
return True
|
|
374
|
+
|
|
375
|
+
if _are_self_compatible_tensors_compatible_to_each_other(
|
|
376
|
+
tensor1_params, tensor2_params
|
|
377
|
+
):
|
|
378
|
+
return True
|
|
379
|
+
elif _is_constant_tensor(tensor1, self.flatbuffer_model.buffers):
|
|
380
|
+
return False
|
|
381
|
+
else:
|
|
382
|
+
error_msg = (
|
|
383
|
+
f'The tensors {tensor1.name} and {tensor2.name} do not have'
|
|
384
|
+
' the same quantization parameters even though they share the'
|
|
385
|
+
' same buffer. Please modify your quantization recipe to make'
|
|
386
|
+
' sure the two tensors have the same quantization settings.'
|
|
387
|
+
)
|
|
388
|
+
raise RuntimeError(error_msg)
|
|
389
|
+
|
|
390
|
+
def _check_and_fix_buffer_sharing(self) -> None:
|
|
391
|
+
"""Check and fix tensor/buffer sharing issues when possible.
|
|
392
|
+
|
|
393
|
+
This function checks if tensors sharing the same buffer have the same
|
|
394
|
+
quantization settings. If not, when it's possible, it will fix it by marking
|
|
395
|
+
such tensors or buffers to be duplicated. Otherwise, it will raise an error.
|
|
396
|
+
|
|
397
|
+
Possible cases that can be fixed by duplication:
|
|
398
|
+
1. A constant tensor recieves different quantization parameters from its
|
|
399
|
+
consumers. In this case, the tensor is marked for duplication.
|
|
400
|
+
2. Two or more tensors share the same constant buffer and have different
|
|
401
|
+
quantization parameters. In this case, the buffer is marked for
|
|
402
|
+
duplication.
|
|
257
403
|
|
|
258
404
|
Raises:
|
|
259
405
|
RuntimeError: If the tensors sharing the same buffer have different
|
|
260
|
-
quantization settings
|
|
406
|
+
quantization settings and it can't be resolved by duplicating the
|
|
407
|
+
buffer/tensor.
|
|
261
408
|
"""
|
|
262
|
-
|
|
263
|
-
|
|
409
|
+
buffers_to_duplicate = []
|
|
410
|
+
tensor_names_to_duplicate = []
|
|
411
|
+
for buffer_idx, tensors in self.buffer_to_tensors.items():
|
|
412
|
+
# TODO: b/458797890 - Investigate if skipping buffer_idx == 0 is a
|
|
413
|
+
# correct fix, or if it just covers up a deeper issue. This is only
|
|
414
|
+
# required when statically quantizing models that have already been
|
|
415
|
+
# quantized dynamically.
|
|
416
|
+
if not tensors or buffer_idx == 0:
|
|
264
417
|
continue
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
tensor_params = self.model_quant_results[
|
|
271
|
-
tfl_flatbuffer_utils.get_tensor_name(tensor)
|
|
272
|
-
]
|
|
273
|
-
if not _compatible_tensor_transformation_params(
|
|
274
|
-
first_tensor_params, tensor_params
|
|
275
|
-
):
|
|
276
|
-
error_msg = (
|
|
277
|
-
f'The tensors {first_tensor.name} and {tensor.name} do not have'
|
|
278
|
-
' the same quantization parameters even though they share the'
|
|
279
|
-
' same buffer. Please modify your quantization recipe to make'
|
|
280
|
-
' sure the two tensors have the same quantization settings.'
|
|
418
|
+
# Check if any of the tensors needs to be duplicated.
|
|
419
|
+
for tensor in tensors:
|
|
420
|
+
if not self._check_buffer_sharing_for_tensor(tensor):
|
|
421
|
+
tensor_names_to_duplicate.append(
|
|
422
|
+
tfl_flatbuffer_utils.get_tensor_name(tensor)
|
|
281
423
|
)
|
|
282
|
-
|
|
424
|
+
# Check if the buffer needs to be duplicated.
|
|
425
|
+
tensor_1 = tensors[0]
|
|
426
|
+
tensor_name_1 = tfl_flatbuffer_utils.get_tensor_name(tensor_1)
|
|
427
|
+
if tensor_name_1 in tensor_names_to_duplicate:
|
|
428
|
+
buffers_to_duplicate.append(buffer_idx)
|
|
429
|
+
continue
|
|
430
|
+
for tensor_2 in tensors[1:]:
|
|
431
|
+
tensor_name_2 = tfl_flatbuffer_utils.get_tensor_name(tensor_2)
|
|
432
|
+
if (
|
|
433
|
+
tensor_name_2 in tensor_names_to_duplicate
|
|
434
|
+
or not self._check_buffer_sharing_for_self_compatible_tensors(
|
|
435
|
+
tensor_1, tensor_2
|
|
436
|
+
)
|
|
437
|
+
):
|
|
438
|
+
buffers_to_duplicate.append(buffer_idx)
|
|
439
|
+
break
|
|
440
|
+
|
|
441
|
+
# Fix the buffer sharing issues.
|
|
442
|
+
self._mark_tensors_requiring_buffer_duplication(buffers_to_duplicate)
|
|
443
|
+
self._mark_tensors_requiring_tensor_duplication(tensor_names_to_duplicate)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def _are_tensor_consumer_params_compatible(
|
|
447
|
+
params: qtyping.TensorTransformationParams,
|
|
448
|
+
) -> bool:
|
|
449
|
+
"""Check if all tensor's consumers have the same quantization parameters."""
|
|
450
|
+
if params.consumers is None or len(params.consumers) < 2:
|
|
451
|
+
return True
|
|
452
|
+
consumer_1 = params.consumers[0]
|
|
453
|
+
for consumer in params.consumers[1:]:
|
|
454
|
+
if not _compatible_tensor_params(consumer, consumer_1):
|
|
455
|
+
return False
|
|
456
|
+
return True
|
|
283
457
|
|
|
284
458
|
|
|
285
|
-
def
|
|
459
|
+
def _are_self_compatible_tensors_compatible_to_each_other(
|
|
286
460
|
params1: qtyping.TensorTransformationParams,
|
|
287
461
|
params2: qtyping.TensorTransformationParams,
|
|
288
462
|
) -> bool:
|
|
289
|
-
"""Check if two
|
|
463
|
+
"""Check if two self compatible tensors are compatible to each other.
|
|
464
|
+
|
|
465
|
+
Self compatible means that all tensor's consumers have the same quantization
|
|
466
|
+
parameters.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
params1: The first tensor transformation params.
|
|
470
|
+
params2: The second tensor transformation params.
|
|
471
|
+
|
|
472
|
+
Returns:
|
|
473
|
+
Whether the two tensors are compatible to each other.
|
|
474
|
+
"""
|
|
475
|
+
# Check the producer.
|
|
290
476
|
if params1.producer is None or params2.producer is None:
|
|
291
477
|
if params1.producer != params2.producer:
|
|
292
478
|
return False
|
|
293
479
|
elif not _compatible_tensor_params(params1.producer, params2.producer):
|
|
294
480
|
return False
|
|
481
|
+
|
|
482
|
+
# Check the consumers.
|
|
295
483
|
if params1.consumers is None or params2.consumers is None:
|
|
296
484
|
if params1.consumers != params2.consumers:
|
|
297
485
|
return False
|
|
298
486
|
else:
|
|
299
|
-
#
|
|
300
|
-
|
|
301
|
-
if not _compatible_tensor_params(params1_consumer, params1.consumers[0]):
|
|
302
|
-
return False
|
|
303
|
-
for params2_consumer in params2.consumers:
|
|
304
|
-
if not _compatible_tensor_params(params2_consumer, params2.consumers[0]):
|
|
305
|
-
return False
|
|
487
|
+
# Since all consumer params within each tensor are the same, it's enough to
|
|
488
|
+
# check only the first consumers.
|
|
306
489
|
if not _compatible_tensor_params(
|
|
307
490
|
params1.consumers[0], params2.consumers[0]
|
|
308
491
|
):
|
|
@@ -330,6 +513,8 @@ def _compatible_tensor_params(
|
|
|
330
513
|
float_source_transformations = [
|
|
331
514
|
_QuantTrans.ADD_QUANTIZE,
|
|
332
515
|
_QuantTrans.NO_QUANTIZE,
|
|
516
|
+
_QuantTrans.INSERT_HADAMARD_ROTATION,
|
|
517
|
+
_QuantTrans.INSERT_DECOMPOSED_HADAMARD_ROTATION,
|
|
333
518
|
]
|
|
334
519
|
quantized_source_transformations = [
|
|
335
520
|
_QuantTrans.QUANTIZE_TENSOR,
|
|
@@ -337,14 +522,6 @@ def _compatible_tensor_params(
|
|
|
337
522
|
]
|
|
338
523
|
if _same_tensor_params_except_id(params1, params2):
|
|
339
524
|
return True
|
|
340
|
-
if (
|
|
341
|
-
params1.transformations[0] != _QuantTrans.NO_QUANTIZE
|
|
342
|
-
and params2.transformations[0] != _QuantTrans.NO_QUANTIZE
|
|
343
|
-
):
|
|
344
|
-
# NO_QUANTIZE has no parameters. So only if both params aren't NO_QUANTIZE
|
|
345
|
-
# do we expect the parameters to be the same.
|
|
346
|
-
if params1.parameters != params2.parameters:
|
|
347
|
-
return False
|
|
348
525
|
# We only need to check the first transformation because transformations are
|
|
349
526
|
# applied in order, and as long as the one that's immediately after the tensor
|
|
350
527
|
# is the same, it's compatible.
|
|
@@ -356,6 +533,12 @@ def _compatible_tensor_params(
|
|
|
356
533
|
if (
|
|
357
534
|
params1.transformations[0] in quantized_source_transformations
|
|
358
535
|
and params2.transformations[0] in quantized_source_transformations
|
|
536
|
+
and params1.parameters == params2.parameters
|
|
359
537
|
):
|
|
360
538
|
return True
|
|
361
539
|
return False
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def _is_constant_tensor(tensor: Any, buffers: Sequence[Any]) -> bool:
|
|
543
|
+
"""Check if the tensor is a constant tensor."""
|
|
544
|
+
return buffers[tensor.buffer].data is not None
|