mct-nightly 2.4.0.20250629.706__py3-none-any.whl → 2.4.0.20250701.185106__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/METADATA +16 -16
  2. {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/RECORD +75 -72
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -1
  5. model_compression_toolkit/core/common/framework_info.py +5 -32
  6. model_compression_toolkit/core/common/fusion/graph_fuser.py +12 -9
  7. model_compression_toolkit/core/common/graph/base_graph.py +20 -37
  8. model_compression_toolkit/core/common/graph/base_node.py +13 -106
  9. model_compression_toolkit/core/common/graph/functional_node.py +1 -1
  10. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +12 -10
  11. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +14 -9
  12. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +9 -15
  13. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +2 -3
  14. model_compression_toolkit/core/common/network_editors/__init__.py +8 -1
  15. model_compression_toolkit/core/common/network_editors/actions.py +4 -96
  16. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  17. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +116 -56
  18. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -1
  19. model_compression_toolkit/core/common/quantization/node_quantization_config.py +55 -179
  20. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +21 -1
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +8 -5
  22. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -70
  23. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +10 -12
  24. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +54 -30
  25. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  26. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +93 -398
  27. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +2 -5
  28. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -4
  29. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -6
  30. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +12 -6
  31. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +1 -1
  32. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -2
  33. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +33 -33
  34. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +2 -4
  35. model_compression_toolkit/core/graph_prep_runner.py +31 -20
  36. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +5 -2
  37. model_compression_toolkit/core/keras/default_framework_info.py +0 -11
  38. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +9 -6
  39. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +3 -1
  40. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -1
  41. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +2 -1
  42. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
  43. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +47 -0
  44. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +3 -2
  45. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +5 -2
  46. model_compression_toolkit/core/pytorch/default_framework_info.py +0 -12
  47. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  48. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +2 -0
  49. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +1 -1
  50. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +2 -1
  51. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +1 -1
  52. model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
  53. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +45 -0
  54. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +3 -2
  55. model_compression_toolkit/core/runner.py +1 -1
  56. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +7 -3
  57. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  58. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +12 -3
  59. model_compression_toolkit/pruning/keras/pruning_facade.py +5 -9
  60. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -5
  61. model_compression_toolkit/ptq/keras/quantization_facade.py +1 -1
  62. model_compression_toolkit/qat/keras/quantization_facade.py +1 -1
  63. model_compression_toolkit/qat/pytorch/quantization_facade.py +1 -1
  64. model_compression_toolkit/quantization_preparation/__init__.py +14 -0
  65. model_compression_toolkit/quantization_preparation/load_fqc.py +223 -0
  66. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  67. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -78
  68. {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/WHEEL +0 -0
  69. {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/licenses/LICENSE.md +0 -0
  70. {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/top_level.txt +0 -0
  71. /model_compression_toolkit/core/keras/{quantizer → quantization}/__init__.py +0 -0
  72. /model_compression_toolkit/core/keras/{quantizer → quantization}/fake_quant_builder.py +0 -0
  73. /model_compression_toolkit/core/keras/{quantizer → quantization}/lut_fake_quant.py +0 -0
  74. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/__init__.py +0 -0
  75. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/fake_quant_builder.py +0 -0
  76. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/lut_fake_quant.py +0 -0
@@ -12,35 +12,38 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Dict, Any, Tuple
15
+ from functools import partial
16
+ from typing import Dict, Any, Tuple, Callable, TYPE_CHECKING
16
17
 
17
18
  import numpy as np
19
+ from mct_quantizers import QuantizationMethod
18
20
 
19
21
  from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
20
22
  from model_compression_toolkit.core.common.hessian import HessianInfoService
21
- from model_compression_toolkit.defaultdict import DefaultDict
22
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
- from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig, \
24
- WeightsAttrQuantizationConfig
23
+ from model_compression_toolkit.core.common.quantization.quantization_params_generation import \
24
+ power_of_two_selection_tensor, lut_kmeans_tensor, symmetric_selection_tensor, uniform_selection_tensor
25
25
  from model_compression_toolkit.logger import Logger
26
26
 
27
+ if TYPE_CHECKING:
28
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
27
29
 
28
- def get_weights_qparams(weights_attr_values: np.ndarray,
29
- weights_quant_config: NodeWeightsQuantizationConfig,
30
- attr_quant_config: WeightsAttrQuantizationConfig,
31
- output_channels_axis: int,
32
- node=None,
33
- hessian_info_service: HessianInfoService = None,
34
- num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Tuple[Dict[Any, Any], int]:
30
+
31
+ def compute_weights_qparams(weights_attr_values: np.ndarray,
32
+ attr_quant_config: 'WeightsAttrQuantizationConfig',
33
+ output_channels_axis: int,
34
+ min_threshold: float,
35
+ node=None,
36
+ hessian_info_service: HessianInfoService = None,
37
+ num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Tuple[Dict[Any, Any], int]:
35
38
  """
36
39
  Compute thresholds to quantize a kernel according to a NodeWeightsQuantizationConfig
37
40
  instance.
38
41
 
39
42
  Args:
40
43
  weights_attr_values: Weights attribute parameter to compute the quantization thresholds for.
41
- weights_quant_config: Weights quantization configuration to define how the thresholds are computed.
42
44
  attr_quant_config: A specific weights attribute quantization configuration to get its params.
43
45
  output_channels_axis: Index of the kernel output channels dimension.
46
+ min_threshold: Minimal threshold to use if threshold is too small.
44
47
  node: The node for which the quantization error is computed (used only with HMSE error method).
45
48
  hessian_info_service: HessianInfoService object for retrieving Hessian-based scores (used only with HMSE error method).
46
49
  num_hessian_samples: Number of samples to approximate Hessian-based scores on (used only with HMSE error method).
@@ -49,22 +52,43 @@ def get_weights_qparams(weights_attr_values: np.ndarray,
49
52
  A dictionary with the quantization threshold of the kernel.
50
53
  Selected quantization channel axis.
51
54
  """
52
- if attr_quant_config.weights_quantization_params_fn is not None:
53
- weights_params, output_channels_axis = attr_quant_config.weights_quantization_params_fn(
54
- weights_attr_values,
55
- p=attr_quant_config.l_p_value,
56
- n_bits=attr_quant_config.weights_n_bits,
57
- per_channel=attr_quant_config.weights_per_channel_threshold,
58
- channel_axis=output_channels_axis,
59
- min_threshold=weights_quant_config.min_threshold,
60
- quant_error_method=attr_quant_config.weights_error_method,
61
- node=node,
62
- hessian_info_service=hessian_info_service,
63
- num_hessian_samples=num_hessian_samples)
64
- else: # pragma: no cover
65
- Logger.error(f"Requested weights quantization parameters computation for node {node.name} without providing a "
66
- f"weights_quantization_params_fn."
67
- f"Returning an empty dictionary since no quantization parameters were computed.")
68
- weights_params = {}
55
+ params_fn = _get_weights_quantization_params_fn(attr_quant_config.weights_quantization_method)
56
+ weights_params, output_channels_axis = params_fn(
57
+ weights_attr_values,
58
+ p=attr_quant_config.l_p_value,
59
+ n_bits=attr_quant_config.weights_n_bits,
60
+ per_channel=attr_quant_config.weights_per_channel_threshold,
61
+ channel_axis=output_channels_axis,
62
+ min_threshold=min_threshold,
63
+ quant_error_method=attr_quant_config.weights_error_method,
64
+ node=node,
65
+ hessian_info_service=hessian_info_service,
66
+ num_hessian_samples=num_hessian_samples)
69
67
 
70
68
  return weights_params, output_channels_axis
69
+
70
+
71
+ _weights_quant_params_fns = {
72
+ QuantizationMethod.POWER_OF_TWO: power_of_two_selection_tensor,
73
+ QuantizationMethod.SYMMETRIC: symmetric_selection_tensor,
74
+ QuantizationMethod.UNIFORM: uniform_selection_tensor,
75
+ QuantizationMethod.LUT_POT_QUANTIZER: partial(lut_kmeans_tensor, is_symmetric=False),
76
+ QuantizationMethod.LUT_SYM_QUANTIZER: partial(lut_kmeans_tensor, is_symmetric=True)
77
+ }
78
+
79
+
80
+ def _get_weights_quantization_params_fn(weights_quantization_method: QuantizationMethod) -> Callable:
81
+ """
82
+ Generate a function for finding weights quantization parameters.
83
+
84
+ Args:
85
+ weights_quantization_method: Which quantization method to use for weights.
86
+ Returns:
87
+ A function to find the quantization parameters.
88
+
89
+ """
90
+ params_fn = _weights_quant_params_fns.get(weights_quantization_method)
91
+ if not params_fn:
92
+ Logger.critical(
93
+ f"No parameter function found for the specified quantization method: {weights_quantization_method}") # pragma: no cover
94
+ return params_fn
@@ -12,8 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
16
-
15
+ from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_weights_quantization_fn
17
16
  from model_compression_toolkit.logger import Logger
18
17
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
19
18
  from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
@@ -47,11 +46,12 @@ def get_quantized_weights_attr_by_qc(attr_name: str,
47
46
  output_channels_axis = None
48
47
 
49
48
  Logger.debug(f'quantizing layer {n.name} attribute {attr_name} with {weights_qc.weights_n_bits} bits')
50
- quantized_kernel = weights_qc.weights_quantization_fn(n.get_weights_by_keys(attr_name),
51
- n_bits=weights_qc.weights_n_bits,
52
- signed=True,
53
- quantization_params=weights_qc.weights_quantization_params,
54
- per_channel=weights_qc.weights_per_channel_threshold,
55
- output_channels_axis=output_channels_axis)
49
+ weights_quantization_fn = get_weights_quantization_fn(weights_qc.weights_quantization_method)
50
+ quantized_kernel = weights_quantization_fn(n.get_weights_by_keys(attr_name),
51
+ n_bits=weights_qc.weights_n_bits,
52
+ signed=True,
53
+ quantization_params=weights_qc.weights_quantization_params,
54
+ per_channel=weights_qc.weights_per_channel_threshold,
55
+ output_channels_axis=output_channels_axis)
56
56
 
57
57
  return quantized_kernel, channels_axis