mct-nightly 2.1.0.20240609.524__py3-none-any.whl → 2.1.0.20240611.428__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {mct_nightly-2.1.0.20240609.524.dist-info → mct_nightly-2.1.0.20240611.428.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.1.0.20240609.524.dist-info → mct_nightly-2.1.0.20240611.428.dist-info}/RECORD +28 -20
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/graph/base_node.py +1 -4
  5. model_compression_toolkit/core/common/hessian/hessian_info_service.py +1 -1
  6. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +1 -1
  7. model_compression_toolkit/core/common/quantization/node_quantization_config.py +10 -6
  8. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +15 -7
  9. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +30 -14
  10. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +8 -7
  11. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +108 -87
  12. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +15 -13
  13. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +29 -14
  14. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +40 -14
  15. model_compression_toolkit/core/keras/reader/node_builder.py +3 -3
  16. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +25 -23
  17. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +10 -0
  18. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/__init__.py +16 -0
  19. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +222 -0
  20. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py +131 -0
  21. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +111 -0
  22. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/__init__.py +16 -0
  23. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +219 -0
  24. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py +131 -0
  25. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +110 -0
  26. {mct_nightly-2.1.0.20240609.524.dist-info → mct_nightly-2.1.0.20240611.428.dist-info}/LICENSE.md +0 -0
  27. {mct_nightly-2.1.0.20240609.524.dist-info → mct_nightly-2.1.0.20240611.428.dist-info}/WHEEL +0 -0
  28. {mct_nightly-2.1.0.20240609.524.dist-info → mct_nightly-2.1.0.20240611.428.dist-info}/top_level.txt +0 -0
@@ -27,7 +27,7 @@ from model_compression_toolkit.constants import MIN_THRESHOLD, DEFAULT_TOL, DEFA
27
27
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import quantize_tensor, \
28
28
  reshape_tensor_for_per_channel_search, uniform_quantize_tensor, get_output_shape
29
29
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import max_power_of_two, \
30
- get_tensor_max
30
+ get_tensor_max, get_tensor_min
31
31
 
32
32
 
