mct-nightly 2.1.0.20240609.524__py3-none-any.whl → 2.1.0.20240611.428__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.1.0.20240609.524.dist-info → mct_nightly-2.1.0.20240611.428.dist-info}/METADATA +1 -1
- {mct_nightly-2.1.0.20240609.524.dist-info → mct_nightly-2.1.0.20240611.428.dist-info}/RECORD +28 -20
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/base_node.py +1 -4
- model_compression_toolkit/core/common/hessian/hessian_info_service.py +1 -1
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +1 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +10 -6
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +15 -7
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +30 -14
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +8 -7
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +108 -87
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +15 -13
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +29 -14
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +40 -14
- model_compression_toolkit/core/keras/reader/node_builder.py +3 -3
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +25 -23
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +10 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/__init__.py +16 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +222 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py +131 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +111 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/__init__.py +16 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +219 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py +131 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +110 -0
- {mct_nightly-2.1.0.20240609.524.dist-info → mct_nightly-2.1.0.20240611.428.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.1.0.20240609.524.dist-info → mct_nightly-2.1.0.20240611.428.dist-info}/WHEEL +0 -0
- {mct_nightly-2.1.0.20240609.524.dist-info → mct_nightly-2.1.0.20240611.428.dist-info}/top_level.txt +0 -0
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py
CHANGED
@@ -27,7 +27,7 @@ from model_compression_toolkit.constants import MIN_THRESHOLD, DEFAULT_TOL, DEFA
|
|
27
27
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import quantize_tensor, \
|
28
28
|
reshape_tensor_for_per_channel_search, uniform_quantize_tensor, get_output_shape
|
29
29
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import max_power_of_two, \
|
30
|
-
get_tensor_max
|
30
|
+
get_tensor_max, get_tensor_min
|
31
31
|
|
32
32
|
|
33
33
|
def qparams_selection_tensor_search(error_function: Callable,
|
@@ -56,41 +56,49 @@ def qparams_selection_tensor_search(error_function: Callable,
|
|
56
56
|
signed: a flag whether the tensor is signed.
|
57
57
|
|
58
58
|
Returns:
|
59
|
-
Optimal constrained threshold to quantize the tensor
|
59
|
+
Optimal constrained threshold to quantize the tensor, and best channel axis if input channel_axis was None,
|
60
|
+
else return the input channel axis.
|
60
61
|
|
61
62
|
"""
|
62
63
|
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
64
|
+
search_axes = range(len(tensor_data.shape)) if channel_axis is None and per_channel else [channel_axis]
|
65
|
+
total_error_list = []
|
66
|
+
th_list = []
|
67
|
+
for _axis in search_axes:
|
68
|
+
output_shape = get_output_shape(tensor_data.shape, _axis)
|
68
69
|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, channel_axis)
|
70
|
+
# First threshold to check is the constrained threshold based on the tensor's maximal value.
|
71
|
+
tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits)
|
72
|
+
threshold = 2 * max_power_of_two(tensor_max, min_threshold)
|
73
73
|
|
74
|
-
|
75
|
-
|
76
|
-
# is used for quantizing the tensor and computing the error. The error is appended to an error list, which
|
77
|
-
# eventually used to select the threshold with the minimal error.
|
78
|
-
for i in range(n_iter):
|
74
|
+
# Rearrange the tensor such that each sub-tensor is flattened, and we iterate over each
|
75
|
+
# one of them when searching for the threshold.
|
79
76
|
if per_channel:
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
77
|
+
tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, _axis)
|
78
|
+
|
79
|
+
error_list = [] # init an empty error list
|
80
|
+
# On each iteration a new constrained threshold which equal to half of the previous tested threshold
|
81
|
+
# is used for quantizing the tensor and computing the error. The error is appended to an error list, which
|
82
|
+
# eventually used to select the threshold with the minimal error.
|
83
|
+
for i in range(n_iter):
|
84
|
+
if per_channel:
|
85
|
+
threshold_hat = (threshold / (2 ** i)).reshape([-1, 1])
|
86
|
+
qt = quantize_tensor(tensor_data_r, threshold_hat, n_bits, signed)
|
87
|
+
per_channel_error = _error_function_wrapper(error_function, tensor_data_r, qt, threshold_hat)
|
88
|
+
error_list.append(per_channel_error)
|
89
|
+
else: # quantize per-tensor
|
90
|
+
qt = quantize_tensor(tensor_data, threshold / (2 ** i), n_bits, signed)
|
91
|
+
error = error_function(qt, tensor_data, threshold=threshold / (2 ** i))
|
92
|
+
error_list.append(error)
|
93
|
+
|
94
|
+
# Take the index of the minimal error, and use it compute the threshold which yielded it.
|
95
|
+
err_mat = np.stack(error_list, axis=-1)
|
96
|
+
i = np.argmin(err_mat, axis=-1)
|
97
|
+
th_list.append(np.maximum(np.reshape(threshold.flatten() / np.power(2, i), output_shape), min_threshold))
|
98
|
+
total_error_list.append(err_mat.min(axis=-1).mean())
|
99
|
+
|
100
|
+
best_axis_index = np.argmin(total_error_list)
|
101
|
+
return th_list[best_axis_index], search_axes[best_axis_index]
|
94
102
|
|
95
103
|
|
96
104
|
def qparams_selection_histogram_search(error_function: Callable,
|
@@ -390,13 +398,12 @@ def search_dynamic_range(base_range: np.ndarray, x: np.ndarray, scalers: np.ndar
|
|
390
398
|
|
391
399
|
def qparams_symmetric_selection_tensor_search(error_function: Callable,
|
392
400
|
tensor_data: np.ndarray,
|
393
|
-
tensor_max: np.ndarray,
|
394
401
|
n_bits: int,
|
395
402
|
per_channel: bool = False,
|
396
403
|
channel_axis: int = 1,
|
397
404
|
n_iter: int = SYMMETRIC_TENSOR_PER_CHANNEL_N_ITER,
|
398
405
|
min_threshold=MIN_THRESHOLD,
|
399
|
-
signed: bool = True) ->
|
406
|
+
signed: bool = True) -> Tuple[np.ndarray, int]:
|
400
407
|
"""
|
401
408
|
Search for optimal threshold (per-channel or per-tensor) for symmetric quantization of a tensor,
|
402
409
|
using the iterative optimizer method.
|
@@ -404,7 +411,6 @@ def qparams_symmetric_selection_tensor_search(error_function: Callable,
|
|
404
411
|
Args:
|
405
412
|
error_function: Function to compute the error between the original and quantized tensors.
|
406
413
|
tensor_data: Numpy array with tensor's content.
|
407
|
-
tensor_max: The max value of the tensor.
|
408
414
|
n_bits: Number of bits to quantize the tensor.
|
409
415
|
per_channel: Whether the tensor should be quantized per-channel or per-tensor.
|
410
416
|
channel_axis: Index of output channels dimension.
|
@@ -417,46 +423,55 @@ def qparams_symmetric_selection_tensor_search(error_function: Callable,
|
|
417
423
|
|
418
424
|
"""
|
419
425
|
|
420
|
-
|
426
|
+
search_axes = range(len(tensor_data.shape)) if channel_axis is None and per_channel else [channel_axis]
|
427
|
+
total_error_list = []
|
428
|
+
th_list = []
|
429
|
+
for _axis in search_axes:
|
430
|
+
tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits)
|
431
|
+
output_shape = get_output_shape(tensor_data.shape, _axis)
|
421
432
|
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
433
|
+
if per_channel:
|
434
|
+
# Rearrange the tensor such that each sub-tensor is flattened, and we iterate
|
435
|
+
# over each one of them when searching for the threshold.
|
436
|
+
tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, _axis)
|
437
|
+
max_tensor = np.maximum(min_threshold, tensor_max)
|
438
|
+
res = qparams_symmetric_iterative_minimization(x0=max_tensor,
|
439
|
+
x=tensor_data_r,
|
440
|
+
loss_fn=error_function, # gets float_tensor, fxp_tensor, threshold
|
441
|
+
n_bits=n_bits,
|
442
|
+
signed=signed,
|
443
|
+
n_intervals=SYMMETRIC_TENSOR_PER_CHANNEL_N_INTERVALS,
|
444
|
+
n_iter=SYMMETRIC_TENSOR_PER_CHANNEL_N_ITER,
|
445
|
+
dec_freq=SYMMETRIC_TENSOR_PER_CHANNEL_DEC_FREQ,
|
446
|
+
per_channel=True)
|
447
|
+
th = np.reshape(np.maximum(min_threshold, res['param']), output_shape)
|
448
|
+
else:
|
449
|
+
# quantize per-tensor
|
450
|
+
res = qparams_symmetric_iterative_minimization(x0=get_init_threshold(min_threshold, tensor_max),
|
451
|
+
x=tensor_data,
|
452
|
+
loss_fn=error_function,
|
453
|
+
n_bits=n_bits,
|
454
|
+
signed=signed,
|
455
|
+
n_intervals=SYMMETRIC_TENSOR_N_INTERVALS,
|
456
|
+
n_iter=SYMMETRIC_TENSOR_N_ITER,
|
457
|
+
dec_freq=SYMMETRIC_TENSOR_DEC_FREQ,
|
458
|
+
per_channel=False)
|
459
|
+
th = max(min_threshold, res['param'])
|
460
|
+
|
461
|
+
total_error_list.append(res['loss'].mean())
|
462
|
+
th_list.append(th)
|
463
|
+
|
464
|
+
best_axis_index = np.argmin(total_error_list)
|
465
|
+
return th_list[best_axis_index], search_axes[best_axis_index]
|
450
466
|
|
451
467
|
|
452
468
|
def qparams_uniform_selection_tensor_search(error_function: Callable,
|
453
469
|
tensor_data: np.ndarray,
|
454
|
-
tensor_min: np.ndarray,
|
455
|
-
tensor_max: np.ndarray,
|
456
470
|
n_bits: int,
|
457
471
|
per_channel: bool = False,
|
458
472
|
channel_axis: int = 1,
|
459
|
-
n_iter: int = UNIFORM_TENSOR_PER_CHANNEL_N_ITER
|
473
|
+
n_iter: int = UNIFORM_TENSOR_PER_CHANNEL_N_ITER,
|
474
|
+
) -> Tuple[Tuple[np.ndarray, np.ndarray], int]:
|
460
475
|
"""
|
461
476
|
Search for optimal quantization range (per-channel or per-tensor) for uniform quantization of a tensor,
|
462
477
|
using the iterative optimizer method and built-in scale factors
|
@@ -465,8 +480,6 @@ def qparams_uniform_selection_tensor_search(error_function: Callable,
|
|
465
480
|
Args:
|
466
481
|
error_function: Function to compute the error between the original and quantized tensors.
|
467
482
|
tensor_data: Numpy array with tensor's content.
|
468
|
-
tensor_min: The min value of the tensor.
|
469
|
-
tensor_max: The max value of the tensor.
|
470
483
|
n_bits: Number of bits to quantize the tensor.
|
471
484
|
per_channel: Whether the tensor should be quantized per-channel or per-tensor.
|
472
485
|
channel_axis: Index of output channels dimension.
|
@@ -477,17 +490,22 @@ def qparams_uniform_selection_tensor_search(error_function: Callable,
|
|
477
490
|
|
478
491
|
"""
|
479
492
|
|
480
|
-
|
493
|
+
search_axes = range(len(tensor_data.shape)) if channel_axis is None and per_channel else [channel_axis]
|
494
|
+
total_error_list = []
|
495
|
+
th_list = []
|
496
|
+
for _axis in search_axes:
|
497
|
+
tensor_min = get_tensor_min(tensor_data, per_channel, _axis)
|
498
|
+
tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits, is_uniform_quantization=True)
|
499
|
+
output_shape = get_output_shape(tensor_data.shape, _axis)
|
481
500
|
|
482
|
-
|
483
|
-
|
484
|
-
|
501
|
+
alpha = np.linspace(BOTTOM_FACTOR, UPPER_FACTOR, UNIFORM_TENSOR_N_SAMPLES)
|
502
|
+
beta = np.linspace(BOTTOM_FACTOR, UPPER_FACTOR, UNIFORM_TENSOR_N_SAMPLES)
|
503
|
+
scalers = np.asarray(list(itertools.product(alpha, beta)))
|
485
504
|
|
486
|
-
|
487
|
-
|
488
|
-
if per_channel:
|
505
|
+
# Rearrange the tensor such that each sub-tensor is flattened, and we iterate over
|
506
|
+
# each one of them when searching for the threshold.
|
489
507
|
if per_channel:
|
490
|
-
tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data,
|
508
|
+
tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, _axis)
|
491
509
|
tensor_min_max = np.column_stack([tensor_min.flatten(), tensor_max.flatten()])
|
492
510
|
res = iterative_uniform_dynamic_range_search(x0=tensor_min_max,
|
493
511
|
x=tensor_data_r,
|
@@ -496,18 +514,21 @@ def qparams_uniform_selection_tensor_search(error_function: Callable,
|
|
496
514
|
n_bits=n_bits,
|
497
515
|
n_iter=UNIFORM_TENSOR_PER_CHANNEL_N_ITER,
|
498
516
|
per_channel=True)
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
517
|
+
th_list.append((np.reshape(res['param'][:, 0], output_shape), np.reshape(res['param'][:, 1], output_shape)))
|
518
|
+
else:
|
519
|
+
# quantize per-tensor
|
520
|
+
res = iterative_uniform_dynamic_range_search(x0=np.array([tensor_min, tensor_max]),
|
521
|
+
x=tensor_data,
|
522
|
+
scalers=scalers,
|
523
|
+
loss_fn=error_function,
|
524
|
+
n_bits=n_bits,
|
525
|
+
n_iter=UNIFORM_TENSOR_N_ITER,
|
526
|
+
per_channel=False)
|
527
|
+
th_list.append(tuple(np.split(res['param'], 2)))
|
528
|
+
total_error_list.append(res['loss'].mean())
|
529
|
+
|
530
|
+
best_axis_index = np.argmin(total_error_list)
|
531
|
+
return th_list[best_axis_index], search_axes[best_axis_index]
|
511
532
|
|
512
533
|
|
513
534
|
def qparams_symmetric_selection_histogram_search(error_function: Callable,
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from typing import Dict, Any
|
15
|
+
from typing import Dict, Any, Tuple
|
16
16
|
|
17
17
|
import numpy as np
|
18
18
|
|
@@ -34,7 +34,7 @@ def get_weights_qparams(weights_attr_values: np.ndarray,
|
|
34
34
|
output_channels_axis: int,
|
35
35
|
node=None,
|
36
36
|
hessian_info_service: HessianInfoService = None,
|
37
|
-
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Dict[Any, Any]:
|
37
|
+
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Tuple[Dict[Any, Any], int]:
|
38
38
|
"""
|
39
39
|
Compute thresholds to quantize a kernel according to a NodeWeightsQuantizationConfig
|
40
40
|
instance.
|
@@ -50,22 +50,24 @@ def get_weights_qparams(weights_attr_values: np.ndarray,
|
|
50
50
|
|
51
51
|
Returns:
|
52
52
|
A dictionary with the quantization threshold of the kernel.
|
53
|
+
Selected quantization channel axis.
|
53
54
|
"""
|
54
55
|
if attr_quant_config.weights_quantization_params_fn is not None:
|
55
|
-
weights_params = attr_quant_config.weights_quantization_params_fn(
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
56
|
+
weights_params, output_channels_axis = attr_quant_config.weights_quantization_params_fn(
|
57
|
+
weights_attr_values,
|
58
|
+
p=attr_quant_config.l_p_value,
|
59
|
+
n_bits=attr_quant_config.weights_n_bits,
|
60
|
+
per_channel=attr_quant_config.weights_per_channel_threshold,
|
61
|
+
channel_axis=output_channels_axis,
|
62
|
+
min_threshold=weights_quant_config.min_threshold,
|
63
|
+
quant_error_method=attr_quant_config.weights_error_method,
|
64
|
+
node=node,
|
65
|
+
hessian_info_service=hessian_info_service,
|
66
|
+
num_hessian_samples=num_hessian_samples)
|
65
67
|
else:
|
66
68
|
weights_params = {}
|
67
69
|
|
68
|
-
return weights_params
|
70
|
+
return weights_params, output_channels_axis
|
69
71
|
|
70
72
|
|
71
73
|
def _get_kernel_channels_mapping(fw_info:FrameworkInfo,
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
import numpy as np
|
16
|
+
from typing import Union, Tuple, Dict
|
16
17
|
|
17
18
|
import model_compression_toolkit.core.common.quantization.quantization_config as qc
|
18
19
|
from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES
|
@@ -25,6 +26,8 @@ from model_compression_toolkit.core.common.quantization.quantization_params_gene
|
|
25
26
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import \
|
26
27
|
get_tensor_max
|
27
28
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
29
|
+
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
|
30
|
+
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import quantize_tensor
|
28
31
|
|
29
32
|
|
30
33
|
def symmetric_selection_tensor(tensor_data: np.ndarray,
|
@@ -37,7 +40,8 @@ def symmetric_selection_tensor(tensor_data: np.ndarray,
|
|
37
40
|
quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE,
|
38
41
|
node=None,
|
39
42
|
hessian_info_service: HessianInfoService = None,
|
40
|
-
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES
|
43
|
+
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES,
|
44
|
+
) -> Tuple[Dict[str, np.ndarray], int]:
|
41
45
|
"""
|
42
46
|
Compute the optimal threshold based on the provided QuantizationErrorMethod to quantize the tensor.
|
43
47
|
Different search is applied, depends on the value of the selected QuantizationErrorMethod.
|
@@ -47,7 +51,7 @@ def symmetric_selection_tensor(tensor_data: np.ndarray,
|
|
47
51
|
p: p-norm to use for the Lp-norm distance.
|
48
52
|
n_bits: Number of bits to quantize the tensor.
|
49
53
|
per_channel: Whether the quantization should be per-channel or not.
|
50
|
-
channel_axis: Output channel index.
|
54
|
+
channel_axis: Output channel index. if None, search for best axis.
|
51
55
|
n_iter: Number of iterations to search for the optimal threshold (not used for this method).
|
52
56
|
min_threshold: Minimal threshold to use if threshold is too small (not used for this method).
|
53
57
|
quant_error_method: an error function to optimize the parameters' selection accordingly.
|
@@ -57,12 +61,24 @@ def symmetric_selection_tensor(tensor_data: np.ndarray,
|
|
57
61
|
|
58
62
|
Returns:
|
59
63
|
Optimal threshold to quantize the tensor in a symmetric manner.
|
64
|
+
Selected quantization channel axis.
|
60
65
|
"""
|
61
66
|
|
62
|
-
tensor_max = get_tensor_max(tensor_data, per_channel, channel_axis, n_bits)
|
63
|
-
|
64
67
|
if quant_error_method == qc.QuantizationErrorMethod.NOCLIPPING:
|
65
|
-
|
68
|
+
if channel_axis is None and per_channel:
|
69
|
+
total_error_list = []
|
70
|
+
th_list = []
|
71
|
+
for _axis in range(len(tensor_data.shape)):
|
72
|
+
tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits)
|
73
|
+
threshold = get_init_threshold(min_threshold, tensor_max, per_channel)
|
74
|
+
q_tensor_data = quantize_tensor(tensor_data, threshold, n_bits, True)
|
75
|
+
total_error_list.append(compute_mse(tensor_data, q_tensor_data, norm=True))
|
76
|
+
th_list.append(threshold)
|
77
|
+
channel_axis = np.argmin(total_error_list)
|
78
|
+
threshold = th_list[channel_axis]
|
79
|
+
else:
|
80
|
+
tensor_max = get_tensor_max(tensor_data, per_channel, channel_axis, n_bits)
|
81
|
+
threshold = get_init_threshold(min_threshold, tensor_max, per_channel)
|
66
82
|
else:
|
67
83
|
signed = True # weights are always signed
|
68
84
|
axis = -1 if per_channel else None
|
@@ -71,15 +87,14 @@ def symmetric_selection_tensor(tensor_data: np.ndarray,
|
|
71
87
|
signed=signed, node=node,
|
72
88
|
hessian_info_service=hessian_info_service,
|
73
89
|
num_hessian_samples=num_hessian_samples)
|
74
|
-
threshold = qparams_symmetric_selection_tensor_search(error_function,
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
return {THRESHOLD: threshold}
|
90
|
+
threshold, channel_axis = qparams_symmetric_selection_tensor_search(error_function,
|
91
|
+
tensor_data,
|
92
|
+
n_bits,
|
93
|
+
per_channel,
|
94
|
+
channel_axis,
|
95
|
+
min_threshold=min_threshold,
|
96
|
+
signed=signed)
|
97
|
+
return {THRESHOLD: threshold}, channel_axis
|
83
98
|
|
84
99
|
|
85
100
|
def symmetric_selection_histogram(bins: np.ndarray,
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
import numpy as np
|
16
|
+
from typing import Union, Tuple, Dict
|
16
17
|
|
17
18
|
import model_compression_toolkit.core.common.quantization.quantization_config as qc
|
18
19
|
from model_compression_toolkit.constants import MIN_THRESHOLD, RANGE_MIN, RANGE_MAX, NUM_QPARAM_HESSIAN_SAMPLES
|
@@ -24,6 +25,9 @@ from model_compression_toolkit.core.common.quantization.quantization_params_gene
|
|
24
25
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import get_tensor_max, \
|
25
26
|
get_tensor_min
|
26
27
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
28
|
+
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
|
29
|
+
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import uniform_quantize_tensor
|
30
|
+
|
27
31
|
|
28
32
|
def uniform_selection_tensor(tensor_data: np.ndarray,
|
29
33
|
p: int,
|
@@ -35,7 +39,8 @@ def uniform_selection_tensor(tensor_data: np.ndarray,
|
|
35
39
|
quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE,
|
36
40
|
node=None,
|
37
41
|
hessian_info_service: HessianInfoService = None,
|
38
|
-
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES
|
42
|
+
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES,
|
43
|
+
) -> Tuple[Dict[str, np.ndarray], int]:
|
39
44
|
"""
|
40
45
|
Compute the optimal quantization range based on the provided QuantizationErrorMethod
|
41
46
|
to uniformly quantize the tensor.
|
@@ -46,7 +51,7 @@ def uniform_selection_tensor(tensor_data: np.ndarray,
|
|
46
51
|
p: p-norm to use for the Lp-norm distance.
|
47
52
|
n_bits: Number of bits to quantize the tensor.
|
48
53
|
per_channel: Whether the quantization should be per-channel or not.
|
49
|
-
channel_axis: Output channel index.
|
54
|
+
channel_axis: Output channel index. if None, search for best axis.
|
50
55
|
n_iter: Number of iterations to search for the optimal threshold (not used for this method).
|
51
56
|
min_threshold: Minimal threshold to use if threshold is too small (not used for this method).
|
52
57
|
quant_error_method: an error function to optimize the range parameters' selection accordingly.
|
@@ -56,27 +61,48 @@ def uniform_selection_tensor(tensor_data: np.ndarray,
|
|
56
61
|
|
57
62
|
Returns:
|
58
63
|
Optimal quantization range to quantize the tensor uniformly.
|
64
|
+
Selected quantization channel axis.
|
59
65
|
"""
|
60
|
-
tensor_min = get_tensor_min(tensor_data, per_channel, channel_axis)
|
61
|
-
tensor_max = get_tensor_max(tensor_data, per_channel, channel_axis, n_bits, is_uniform_quantization=True)
|
62
|
-
|
63
66
|
if quant_error_method == qc.QuantizationErrorMethod.NOCLIPPING:
|
64
|
-
|
67
|
+
if channel_axis is None and per_channel:
|
68
|
+
total_error_list = []
|
69
|
+
th_list = []
|
70
|
+
for _axis in range(len(tensor_data.shape)):
|
71
|
+
tensor_min = get_tensor_min(tensor_data, per_channel, _axis)
|
72
|
+
tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits, is_uniform_quantization=True)
|
73
|
+
q_tensor_data = uniform_quantize_tensor(tensor_data, tensor_min, tensor_max, n_bits)
|
74
|
+
total_error_list.append(compute_mse(tensor_data, q_tensor_data, norm=True))
|
75
|
+
th_list.append((tensor_min, tensor_max))
|
76
|
+
channel_axis = np.argmin(total_error_list)
|
77
|
+
mm = th_list[channel_axis]
|
78
|
+
else:
|
79
|
+
tensor_min = get_tensor_min(tensor_data, per_channel, channel_axis)
|
80
|
+
tensor_max = get_tensor_max(tensor_data, per_channel, channel_axis, n_bits, is_uniform_quantization=True)
|
81
|
+
mm = tensor_min, tensor_max
|
65
82
|
else:
|
66
83
|
axis = -1 if per_channel else None
|
67
84
|
error_function = get_threshold_selection_tensor_error_function(QuantizationMethod.UNIFORM, quant_error_method,
|
68
85
|
p, axis=axis, norm=False, node=node,
|
69
86
|
hessian_info_service=hessian_info_service,
|
70
87
|
num_hessian_samples=num_hessian_samples)
|
71
|
-
mm = qparams_uniform_selection_tensor_search(error_function,
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
88
|
+
mm, channel_axis = qparams_uniform_selection_tensor_search(error_function,
|
89
|
+
tensor_data,
|
90
|
+
n_bits,
|
91
|
+
per_channel,
|
92
|
+
channel_axis)
|
93
|
+
# In case the tensor\axis has a single value, then min==max, so need to adjust either min or max to zero.
|
94
|
+
if not isinstance(mm[0], np.ndarray):
|
95
|
+
if mm[0] > 0:
|
96
|
+
mm = (np.float32(0).astype(mm[0].dtype), mm[1])
|
97
|
+
if mm[1] < 0:
|
98
|
+
mm = (mm[0], np.float32(0).astype(mm[1].dtype))
|
99
|
+
else:
|
100
|
+
adj_min_to_zero = np.logical_and(mm[1] == mm[0], mm[0] > 0)
|
101
|
+
adj_max_to_zero = np.logical_and(mm[1] == mm[0], mm[1] < 0)
|
102
|
+
mm[0][adj_min_to_zero] = 0
|
103
|
+
mm[1][adj_max_to_zero] = 0
|
78
104
|
return {RANGE_MIN: mm[0],
|
79
|
-
RANGE_MAX: mm[1]}
|
105
|
+
RANGE_MAX: mm[1]}, channel_axis
|
80
106
|
|
81
107
|
|
82
108
|
def uniform_selection_histogram(bins: np.ndarray,
|
@@ -158,7 +158,8 @@ def build_node(node: KerasNode,
|
|
158
158
|
if is_const(arg) or (
|
159
159
|
keras_layer.symbol in tf_function_symbols and
|
160
160
|
isinstance(arg, (tuple, list))):
|
161
|
-
|
161
|
+
if inputs_as_list or i in kwarg2index.values():
|
162
|
+
weights.update({i: to_numpy(arg, is_single_tensor=True)})
|
162
163
|
# remove weights and KerasTensors and weights from op_call_args
|
163
164
|
if inputs_as_list:
|
164
165
|
op_call_args = tuple(op_call_args[1:])
|
@@ -169,8 +170,7 @@ def build_node(node: KerasNode,
|
|
169
170
|
# read weights from call kwargs
|
170
171
|
weight_keys = []
|
171
172
|
for k, v in op_call_kwargs.items():
|
172
|
-
if is_const(v) or (keras_layer.
|
173
|
-
tf.matmul] and
|
173
|
+
if is_const(v) or (keras_layer.symbol in tf_function_symbols and
|
174
174
|
isinstance(v, (tuple, list))):
|
175
175
|
if k in kwarg2index:
|
176
176
|
weights.update({kwarg2index[k]: to_numpy(v, is_single_tensor=True)})
|
@@ -39,20 +39,16 @@ from mct_quantizers import PytorchQuantizationWrapper
|
|
39
39
|
def _build_input_tensors_list(node: BaseNode,
|
40
40
|
graph: Graph,
|
41
41
|
inputs: Tuple[Any],
|
42
|
-
node_to_output_tensors_dict: Dict[BaseNode, List]
|
43
|
-
is_op_quantize_wrapper: bool) -> List[List]:
|
42
|
+
node_to_output_tensors_dict: Dict[BaseNode, List]) -> List[List]:
|
44
43
|
"""
|
45
44
|
Given a node, build a list of input tensors the node gets. The list is built based on the
|
46
|
-
node's incoming edges, previous nodes' output tensors
|
47
|
-
Positional weights aren't used if the node's op is PytorchQuantizationWrapper, since it's
|
48
|
-
positional weights are already in the wrapper.
|
45
|
+
node's incoming edges, previous nodes' output tensors.
|
49
46
|
|
50
47
|
Args:
|
51
48
|
node: Node to build its input tensors list.
|
52
49
|
graph: Graph the node is in.
|
53
50
|
inputs: list of input tensors to model.
|
54
51
|
node_to_output_tensors_dict: A dictionary from a node to its output tensors.
|
55
|
-
is_op_quantize_wrapper: Whether the func_op is a PytorchQuantizationWrapper or not.
|
56
52
|
|
57
53
|
Returns:
|
58
54
|
A list of the node's input tensors.
|
@@ -67,35 +63,30 @@ def _build_input_tensors_list(node: BaseNode,
|
|
67
63
|
_input_tensors = node_to_output_tensors_dict[ie.source_node]
|
68
64
|
input_tensors.append(_input_tensors)
|
69
65
|
input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists
|
70
|
-
input_tensors = node.insert_positional_weights_to_input_list(input_tensors)
|
71
|
-
# convert inputs from positional weights (numpy arrays) to tensors. Must handle each element in the
|
72
|
-
# list separately, because in FX the tensors are FX objects and fail to_torch_tensor
|
73
|
-
input_tensors = [to_torch_tensor(t, numpy_type=t.dtype) if isinstance(t, np.ndarray) else t
|
74
|
-
for t in input_tensors]
|
75
66
|
return input_tensors
|
76
67
|
|
77
68
|
|
78
69
|
def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List,
|
79
|
-
|
70
|
+
tensor_input_indices: List = None) -> List:
|
80
71
|
"""
|
81
|
-
Merge input tensors list with op_call_args, according to correct order.
|
72
|
+
Merge input tensors list with positional weights and op_call_args, according to correct order.
|
82
73
|
|
83
74
|
Args:
|
84
75
|
_node: The node the inputs are for.
|
85
76
|
input_tensors: activation input tensors to node.
|
86
77
|
op_call_args: framework node call args.
|
87
|
-
|
78
|
+
|
88
79
|
Returns:
|
89
80
|
Combined list of input_tensors and op_call_args.
|
90
81
|
"""
|
91
82
|
if isinstance(_node, FunctionalNode) and _node.tensor_input_indices:
|
92
83
|
_input_list = op_call_args.copy()
|
93
|
-
if
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
84
|
+
if tensor_input_indices is None:
|
85
|
+
tensor_input_indices = _node.tensor_input_indices
|
86
|
+
assert len(tensor_input_indices) == len(input_tensors), \
|
87
|
+
f'Mismatch between input tensors ({len(tensor_input_indices)}) and indices {len(input_tensors)}'
|
88
|
+
for i, t in zip(tensor_input_indices, input_tensors):
|
89
|
+
_input_list.insert(i, t)
|
99
90
|
else:
|
100
91
|
_input_list = input_tensors + op_call_args
|
101
92
|
|
@@ -126,10 +117,22 @@ def _run_operation(n: BaseNode,
|
|
126
117
|
op_call_args = n.op_call_args if isinstance(n, FunctionalNode) else []
|
127
118
|
functional_kwargs = n.op_call_kwargs if isinstance(n, FunctionalNode) else {}
|
128
119
|
|
120
|
+
if not (isinstance(n, FunctionalNode) and isinstance(op_func, PytorchQuantizationWrapper)):
|
121
|
+
# Insert positional weights only when not a quantized functional node, because quantized functional nodes
|
122
|
+
# insert the quantized weights in the wrapper.
|
123
|
+
input_tensors = n.insert_positional_weights_to_input_list(input_tensors)
|
124
|
+
# convert inputs from positional weights (numpy arrays) to tensors. Must handle each element in the
|
125
|
+
# list separately, because in FX the tensors are FX objects and fail to_torch_tensor
|
126
|
+
input_tensors = [to_torch_tensor(t, numpy_type=t.dtype) if isinstance(t, np.ndarray) else t
|
127
|
+
for t in input_tensors]
|
128
|
+
_tensor_input_indices = None
|
129
|
+
else:
|
130
|
+
_tensor_input_indices = [i for i in n.tensor_input_indices if i not in n.weights]
|
131
|
+
|
129
132
|
if isinstance(n, FunctionalNode) and n.inputs_as_list:
|
130
133
|
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
|
131
134
|
else:
|
132
|
-
merged_inputs = _merge_inputs(n, input_tensors, op_call_args,
|
135
|
+
merged_inputs = _merge_inputs(n, input_tensors, op_call_args, tensor_input_indices=_tensor_input_indices)
|
133
136
|
out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs)
|
134
137
|
|
135
138
|
# Add a fake quant node if the node has an activation threshold.
|
@@ -295,8 +298,7 @@ class PytorchModel(torch.nn.Module):
|
|
295
298
|
input_tensors = _build_input_tensors_list(node,
|
296
299
|
self.graph,
|
297
300
|
args,
|
298
|
-
node_to_output_tensors_dict
|
299
|
-
isinstance(op_func, PytorchQuantizationWrapper))
|
301
|
+
node_to_output_tensors_dict)
|
300
302
|
use_activation_quantization, activation_quantization_fn = self._get_activation_quantization_fn(node)
|
301
303
|
|
302
304
|
# Run node operation and fetch outputs
|