33
33
  def qparams_selection_tensor_search(error_function: Callable,
@@ -56,41 +56,49 @@ def qparams_selection_tensor_search(error_function: Callable,
56
56
  signed: a flag whether the tensor is signed.
57
57
 
58
58
  Returns:
59
- Optimal constrained threshold to quantize the tensor.
59
+ Optimal constrained threshold to quantize the tensor, and best channel axis if input channel_axis was None,
60
+ else return the input channel axis.
60
61
 
61
62
  """
62
63
 
63
- output_shape = get_output_shape(tensor_data.shape, channel_axis)
64
-
65
- # First threshold to check is the constrained threshold based on the tensor's maximal value.
66
- tensor_max = get_tensor_max(tensor_data, per_channel, channel_axis, n_bits)
67
- threshold = 2 * max_power_of_two(tensor_max, min_threshold)
64
+ search_axes = range(len(tensor_data.shape)) if channel_axis is None and per_channel else [channel_axis]
65
+ total_error_list = []
66
+ th_list = []
67
+ for _axis in search_axes:
68
+ output_shape = get_output_shape(tensor_data.shape, _axis)
68
69
 
69
- # If the threshold is computed per-channel, we rearrange the tensor such that each sub-tensor
70
- # is flattened, and we iterate over each one of them when searching for the threshold.
71
- if per_channel:
72
- tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, channel_axis)
70
+ # First threshold to check is the constrained threshold based on the tensor's maximal value.
71
+ tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits)
72
+ threshold = 2 * max_power_of_two(tensor_max, min_threshold)
73
73
 
74
- error_list = [] # init an empty error list
75
- # On each iteration a new constrained threshold which equal to half of the previous tested threshold
76
- # is used for quantizing the tensor and computing the error. The error is appended to an error list, which
77
- # eventually used to select the threshold with the minimal error.
78
- for i in range(n_iter):
74
+ # Rearrange the tensor such that each sub-tensor is flattened, and we iterate over each
75
+ # one of them when searching for the threshold.
79
76
  if per_channel:
80
- threshold_hat = (threshold / (2 ** i)).reshape([-1, 1])
81
- qt = quantize_tensor(tensor_data_r, threshold_hat, n_bits, signed)
82
- per_channel_error = _error_function_wrapper(error_function, tensor_data_r, qt, threshold_hat)
83
-
84
- error_list.append(per_channel_error)
85
- else: # quantize per-tensor
86
- qt = quantize_tensor(tensor_data, threshold / (2 ** i), n_bits, signed)
87
- error = error_function(qt, tensor_data, threshold=threshold / (2 ** i))
88
- error_list.append(error)
89
-
90
- # Take the index of the minimal error, and use it compute the threshold which yielded it.
91
- i = np.argmin(np.stack(error_list, axis=-1), axis=-1)
92
-
93
- return np.maximum(np.reshape(threshold.flatten() / np.power(2, i), output_shape), min_threshold)
77
+ tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, _axis)
78
+
79
+ error_list = [] # init an empty error list
80
+ # On each iteration a new constrained threshold which equal to half of the previous tested threshold
81
+ # is used for quantizing the tensor and computing the error. The error is appended to an error list, which
82
+ # eventually used to select the threshold with the minimal error.
83
+ for i in range(n_iter):
84
+ if per_channel:
85
+ threshold_hat = (threshold / (2 ** i)).reshape([-1, 1])
86
+ qt = quantize_tensor(tensor_data_r, threshold_hat, n_bits, signed)
87
+ per_channel_error = _error_function_wrapper(error_function, tensor_data_r, qt, threshold_hat)
88
+ error_list.append(per_channel_error)
89
+ else: # quantize per-tensor
90
+ qt = quantize_tensor(tensor_data, threshold / (2 ** i), n_bits, signed)
91
+ error = error_function(qt, tensor_data, threshold=threshold / (2 ** i))
92
+ error_list.append(error)
93
+
94
+ # Take the index of the minimal error, and use it compute the threshold which yielded it.
95
+ err_mat = np.stack(error_list, axis=-1)
96
+ i = np.argmin(err_mat, axis=-1)
97
+ th_list.append(np.maximum(np.reshape(threshold.flatten() / np.power(2, i), output_shape), min_threshold))
98
+ total_error_list.append(err_mat.min(axis=-1).mean())
99
+
100
+ best_axis_index = np.argmin(total_error_list)
101
+ return th_list[best_axis_index], search_axes[best_axis_index]
94
102
 
95
103
 
96
104
  def qparams_selection_histogram_search(error_function: Callable,
@@ -390,13 +398,12 @@ def search_dynamic_range(base_range: np.ndarray, x: np.ndarray, scalers: np.ndar
390
398
 
391
399
  def qparams_symmetric_selection_tensor_search(error_function: Callable,
392
400
  tensor_data: np.ndarray,
393
- tensor_max: np.ndarray,
394
401
  n_bits: int,
395
402
  per_channel: bool = False,
396
403
  channel_axis: int = 1,
397
404
  n_iter: int = SYMMETRIC_TENSOR_PER_CHANNEL_N_ITER,
398
405
  min_threshold=MIN_THRESHOLD,
399
- signed: bool = True) -> Any:
406
+ signed: bool = True) -> Tuple[np.ndarray, int]:
400
407
  """
401
408
  Search for optimal threshold (per-channel or per-tensor) for symmetric quantization of a tensor,
402
409
  using the iterative optimizer method.
@@ -404,7 +411,6 @@ def qparams_symmetric_selection_tensor_search(error_function: Callable,
404
411
  Args:
405
412
  error_function: Function to compute the error between the original and quantized tensors.
406
413
  tensor_data: Numpy array with tensor's content.
407
- tensor_max: The max value of the tensor.
408
414
  n_bits: Number of bits to quantize the tensor.
409
415
  per_channel: Whether the tensor should be quantized per-channel or per-tensor.
410
416
  channel_axis: Index of output channels dimension.
@@ -417,46 +423,55 @@ def qparams_symmetric_selection_tensor_search(error_function: Callable,
417
423
 
418
424
  """
419
425
 
420
- output_shape = get_output_shape(tensor_data.shape, channel_axis)
426
+ search_axes = range(len(tensor_data.shape)) if channel_axis is None and per_channel else [channel_axis]
427
+ total_error_list = []
428
+ th_list = []
429
+ for _axis in search_axes:
430
+ tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits)
431
+ output_shape = get_output_shape(tensor_data.shape, _axis)
421
432
 
422
- # If the threshold is computed per-channel, we rearrange the tensor such that each sub-tensor
423
- # is flattened, and we iterate over each one of them when searching for the threshold.
424
- if per_channel:
425
- tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, channel_axis)
426
- max_tensor = np.maximum(min_threshold, tensor_max)
427
- res = qparams_symmetric_iterative_minimization(x0=max_tensor,
428
- x=tensor_data_r,
429
- loss_fn=error_function, # gets float_tensor, fxp_tensor, threshold
430
- n_bits=n_bits,
431
- signed=signed,
432
- n_intervals=SYMMETRIC_TENSOR_PER_CHANNEL_N_INTERVALS,
433
- n_iter=SYMMETRIC_TENSOR_PER_CHANNEL_N_ITER,
434
- dec_freq=SYMMETRIC_TENSOR_PER_CHANNEL_DEC_FREQ,
435
- per_channel=True)
436
- return np.reshape(np.maximum(min_threshold, res['param']), output_shape)
437
- else:
438
- # quantize per-tensor
439
- res = qparams_symmetric_iterative_minimization(x0=get_init_threshold(min_threshold, tensor_max),
440
- x=tensor_data,
441
- loss_fn=error_function,
442
- n_bits=n_bits,
443
- signed=signed,
444
- n_intervals=SYMMETRIC_TENSOR_N_INTERVALS,
445
- n_iter=SYMMETRIC_TENSOR_N_ITER,
446
- dec_freq=SYMMETRIC_TENSOR_DEC_FREQ,
447
- per_channel=False)
448
-
449
- return max(min_threshold, res['param'])
433
+ if per_channel:
434
+ # Rearrange the tensor such that each sub-tensor is flattened, and we iterate
435
+ # over each one of them when searching for the threshold.
436
+ tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, _axis)
437
+ max_tensor = np.maximum(min_threshold, tensor_max)
438
+ res = qparams_symmetric_iterative_minimization(x0=max_tensor,
439
+ x=tensor_data_r,
440
+ loss_fn=error_function, # gets float_tensor, fxp_tensor, threshold
441
+ n_bits=n_bits,
442
+ signed=signed,
443
+ n_intervals=SYMMETRIC_TENSOR_PER_CHANNEL_N_INTERVALS,
444
+ n_iter=SYMMETRIC_TENSOR_PER_CHANNEL_N_ITER,
445
+ dec_freq=SYMMETRIC_TENSOR_PER_CHANNEL_DEC_FREQ,
446
+ per_channel=True)
447
+ th = np.reshape(np.maximum(min_threshold, res['param']), output_shape)
448
+ else:
449
+ # quantize per-tensor
450
+ res = qparams_symmetric_iterative_minimization(x0=get_init_threshold(min_threshold, tensor_max),
451
+ x=tensor_data,
452
+ loss_fn=error_function,
453
+ n_bits=n_bits,
454
+ signed=signed,
455
+ n_intervals=SYMMETRIC_TENSOR_N_INTERVALS,
456
+ n_iter=SYMMETRIC_TENSOR_N_ITER,
457
+ dec_freq=SYMMETRIC_TENSOR_DEC_FREQ,
458
+ per_channel=False)
459
+ th = max(min_threshold, res['param'])
460
+
461
+ total_error_list.append(res['loss'].mean())
462
+ th_list.append(th)
463
+
464
+ best_axis_index = np.argmin(total_error_list)
465
+ return th_list[best_axis_index], search_axes[best_axis_index]
450
466
 
451
467
 
452
468
  def qparams_uniform_selection_tensor_search(error_function: Callable,
453
469
  tensor_data: np.ndarray,
454
- tensor_min: np.ndarray,
455
- tensor_max: np.ndarray,
456
470
  n_bits: int,
457
471
  per_channel: bool = False,
458
472
  channel_axis: int = 1,
459
- n_iter: int = UNIFORM_TENSOR_PER_CHANNEL_N_ITER) -> Any:
473
+ n_iter: int = UNIFORM_TENSOR_PER_CHANNEL_N_ITER,
474
+ ) -> Tuple[Tuple[np.ndarray, np.ndarray], int]:
460
475
  """
461
476
  Search for optimal quantization range (per-channel or per-tensor) for uniform quantization of a tensor,
462
477
  using the iterative optimizer method and built-in scale factors
@@ -465,8 +480,6 @@ def qparams_uniform_selection_tensor_search(error_function: Callable,
465
480
  Args:
466
481
  error_function: Function to compute the error between the original and quantized tensors.
467
482
  tensor_data: Numpy array with tensor's content.
468
- tensor_min: The min value of the tensor.
469
- tensor_max: The max value of the tensor.
470
483
  n_bits: Number of bits to quantize the tensor.
471
484
  per_channel: Whether the tensor should be quantized per-channel or per-tensor.
472
485
  channel_axis: Index of output channels dimension.
@@ -477,17 +490,22 @@ def qparams_uniform_selection_tensor_search(error_function: Callable,
477
490
 
478
491
  """
479
492
 
480
- output_shape = get_output_shape(tensor_data.shape, channel_axis)
493
+ search_axes = range(len(tensor_data.shape)) if channel_axis is None and per_channel else [channel_axis]
494
+ total_error_list = []
495
+ th_list = []
496
+ for _axis in search_axes:
497
+ tensor_min = get_tensor_min(tensor_data, per_channel, _axis)
498
+ tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits, is_uniform_quantization=True)
499
+ output_shape = get_output_shape(tensor_data.shape, _axis)
481
500
 
482
- alpha = np.linspace(BOTTOM_FACTOR, UPPER_FACTOR, UNIFORM_TENSOR_N_SAMPLES)
483
- beta = np.linspace(BOTTOM_FACTOR, UPPER_FACTOR, UNIFORM_TENSOR_N_SAMPLES)
484
- scalers = np.asarray(list(itertools.product(alpha, beta)))
501
+ alpha = np.linspace(BOTTOM_FACTOR, UPPER_FACTOR, UNIFORM_TENSOR_N_SAMPLES)
502
+ beta = np.linspace(BOTTOM_FACTOR, UPPER_FACTOR, UNIFORM_TENSOR_N_SAMPLES)
503
+ scalers = np.asarray(list(itertools.product(alpha, beta)))
485
504
 
486
- # If the threshold is computed per-channel, we rearrange the tensor such that each sub-tensor
487
- # is flattened, and we iterate over each one of them when searching for the threshold.
488
- if per_channel:
505
+ # Rearrange the tensor such that each sub-tensor is flattened, and we iterate over
506
+ # each one of them when searching for the threshold.
489
507
  if per_channel:
490
- tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, channel_axis)
508
+ tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, _axis)
491
509
  tensor_min_max = np.column_stack([tensor_min.flatten(), tensor_max.flatten()])
492
510
  res = iterative_uniform_dynamic_range_search(x0=tensor_min_max,
493
511
  x=tensor_data_r,
@@ -496,18 +514,21 @@ def qparams_uniform_selection_tensor_search(error_function: Callable,
496
514
  n_bits=n_bits,
497
515
  n_iter=UNIFORM_TENSOR_PER_CHANNEL_N_ITER,
498
516
  per_channel=True)
499
- return np.reshape(res['param'][:, 0], output_shape), np.reshape(res['param'][:, 1], output_shape)
500
- else:
501
- # quantize per-tensor
502
- pass
503
- res = iterative_uniform_dynamic_range_search(x0=np.array([tensor_min, tensor_max]),
504
- x=tensor_data,
505
- scalers=scalers,
506
- loss_fn=error_function,
507
- n_bits=n_bits,
508
- n_iter=UNIFORM_TENSOR_N_ITER,
509
- per_channel=False)
510
- return res['param']
517
+ th_list.append((np.reshape(res['param'][:, 0], output_shape), np.reshape(res['param'][:, 1], output_shape)))
518
+ else:
519
+ # quantize per-tensor
520
+ res = iterative_uniform_dynamic_range_search(x0=np.array([tensor_min, tensor_max]),
521
+ x=tensor_data,
522
+ scalers=scalers,
523
+ loss_fn=error_function,
524
+ n_bits=n_bits,
525
+ n_iter=UNIFORM_TENSOR_N_ITER,
526
+ per_channel=False)
527
+ th_list.append(tuple(np.split(res['param'], 2)))
528
+ total_error_list.append(res['loss'].mean())
529
+
530
+ best_axis_index = np.argmin(total_error_list)
531
+ return th_list[best_axis_index], search_axes[best_axis_index]
511
532
 
512
533
 
513
534
  def qparams_symmetric_selection_histogram_search(error_function: Callable,
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Dict, Any
15
+ from typing import Dict, Any, Tuple
16
16
 
17
17
  import numpy as np
18
18
 
@@ -34,7 +34,7 @@ def get_weights_qparams(weights_attr_values: np.ndarray,
34
34
  output_channels_axis: int,
35
35
  node=None,
36
36
  hessian_info_service: HessianInfoService = None,
37
- num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Dict[Any, Any]:
37
+ num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Tuple[Dict[Any, Any], int]:
38
38
  """
39
39
  Compute thresholds to quantize a kernel according to a NodeWeightsQuantizationConfig
40
40
  instance.
@@ -50,22 +50,24 @@ def get_weights_qparams(weights_attr_values: np.ndarray,
50
50
 
51
51
  Returns:
52
52
  A dictionary with the quantization threshold of the kernel.
53
+ Selected quantization channel axis.
53
54
  """
54
55
  if attr_quant_config.weights_quantization_params_fn is not None:
55
- weights_params = attr_quant_config.weights_quantization_params_fn(weights_attr_values,
56
- p=attr_quant_config.l_p_value,
57
- n_bits=attr_quant_config.weights_n_bits,
58
- per_channel=attr_quant_config.weights_per_channel_threshold and output_channels_axis is not None,
59
- channel_axis=output_channels_axis,
60
- min_threshold=weights_quant_config.min_threshold,
61
- quant_error_method=attr_quant_config.weights_error_method,
62
- node=node,
63
- hessian_info_service=hessian_info_service,
64
- num_hessian_samples=num_hessian_samples)
56
+ weights_params, output_channels_axis = attr_quant_config.weights_quantization_params_fn(
57
+ weights_attr_values,
58
+ p=attr_quant_config.l_p_value,
59
+ n_bits=attr_quant_config.weights_n_bits,
60
+ per_channel=attr_quant_config.weights_per_channel_threshold,
61
+ channel_axis=output_channels_axis,
62
+ min_threshold=weights_quant_config.min_threshold,
63
+ quant_error_method=attr_quant_config.weights_error_method,
64
+ node=node,
65
+ hessian_info_service=hessian_info_service,
66
+ num_hessian_samples=num_hessian_samples)
65
67
  else:
66
68
  weights_params = {}
67
69
 
68
- return weights_params
70
+ return weights_params, output_channels_axis
69
71
 
70
72
 
71
73
  def _get_kernel_channels_mapping(fw_info:FrameworkInfo,
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import numpy as np
16
+ from typing import Union, Tuple, Dict
16
17
 
17
18
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
18
19
  from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES
@@ -25,6 +26,8 @@ from model_compression_toolkit.core.common.quantization.quantization_params_gene
25
26
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import \
26
27
  get_tensor_max
27
28
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
29
+ from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
30
+ from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import quantize_tensor
28
31
 
29
32
 
30
33
  def symmetric_selection_tensor(tensor_data: np.ndarray,
@@ -37,7 +40,8 @@ def symmetric_selection_tensor(tensor_data: np.ndarray,
37
40
  quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE,
38
41
  node=None,
39
42
  hessian_info_service: HessianInfoService = None,
40
- num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> dict:
43
+ num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES,
44
+ ) -> Tuple[Dict[str, np.ndarray], int]:
41
45
  """
42
46
  Compute the optimal threshold based on the provided QuantizationErrorMethod to quantize the tensor.
43
47
  Different search is applied, depends on the value of the selected QuantizationErrorMethod.
@@ -47,7 +51,7 @@ def symmetric_selection_tensor(tensor_data: np.ndarray,
47
51
  p: p-norm to use for the Lp-norm distance.
48
52
  n_bits: Number of bits to quantize the tensor.
49
53
  per_channel: Whether the quantization should be per-channel or not.
50
- channel_axis: Output channel index.
54
+ channel_axis: Output channel index. if None, search for best axis.
51
55
  n_iter: Number of iterations to search for the optimal threshold (not used for this method).
52
56
  min_threshold: Minimal threshold to use if threshold is too small (not used for this method).
53
57
  quant_error_method: an error function to optimize the parameters' selection accordingly.
@@ -57,12 +61,24 @@ def symmetric_selection_tensor(tensor_data: np.ndarray,
57
61
 
58
62
  Returns:
59
63
  Optimal threshold to quantize the tensor in a symmetric manner.
64
+ Selected quantization channel axis.
60
65
  """
61
66
 
62
- tensor_max = get_tensor_max(tensor_data, per_channel, channel_axis, n_bits)
63
-
64
67
  if quant_error_method == qc.QuantizationErrorMethod.NOCLIPPING:
65
- threshold = get_init_threshold(min_threshold, tensor_max, per_channel)
68
+ if channel_axis is None and per_channel:
69
+ total_error_list = []
70
+ th_list = []
71
+ for _axis in range(len(tensor_data.shape)):
72
+ tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits)
73
+ threshold = get_init_threshold(min_threshold, tensor_max, per_channel)
74
+ q_tensor_data = quantize_tensor(tensor_data, threshold, n_bits, True)
75
+ total_error_list.append(compute_mse(tensor_data, q_tensor_data, norm=True))
76
+ th_list.append(threshold)
77
+ channel_axis = np.argmin(total_error_list)
78
+ threshold = th_list[channel_axis]
79
+ else:
80
+ tensor_max = get_tensor_max(tensor_data, per_channel, channel_axis, n_bits)
81
+ threshold = get_init_threshold(min_threshold, tensor_max, per_channel)
66
82
  else:
67
83
  signed = True # weights are always signed
68
84
  axis = -1 if per_channel else None
@@ -71,15 +87,14 @@ def symmetric_selection_tensor(tensor_data: np.ndarray,
71
87
  signed=signed, node=node,
72
88
  hessian_info_service=hessian_info_service,
73
89
  num_hessian_samples=num_hessian_samples)
74
- threshold = qparams_symmetric_selection_tensor_search(error_function,
75
- tensor_data,
76
- tensor_max,
77
- n_bits,
78
- per_channel,
79
- channel_axis,
80
- min_threshold=min_threshold,
81
- signed=signed)
82
- return {THRESHOLD: threshold}
90
+ threshold, channel_axis = qparams_symmetric_selection_tensor_search(error_function,
91
+ tensor_data,
92
+ n_bits,
93
+ per_channel,
94
+ channel_axis,
95
+ min_threshold=min_threshold,
96
+ signed=signed)
97
+ return {THRESHOLD: threshold}, channel_axis
83
98
 
84
99
 
85
100
  def symmetric_selection_histogram(bins: np.ndarray,
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import numpy as np
16
+ from typing import Union, Tuple, Dict
16
17
 
17
18
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
18
19
  from model_compression_toolkit.constants import MIN_THRESHOLD, RANGE_MIN, RANGE_MAX, NUM_QPARAM_HESSIAN_SAMPLES
@@ -24,6 +25,9 @@ from model_compression_toolkit.core.common.quantization.quantization_params_gene
24
25
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import get_tensor_max, \
25
26
  get_tensor_min
26
27
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
28
+ from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
29
+ from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import uniform_quantize_tensor
30
+
27
31
 
28
32
  def uniform_selection_tensor(tensor_data: np.ndarray,
29
33
  p: int,
@@ -35,7 +39,8 @@ def uniform_selection_tensor(tensor_data: np.ndarray,
35
39
  quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE,
36
40
  node=None,
37
41
  hessian_info_service: HessianInfoService = None,
38
- num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> dict:
42
+ num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES,
43
+ ) -> Tuple[Dict[str, np.ndarray], int]:
39
44
  """
40
45
  Compute the optimal quantization range based on the provided QuantizationErrorMethod
41
46
  to uniformly quantize the tensor.
@@ -46,7 +51,7 @@ def uniform_selection_tensor(tensor_data: np.ndarray,
46
51
  p: p-norm to use for the Lp-norm distance.
47
52
  n_bits: Number of bits to quantize the tensor.
48
53
  per_channel: Whether the quantization should be per-channel or not.
49
- channel_axis: Output channel index.
54
+ channel_axis: Output channel index. if None, search for best axis.
50
55
  n_iter: Number of iterations to search for the optimal threshold (not used for this method).
51
56
  min_threshold: Minimal threshold to use if threshold is too small (not used for this method).
52
57
  quant_error_method: an error function to optimize the range parameters' selection accordingly.
@@ -56,27 +61,48 @@ def uniform_selection_tensor(tensor_data: np.ndarray,
56
61
 
57
62
  Returns:
58
63
  Optimal quantization range to quantize the tensor uniformly.
64
+ Selected quantization channel axis.
59
65
  """
60
- tensor_min = get_tensor_min(tensor_data, per_channel, channel_axis)
61
- tensor_max = get_tensor_max(tensor_data, per_channel, channel_axis, n_bits, is_uniform_quantization=True)
62
-
63
66
  if quant_error_method == qc.QuantizationErrorMethod.NOCLIPPING:
64
- mm = tensor_min, tensor_max
67
+ if channel_axis is None and per_channel:
68
+ total_error_list = []
69
+ th_list = []
70
+ for _axis in range(len(tensor_data.shape)):
71
+ tensor_min = get_tensor_min(tensor_data, per_channel, _axis)
72
+ tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits, is_uniform_quantization=True)
73
+ q_tensor_data = uniform_quantize_tensor(tensor_data, tensor_min, tensor_max, n_bits)
74
+ total_error_list.append(compute_mse(tensor_data, q_tensor_data, norm=True))
75
+ th_list.append((tensor_min, tensor_max))
76
+ channel_axis = np.argmin(total_error_list)
77
+ mm = th_list[channel_axis]
78
+ else:
79
+ tensor_min = get_tensor_min(tensor_data, per_channel, channel_axis)
80
+ tensor_max = get_tensor_max(tensor_data, per_channel, channel_axis, n_bits, is_uniform_quantization=True)
81
+ mm = tensor_min, tensor_max
65
82
  else:
66
83
  axis = -1 if per_channel else None
67
84
  error_function = get_threshold_selection_tensor_error_function(QuantizationMethod.UNIFORM, quant_error_method,
68
85
  p, axis=axis, norm=False, node=node,
69
86
  hessian_info_service=hessian_info_service,
70
87
  num_hessian_samples=num_hessian_samples)
71
- mm = qparams_uniform_selection_tensor_search(error_function,
72
- tensor_data,
73
- tensor_min,
74
- tensor_max,
75
- n_bits,
76
- per_channel,
77
- channel_axis)
88
+ mm, channel_axis = qparams_uniform_selection_tensor_search(error_function,
89
+ tensor_data,
90
+ n_bits,
91
+ per_channel,
92
+ channel_axis)
93
+ # In case the tensor\axis has a single value, then min==max, so need to adjust either min or max to zero.
94
+ if not isinstance(mm[0], np.ndarray):
95
+ if mm[0] > 0:
96
+ mm = (np.float32(0).astype(mm[0].dtype), mm[1])
97
+ if mm[1] < 0:
98
+ mm = (mm[0], np.float32(0).astype(mm[1].dtype))
99
+ else:
100
+ adj_min_to_zero = np.logical_and(mm[1] == mm[0], mm[0] > 0)
101
+ adj_max_to_zero = np.logical_and(mm[1] == mm[0], mm[1] < 0)
102
+ mm[0][adj_min_to_zero] = 0
103
+ mm[1][adj_max_to_zero] = 0
78
104
  return {RANGE_MIN: mm[0],
79
- RANGE_MAX: mm[1]}
105
+ RANGE_MAX: mm[1]}, channel_axis
80
106
 
81
107
 
82
108
  def uniform_selection_histogram(bins: np.ndarray,
@@ -158,7 +158,8 @@ def build_node(node: KerasNode,
158
158
  if is_const(arg) or (
159
159
  keras_layer.symbol in tf_function_symbols and
160
160
  isinstance(arg, (tuple, list))):
161
- weights.update({i: to_numpy(arg, is_single_tensor=True)})
161
+ if inputs_as_list or i in kwarg2index.values():
162
+ weights.update({i: to_numpy(arg, is_single_tensor=True)})
162
163
  # remove weights and KerasTensors and weights from op_call_args
163
164
  if inputs_as_list:
164
165
  op_call_args = tuple(op_call_args[1:])
@@ -169,8 +170,7 @@ def build_node(node: KerasNode,
169
170
  # read weights from call kwargs
170
171
  weight_keys = []
171
172
  for k, v in op_call_kwargs.items():
172
- if is_const(v) or (keras_layer.function in [tf.add, tf.multiply, tf.subtract, tf.divide, tf.truediv, tf.pow,
173
- tf.matmul] and
173
+ if is_const(v) or (keras_layer.symbol in tf_function_symbols and
174
174
  isinstance(v, (tuple, list))):
175
175
  if k in kwarg2index:
176
176
  weights.update({kwarg2index[k]: to_numpy(v, is_single_tensor=True)})
@@ -39,20 +39,16 @@ from mct_quantizers import PytorchQuantizationWrapper
39
39
  def _build_input_tensors_list(node: BaseNode,
40
40
  graph: Graph,
41
41
  inputs: Tuple[Any],
42
- node_to_output_tensors_dict: Dict[BaseNode, List],
43
- is_op_quantize_wrapper: bool) -> List[List]:
42
+ node_to_output_tensors_dict: Dict[BaseNode, List]) -> List[List]:
44
43
  """
45
44
  Given a node, build a list of input tensors the node gets. The list is built based on the
46
- node's incoming edges, previous nodes' output tensors and the node's positional weights.
47
- Positional weights aren't used if the node's op is PytorchQuantizationWrapper, since it's
48
- positional weights are already in the wrapper.
45
+ node's incoming edges, previous nodes' output tensors.
49
46
 
50
47
  Args:
51
48
  node: Node to build its input tensors list.
52
49
  graph: Graph the node is in.
53
50
  inputs: list of input tensors to model.
54
51
  node_to_output_tensors_dict: A dictionary from a node to its output tensors.
55
- is_op_quantize_wrapper: Whether the func_op is a PytorchQuantizationWrapper or not.
56
52
 
57
53
  Returns:
58
54
  A list of the node's input tensors.
@@ -67,35 +63,30 @@ def _build_input_tensors_list(node: BaseNode,
67
63
  _input_tensors = node_to_output_tensors_dict[ie.source_node]
68
64
  input_tensors.append(_input_tensors)
69
65
  input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists
70
- input_tensors = node.insert_positional_weights_to_input_list(input_tensors)
71
- # convert inputs from positional weights (numpy arrays) to tensors. Must handle each element in the
72
- # list separately, because in FX the tensors are FX objects and fail to_torch_tensor
73
- input_tensors = [to_torch_tensor(t, numpy_type=t.dtype) if isinstance(t, np.ndarray) else t
74
- for t in input_tensors]
75
66
  return input_tensors
76
67
 
77
68
 
78
69
  def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List,
79
- is_op_quantize_wrapper: bool) -> List:
70
+ tensor_input_indices: List = None) -> List:
80
71
  """
81
- Merge input tensors list with op_call_args, according to correct order.
72
+ Merge input tensors list with positional weights and op_call_args, according to correct order.
82
73
 
83
74
  Args:
84
75
  _node: The node the inputs are for.
85
76
  input_tensors: activation input tensors to node.
86
77
  op_call_args: framework node call args.
87
- is_op_quantize_wrapper: Whether the func_op is a PytorchQuantizationWrapper or not.
78
+
88
79
  Returns:
89
80
  Combined list of input_tensors and op_call_args.
90
81
  """
91
82
  if isinstance(_node, FunctionalNode) and _node.tensor_input_indices:
92
83
  _input_list = op_call_args.copy()
93
- if is_op_quantize_wrapper:
94
- _input_list = input_tensors + _input_list
95
- else:
96
- assert len(_node.tensor_input_indices) == len(input_tensors), 'Mismatch between input tensors and indices'
97
- for i, t in zip(_node.tensor_input_indices, input_tensors):
98
- _input_list.insert(i, t)
84
+ if tensor_input_indices is None:
85
+ tensor_input_indices = _node.tensor_input_indices
86
+ assert len(tensor_input_indices) == len(input_tensors), \
87
+ f'Mismatch between input tensors ({len(tensor_input_indices)}) and indices {len(input_tensors)}'
88
+ for i, t in zip(tensor_input_indices, input_tensors):
89
+ _input_list.insert(i, t)
99
90
  else:
100
91
  _input_list = input_tensors + op_call_args
101
92
 
@@ -126,10 +117,22 @@ def _run_operation(n: BaseNode,
126
117
  op_call_args = n.op_call_args if isinstance(n, FunctionalNode) else []
127
118
  functional_kwargs = n.op_call_kwargs if isinstance(n, FunctionalNode) else {}
128
119
 
120
+ if not (isinstance(n, FunctionalNode) and isinstance(op_func, PytorchQuantizationWrapper)):
121
+ # Insert positional weights only when not a quantized functional node, because quantized functional nodes
122
+ # insert the quantized weights in the wrapper.
123
+ input_tensors = n.insert_positional_weights_to_input_list(input_tensors)
124
+ # convert inputs from positional weights (numpy arrays) to tensors. Must handle each element in the
125
+ # list separately, because in FX the tensors are FX objects and fail to_torch_tensor
126
+ input_tensors = [to_torch_tensor(t, numpy_type=t.dtype) if isinstance(t, np.ndarray) else t
127
+ for t in input_tensors]
128
+ _tensor_input_indices = None
129
+ else:
130
+ _tensor_input_indices = [i for i in n.tensor_input_indices if i not in n.weights]
131
+
129
132
  if isinstance(n, FunctionalNode) and n.inputs_as_list:
130
133
  out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
131
134
  else:
132
- merged_inputs = _merge_inputs(n, input_tensors, op_call_args, isinstance(op_func, PytorchQuantizationWrapper))
135
+ merged_inputs = _merge_inputs(n, input_tensors, op_call_args, tensor_input_indices=_tensor_input_indices)
133
136
  out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs)
134
137
 
135
138
  # Add a fake quant node if the node has an activation threshold.
@@ -295,8 +298,7 @@ class PytorchModel(torch.nn.Module):
295
298
  input_tensors = _build_input_tensors_list(node,
296
299
  self.graph,
297
300
  args,
298
- node_to_output_tensors_dict,
299
- isinstance(op_func, PytorchQuantizationWrapper))
301
+ node_to_output_tensors_dict)
300
302
  use_activation_quantization, activation_quantization_fn = self._get_activation_quantization_fn(node)
301
303
 
302
304
  # Run node operation and fetch outputs