mct-nightly 2.4.0.20250630.629__py3-none-any.whl → 2.4.0.20250702.605__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/METADATA +16 -16
  2. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/RECORD +75 -72
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -1
  5. model_compression_toolkit/core/common/framework_info.py +5 -32
  6. model_compression_toolkit/core/common/fusion/graph_fuser.py +12 -9
  7. model_compression_toolkit/core/common/graph/base_graph.py +20 -37
  8. model_compression_toolkit/core/common/graph/base_node.py +13 -106
  9. model_compression_toolkit/core/common/graph/functional_node.py +1 -1
  10. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +12 -10
  11. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +14 -9
  12. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +9 -15
  13. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +2 -3
  14. model_compression_toolkit/core/common/network_editors/__init__.py +8 -1
  15. model_compression_toolkit/core/common/network_editors/actions.py +4 -96
  16. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  17. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +116 -56
  18. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -1
  19. model_compression_toolkit/core/common/quantization/node_quantization_config.py +55 -179
  20. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +21 -1
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +8 -5
  22. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -70
  23. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +10 -12
  24. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +54 -30
  25. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  26. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +93 -398
  27. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +2 -5
  28. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -4
  29. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -6
  30. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +12 -6
  31. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +1 -1
  32. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -2
  33. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +33 -33
  34. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +2 -4
  35. model_compression_toolkit/core/graph_prep_runner.py +31 -20
  36. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +5 -2
  37. model_compression_toolkit/core/keras/default_framework_info.py +0 -11
  38. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +9 -6
  39. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +3 -1
  40. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -1
  41. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +2 -1
  42. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
  43. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +47 -0
  44. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +3 -2
  45. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +5 -2
  46. model_compression_toolkit/core/pytorch/default_framework_info.py +0 -12
  47. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  48. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +2 -0
  49. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +1 -1
  50. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +2 -1
  51. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +1 -1
  52. model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
  53. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +45 -0
  54. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +3 -2
  55. model_compression_toolkit/core/runner.py +1 -1
  56. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +7 -3
  57. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  58. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +12 -3
  59. model_compression_toolkit/pruning/keras/pruning_facade.py +5 -9
  60. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -5
  61. model_compression_toolkit/ptq/keras/quantization_facade.py +1 -1
  62. model_compression_toolkit/qat/keras/quantization_facade.py +1 -1
  63. model_compression_toolkit/qat/pytorch/quantization_facade.py +1 -1
  64. model_compression_toolkit/quantization_preparation/__init__.py +14 -0
  65. model_compression_toolkit/quantization_preparation/load_fqc.py +223 -0
  66. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  67. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -78
  68. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/WHEEL +0 -0
  69. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/licenses/LICENSE.md +0 -0
  70. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/top_level.txt +0 -0
  71. /model_compression_toolkit/core/keras/{quantizer → quantization}/__init__.py +0 -0
  72. /model_compression_toolkit/core/keras/{quantizer → quantization}/fake_quant_builder.py +0 -0
  73. /model_compression_toolkit/core/keras/{quantizer → quantization}/lut_fake_quant.py +0 -0
  74. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/__init__.py +0 -0
  75. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/fake_quant_builder.py +0 -0
  76. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/lut_fake_quant.py +0 -0
@@ -21,12 +21,8 @@ from torch.nn import Conv2d, ConvTranspose2d, Linear
21
21
  from torch import sigmoid
22
22
 
23
23
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo, set_fw_info, ChannelAxisMapping
24
- from mct_quantizers import QuantizationMethod
25
24
  from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
26
25
  from model_compression_toolkit.core.pytorch.constants import KERNEL
27
- from model_compression_toolkit.core.pytorch.quantizer.fake_quant_builder import power_of_two_quantization, \
28
- symmetric_quantization, uniform_quantization
29
- from model_compression_toolkit.core.pytorch.quantizer.lut_fake_quant import activation_lut_kmean_quantizer
30
26
 
31
27
 
32
28
  class PyTorchInfo(FrameworkInfo):
@@ -81,14 +77,6 @@ class PyTorchInfo(FrameworkInfo):
81
77
  SiLU: (-0.279, None),
82
78
  }
83
79
 
84
- """
85
- Mapping from a QuantizationMethod to an activation quantizer function.
86
- """
87
- activation_quantizer_mapping = {QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
88
- QuantizationMethod.SYMMETRIC: symmetric_quantization,
89
- QuantizationMethod.UNIFORM: uniform_quantization,
90
- QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer}
91
-
92
80
  @classmethod
93
81
  def get_kernel_channels(cls, node_type: Any) -> ChannelAxisMapping:
94
82
  """
@@ -95,11 +95,11 @@ class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
95
95
  else:
96
96
  return graph
97
97
  elif non_linear_node.is_match_type(hardtanh):
98
- if (non_linear_node.framework_attr[HARDTANH_MIN_VAL] == 0.0) and not \
99
- (np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]).astype(int) -
100
- np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]) == 0):
101
- scale_factor = non_linear_node.framework_attr[HARDTANH_MAX_VAL] / self.threshold
102
- non_linear_node.functional_op.__defaults__ = (0.0, self.threshold, non_linear_node.framework_attr[INPLACE])
98
+ kwargs = non_linear_node.op_call_kwargs
99
+ if (kwargs[HARDTANH_MIN_VAL] == 0.0) and not \
100
+ (np.log2(kwargs[HARDTANH_MAX_VAL]).astype(int) - np.log2(kwargs[HARDTANH_MAX_VAL]) == 0):
101
+ scale_factor = kwargs[HARDTANH_MAX_VAL] / self.threshold
102
+ non_linear_node.functional_op.__defaults__ = (0.0, self.threshold, kwargs[INPLACE])
103
103
  else:
104
104
  return graph
105
105
  else:
@@ -29,6 +29,7 @@ from model_compression_toolkit.core.common import BaseNode, Graph
29
29
  from model_compression_toolkit.core.common.graph.graph_matchers import EdgeMatcher
30
30
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
31
31
  from model_compression_toolkit.core.common.substitutions.shift_negative_activation import apply_shift_negative_correction
32
+ from model_compression_toolkit.core.pytorch.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
32
33
  from model_compression_toolkit.core.pytorch.constants import PAD, VALUE, PADDING, BIAS, USE_BIAS
33
34
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
34
35
 
@@ -239,4 +240,5 @@ def pytorch_apply_shift_negative_correction(graph: Graph,
239
240
  PADDING,
240
241
  BIAS,
241
242
  USE_BIAS,
243
+ get_activation_quantization_fn_factory,
242
244
  params_search_quantization_fn=params_search_quantization_fn)
@@ -91,7 +91,7 @@ class WeightsHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
91
91
  for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
92
92
 
93
93
  # Check if the target node's layer type is supported.
94
- if not ipt_node.is_kernel_op:
94
+ if not ipt_node.kernel_attr:
95
95
  Logger.critical(f"Hessian information with respect to weights is not supported for "
96
96
  f"{ipt_node.type} layers.") # pragma: no cover
97
97
 
@@ -20,6 +20,7 @@ from model_compression_toolkit.core.common.mixed_precision.configurable_quantize
20
20
  verify_candidates_descending_order, init_activation_quantizers
21
21
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
22
22
  CandidateNodeQuantizationConfig
23
+ from model_compression_toolkit.core.pytorch.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
23
24
  from model_compression_toolkit.logger import Logger
24
25
  from mct_quantizers import QuantizationMethod
25
26
  from mct_quantizers import QuantizationTarget
@@ -67,7 +68,7 @@ class ConfigurableActivationQuantizer(BasePyTorchInferableQuantizer):
67
68
  Logger.critical("Unsupported configuration: Mixing candidates with differing activation quantization states (enabled/disabled).") # pragma: no cover
68
69
 
69
70
  # Setting layer's activation
70
- self.activation_quantizers = init_activation_quantizers(self.node_q_cfg)
71
+ self.activation_quantizers = init_activation_quantizers(self.node_q_cfg, get_activation_quantization_fn_factory)
71
72
  self.active_quantization_config_index = max_candidate_idx # initialize with first config as default
72
73
 
73
74
  def set_active_activation_quantizer(self, index: Optional[int]):
@@ -167,7 +167,7 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
167
167
  """
168
168
 
169
169
  attributes_with_axis = {}
170
- if node.is_kernel_op:
170
+ if node.kernel_attr:
171
171
  attributes_with_axis[node.kernel_attr] = (node.channel_axis.output, node.channel_axis.input)
172
172
 
173
173
  # Bias is a vector at the length of the number of output channels.
@@ -26,7 +26,7 @@ from torch.nn import Module, Sigmoid, Softmax
26
26
 
27
27
  import model_compression_toolkit.core.pytorch.constants as pytorch_constants
28
28
  from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
29
- from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig
29
+ from model_compression_toolkit.core import QuantizationConfig, CoreConfig
30
30
  from model_compression_toolkit.core import common
31
31
  from model_compression_toolkit.core.common import Graph, BaseNode
32
32
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
@@ -0,0 +1,45 @@
1
+ # Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from collections.abc import Callable
17
+
18
+ from mct_quantizers import QuantizationMethod
19
+ from model_compression_toolkit.core.pytorch.quantization.fake_quant_builder import power_of_two_quantization, \
20
+ symmetric_quantization, uniform_quantization
21
+ from model_compression_toolkit.core.pytorch.quantization.lut_fake_quant import activation_lut_kmean_quantizer
22
+
23
+
24
+ """
25
+ Mapping from a QuantizationMethod to an activation quantizer function.
26
+ """
27
+ _activation_quantizer_factory_mapping = {
28
+ QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
29
+ QuantizationMethod.SYMMETRIC: symmetric_quantization,
30
+ QuantizationMethod.UNIFORM: uniform_quantization,
31
+ QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer
32
+ }
33
+
34
+
35
+ def get_activation_quantization_fn_factory(quantization_method: QuantizationMethod) -> Callable[[int, dict], Callable]:
36
+ """
37
+ Get factory for activation quantizer.
38
+
39
+ Args:
40
+ quantization_method: quantization method for activation.
41
+
42
+ Returns:
43
+ Factory that accepts activation bitwidth and a dict of quantization params, and returns the quantizer.
44
+ """
45
+ return _activation_quantizer_factory_mapping[quantization_method]
@@ -18,7 +18,7 @@ from torch.nn import Conv2d, Linear, ConvTranspose2d
18
18
  from model_compression_toolkit.core import QuantizationConfig
19
19
  from model_compression_toolkit.core.common import Graph
20
20
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
21
+ from model_compression_toolkit.core.pytorch.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
22
22
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
23
23
  from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \
24
24
  compute_activation_bias_correction_of_graph
@@ -50,5 +50,6 @@ def pytorch_compute_activation_bias_correction_of_graph(graph: Graph,
50
50
  fw_impl=fw_impl,
51
51
  activation_bias_correction_node_matchers=
52
52
  activation_bias_correction_node_matchers,
53
- kernel_size=KERNEL_SIZE)
53
+ kernel_size=KERNEL_SIZE,
54
+ get_activation_quantization_fn_factory=get_activation_quantization_fn_factory)
54
55
  return graph
@@ -118,7 +118,7 @@ def core_runner(in_model: Any,
118
118
  if core_config.is_mixed_precision_enabled:
119
119
  if core_config.mixed_precision_config.configuration_overwrite is None:
120
120
 
121
- filter_candidates_for_mixed_precision(graph, target_resource_utilization, fqc)
121
+ filter_candidates_for_mixed_precision(graph, target_resource_utilization)
122
122
  bit_widths_config = search_bit_width(tg,
123
123
  fw_impl,
124
124
  target_resource_utilization,
@@ -12,14 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Callable
15
+ from typing import Callable, Optional, List
16
16
  from io import BytesIO
17
17
 
18
18
  import torch.nn
19
19
 
20
20
  from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
21
21
 
22
- from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
23
22
  from model_compression_toolkit.verify_packages import FOUND_ONNX
24
23
  from model_compression_toolkit.logger import Logger
25
24
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
@@ -65,11 +64,14 @@ if FOUND_ONNX:
65
64
  self._use_onnx_custom_quantizer_ops = use_onnx_custom_quantizer_ops
66
65
  self._onnx_opset_version = onnx_opset_version
67
66
 
68
- def export(self, output_names=None) -> None:
67
+ def export(self, output_names: Optional[List[str]] = None) -> None:
69
68
  """
70
69
  Convert an exportable (fully-quantized) PyTorch model to a fakely-quant model
71
70
  (namely, weights that are in fake-quant format) and fake-quant layers for the activations.
72
71
 
72
+ Args:
73
+ output_names (Optional[List[str]]): Optional list of output node names for export compatibility.
74
+
73
75
  Returns:
74
76
  Fake-quant PyTorch model.
75
77
  """
@@ -131,6 +133,8 @@ if FOUND_ONNX:
131
133
  output_names = ['output']
132
134
  dynamic_axes.update({'output': {0: 'batch_size'}})
133
135
  else:
136
+ assert isinstance(output_names, list), \
137
+ f"`output_names` must be a list, but got {type(output_names).__name__}"
134
138
  if isinstance(model_output, (list, tuple)):
135
139
  num_of_outputs = len(model_output)
136
140
  else:
@@ -49,7 +49,7 @@ class FakelyQuantTorchScriptPyTorchExporter(BasePyTorchExporter):
49
49
  save_model_path,
50
50
  repr_dataset)
51
51
 
52
- def export(self) -> None:
52
+ def export(self, output_names=None) -> None:
53
53
  """
54
54
  Convert an exportable (fully-quantized) PyTorch model to a fakely-quant model
55
55
  (namely, weights that are in fake-quant format) and fake-quant layers for the activations.
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Callable
15
+ from typing import Callable, Optional, List
16
16
  from packaging import version
17
17
 
18
18
  from model_compression_toolkit.verify_packages import FOUND_TORCH
@@ -49,7 +49,8 @@ if FOUND_TORCH:
49
49
  is_layer_exportable_fn: Callable = is_pytorch_layer_exportable,
50
50
  serialization_format: PytorchExportSerializationFormat = PytorchExportSerializationFormat.ONNX,
51
51
  quantization_format: QuantizationFormat = QuantizationFormat.MCTQ,
52
- onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION) -> None:
52
+ onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION,
53
+ output_names: Optional[List[str]] = None) -> None:
53
54
  """
54
55
  Export a PyTorch quantized model to a torchscript or onnx model.
55
56
  The model will be saved to the path in save_model_path.
@@ -67,11 +68,19 @@ if FOUND_TORCH:
67
68
  PytorchExportSerializationFormat.ONNX).
68
69
  quantization_format: Format of how quantizers are exported (fakely-quant, int8, MCTQ quantizers).
69
70
  onnx_opset_version: ONNX opset version to use for exported ONNX model.
71
+ output_names (Optional[List[str]]): Optional list of output node names for export compatibility.
72
+ This argument is relevant only when using PytorchExportSerializationFormat.ONNX.
70
73
 
71
74
  """
72
75
  # Ensure 'metadata' is available directly on the model, if present in submodules
73
76
  find_and_assign_metadata_attr(model)
74
77
 
78
+ if output_names is not None and serialization_format != PytorchExportSerializationFormat.ONNX:
79
+ Logger.warning(
80
+ f'`output_names` is only applicable when exporting to ONNX. '
81
+ f'Current serialization format is {serialization_format}, so `output_names` will be ignored.'
82
+ ) # pragma: no cover
83
+
75
84
  if serialization_format == PytorchExportSerializationFormat.TORCHSCRIPT:
76
85
  if quantization_format in supported_serialization_quantization_export_dict[serialization_format]:
77
86
  exporter = FakelyQuantTorchScriptPyTorchExporter(model,
@@ -109,7 +118,7 @@ if FOUND_TORCH:
109
118
  f'Unsupported serialization {serialization_format} was used to export Pytorch model.'
110
119
  f' Please see API for supported formats.') # pragma: no cover
111
120
 
112
- exporter.export()
121
+ exporter.export(output_names=output_names)
113
122
 
114
123
  else:
115
124
  def pytorch_export_model(*args, **kwargs):
@@ -17,6 +17,7 @@ from typing import Callable, Tuple, Union
17
17
 
18
18
  from model_compression_toolkit import get_target_platform_capabilities
19
19
  from model_compression_toolkit.constants import TENSORFLOW
20
+ from model_compression_toolkit.quantization_preparation.load_fqc import load_fqc_configuration
20
21
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
21
22
  from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
22
23
  from model_compression_toolkit.verify_packages import FOUND_TF
@@ -24,10 +25,8 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_
24
25
  from model_compression_toolkit.core.common.pruning.pruner import Pruner
25
26
  from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
26
27
  from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
27
- from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
28
28
  from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
29
29
  from model_compression_toolkit.logger import Logger
30
- from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
31
30
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
32
31
 
33
32
  if FOUND_TF:
@@ -117,20 +116,17 @@ if FOUND_TF:
117
116
 
118
117
  target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
119
118
  # Attach tpc model to framework
120
- attach2keras = AttachTpcToKeras()
121
- target_platform_capabilities = attach2keras.attach(target_platform_capabilities)
119
+ framework_platform_capabilities = AttachTpcToKeras().attach(target_platform_capabilities)
122
120
 
123
121
  # Convert the original Keras model to an internal graph representation.
124
122
  float_graph = read_model_to_graph(model,
125
123
  representative_data_gen,
126
- target_platform_capabilities,
124
+ framework_platform_capabilities,
127
125
  fw_impl)
128
126
 
129
127
  # Apply quantization configuration to the graph. This step is necessary even when not quantizing,
130
128
  # as it prepares the graph for the pruning process.
131
- float_graph_with_compression_config = set_quantization_configuration_to_graph(float_graph,
132
- quant_config=DEFAULTCONFIG,
133
- mixed_precision_enable=False)
129
+ float_graph_with_compression_config = load_fqc_configuration(float_graph, framework_platform_capabilities)
134
130
 
135
131
  # Create a Pruner object with the graph and configuration.
136
132
  pruner = Pruner(float_graph_with_compression_config,
@@ -138,7 +134,7 @@ if FOUND_TF:
138
134
  target_resource_utilization,
139
135
  representative_data_gen,
140
136
  pruning_config,
141
- target_platform_capabilities)
137
+ framework_platform_capabilities)
142
138
 
143
139
  # Apply the pruning process.
144
140
  pruned_graph = pruner.prune_graph()
@@ -23,10 +23,9 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_
23
23
  from model_compression_toolkit.core.common.pruning.pruner import Pruner
24
24
  from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
25
25
  from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
26
- from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
26
+ from model_compression_toolkit.quantization_preparation.load_fqc import load_fqc_configuration
27
27
  from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
28
28
  from model_compression_toolkit.logger import Logger
29
- from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
30
29
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
31
30
 
32
31
 
@@ -134,9 +133,7 @@ if FOUND_TORCH:
134
133
 
135
134
  # Apply quantization configuration to the graph. This step is necessary even when not quantizing,
136
135
  # as it prepares the graph for the pruning process.
137
- float_graph_with_compression_config = set_quantization_configuration_to_graph(float_graph,
138
- quant_config=DEFAULTCONFIG,
139
- mixed_precision_enable=False)
136
+ float_graph_with_compression_config = load_fqc_configuration(float_graph, framework_platform_capabilities)
140
137
 
141
138
  # Create a Pruner object with the graph and configuration.
142
139
  pruner = Pruner(float_graph_with_compression_config,
@@ -122,7 +122,7 @@ if FOUND_TF:
122
122
 
123
123
  >>> quantized_model, quantization_info = mct.ptq.keras_post_training_quantization(model, repr_datagen, ru, core_config=config)
124
124
 
125
- For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
125
+ For more configuration options, please take a look at our `API documentation <https://sonysemiconductorsolutions.github.io/mct-model-optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
126
126
 
127
127
  """
128
128
 
@@ -167,7 +167,7 @@ if FOUND_TF:
167
167
 
168
168
  >>> quantized_model = tf.keras.models.load_model(model_file, custom_objects=custom_objects)
169
169
 
170
- For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
170
+ For more configuration options, please take a look at our `API documentation <https://sonysemiconductorsolutions.github.io/mct-model-optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
171
171
 
172
172
  """
173
173
 
@@ -136,7 +136,7 @@ if FOUND_TORCH:
136
136
 
137
137
  >>> quantized_model, quantization_info = mct.qat.pytorch_quantization_aware_training_init_experimental(model, repr_datagen, core_config=config)
138
138
 
139
- For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
139
+ For more configuration options, please take a look at our `API documentation <https://sonysemiconductorsolutions.github.io/mct-model-optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
140
140
 
141
141
  """
142
142
  Logger.warning(
@@ -0,0 +1,14 @@
1
+ # Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,223 @@
1
+ # Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import List, Optional
16
+
17
+ from model_compression_toolkit.core.common import Graph, BaseNode
18
+ from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
19
+ from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator
20
+ from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
21
+ CandidateNodeQuantizationConfig, NodeQuantizationConfig
22
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import \
23
+ NodeActivationQuantizationConfig, NodeWeightsQuantizationConfig, ActivationQuantizationMode
24
+ from model_compression_toolkit.core.common.quantization.set_node_quantization_config import filter_node_qco_by_graph
25
+ from model_compression_toolkit.logger import Logger
26
+ from model_compression_toolkit.target_platform_capabilities import FrameworkQuantizationCapabilities, \
27
+ QuantizationConfigOptions, OpQuantizationConfig
28
+
29
+
30
+ def load_fqc_configuration(graph: Graph, fqc: FrameworkQuantizationCapabilities):
31
+ """
32
+ Set-up graph for quantization per TPC.
33
+ Each node will contain quantization candidates for mixed precision and the base config for single precision.
34
+ The graph will contain the fusing info.
35
+
36
+ Args:
37
+ graph: graph.
38
+ fqc: framework quantization capabilities object.
39
+
40
+ Returns:
41
+ Updated graph.
42
+ """
43
+ graph = _set_nodes_quantization_configuration(graph, fqc)
44
+ graph = _set_fusion_info(graph, fqc)
45
+
46
+ return graph
47
+
48
+
49
+ def set_quantization_configs_to_node(node: BaseNode,
50
+ graph: Graph,
51
+ fqc: FrameworkQuantizationCapabilities):
52
+ """
53
+ Create and set quantization configurations to a node (for both weights and activation).
54
+
55
+ Args:
56
+ node (BaseNode): Node to set its quantization configurations.
57
+ graph (Graph): Model's internal representation graph.
58
+ fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to get default OpQuantizationConfig.
59
+ """
60
+ qc_options = fetch_qc_options_for_node(node, fqc)
61
+ base_config, candidates_qcs = filter_node_qco_by_graph(node, fqc, graph, qc_options)
62
+
63
+ node_attrs_list = node.get_node_weights_attributes()
64
+ mp_candidates = [_create_candidate(node.channel_axis, qc, node_attrs_list)
65
+ for qc in candidates_qcs]
66
+ sp_cfg = _create_candidate(node.channel_axis, base_config, node_attrs_list)
67
+
68
+ node.quantization_cfg = NodeQuantizationConfig(base_quantization_cfg=sp_cfg,
69
+ candidates_quantization_cfg=mp_candidates)
70
+
71
+ # TODO is not needed anymore as find min/max candidate look for a real max/min, but some tests still count on it
72
+ node.sort_node_candidates()
73
+
74
+ if not node.has_activation:
75
+ node.quantization_cfg.update_activation_quantization_mode(ActivationQuantizationMode.NO_QUANT)
76
+
77
+ _disable_unsupported_quant_preserving(node, graph)
78
+
79
+
80
+ def fetch_qc_options_for_node(node: BaseNode,
81
+ fqc: FrameworkQuantizationCapabilities,
82
+ return_default=True) -> Optional[QuantizationConfigOptions]:
83
+ """
84
+ Get quantization configuration options for the node from TPC.
85
+
86
+ Args:
87
+ node: node for which to fetch quantization configuration.
88
+ fqc: framework quantization capabilities.
89
+ return_default: whether to return the default qco or None if node op is not in FQC.
90
+
91
+ Returns:
92
+ Quantization configuration options for the node.
93
+ """
94
+ # qcos by filters
95
+ filter_matches = [(fl, qco) for fl, qco in fqc.filterlayer2qco.items() if node.is_match_filter_params(fl)]
96
+ fls, filter_qcos = zip(*filter_matches) if filter_matches else (None, None)
97
+ if filter_qcos and any(qco != filter_qcos[0] for qco in filter_qcos[1:]):
98
+ raise ValueError(f'Cannot assign quantization configuration to {node} as it matches more than one filter with '
99
+ f'conflicting configs: {fls}.')
100
+
101
+ # qco by opset
102
+ # must use is_match_type for functional op in TF2.15
103
+ matches = [(op_type, qco) for op_type, qco in fqc.layer2qco.items() if node.is_match_type(op_type)]
104
+ op_types, qcos = zip(*matches) if matches else (None, None)
105
+ if qcos and any(qco != qcos[0] for qco in qcos[1:]):
106
+ raise ValueError(f'Cannot assign quantization configuration to {node} as it matches more than one op type with '
107
+ f'conflicting configs: {op_types}.')
108
+
109
+ # if node matches by both filter and opset, filter takes priority
110
+ if filter_qcos:
111
+ return filter_qcos[0]
112
+
113
+ if qcos:
114
+ return qcos[0]
115
+
116
+ return fqc.tpc.default_qco if return_default else None
117
+
118
+
119
+ def _set_nodes_quantization_configuration(graph: Graph,
120
+ fqc: FrameworkQuantizationCapabilities) -> Graph:
121
+ """
122
+ Set quantization configuration for each graph node.
123
+
124
+ Args:
125
+ graph: graph to set with quantization configuration.
126
+ fqc: framework quantization capabilities.
127
+
128
+ Returns:
129
+ Graph: The graph with quantization configurations attached to each node in it.
130
+ """
131
+ _validate_custom_ops_have_qco(graph, fqc)
132
+
133
+ for n in graph.get_topo_sorted_nodes():
134
+ set_quantization_configs_to_node(node=n,
135
+ graph=graph,
136
+ fqc=fqc)
137
+ return graph
138
+
139
+
140
+ def _set_fusion_info(graph: Graph, fqc: FrameworkQuantizationCapabilities) -> Graph:
141
+ """
142
+
143
+ Args:
144
+ graph: graph.
145
+ fqc: quantization capabilities with attached framework.
146
+
147
+ Returns:
148
+
149
+ """
150
+ # TODO fix the dict with const keys inside get_fusing_patterns. use named tuple or class
151
+ # TODO irena instead of storing fusion inside graph (including tpc objects) and then let graph convert tpc op config to
152
+ # node config, do it here and only store in graph whatever is relevant after this stage.
153
+ fusing_info = FusingInfoGenerator(fqc.get_fusing_patterns()).generate_fusing_info(graph)
154
+ graph.fusing_info = fusing_info
155
+ graph.override_fused_node_activation_quantization_candidates()
156
+ return graph
157
+
158
+
159
+ def _disable_unsupported_quant_preserving(node: BaseNode, graph: Graph):
160
+ """
161
+ Disable quantization for quantization preserving ops in cases it cannot be supported
162
+ (multiple inputs or un-quantized previous node).
163
+
164
+ Args:
165
+ node: current node.
166
+ graph: graph.
167
+ """
168
+ if not node.quantization_cfg.get_activation_quant_mode() == ActivationQuantizationMode.PRESERVE_QUANT:
169
+ return
170
+
171
+ prev_nodes = graph.get_prev_nodes(node)
172
+ if len(prev_nodes) != 1:
173
+ Logger.info(f'Disabling Quantization-Preserving for node {node.name} with {len(prev_nodes)} inputs.')
174
+ node.quantization_cfg.update_activation_quantization_mode(ActivationQuantizationMode.NO_QUANT)
175
+ elif prev_nodes[0].quantization_cfg.get_activation_quant_mode() == ActivationQuantizationMode.NO_QUANT:
176
+ Logger.info(f'Disabling Quantization-Preserving for node {node.name} since previous node activation '
177
+ f'quantization is disabled.')
178
+ node.quantization_cfg.update_activation_quantization_mode(ActivationQuantizationMode.NO_QUANT)
179
+
180
+
181
+ # TODO irena copied from graph.set_fqc as is. Why does it have Keras errors?
182
+ def _validate_custom_ops_have_qco(graph, fqc):
183
+ custom_nodes = [n for n in graph.nodes if n.is_custom]
184
+ for n in custom_nodes:
185
+ qco = fetch_qc_options_for_node(n, fqc, return_default=False)
186
+ if not qco:
187
+ Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. '
188
+ ' Please add the custom layer to Framework Quantization Capabilities (FQC), or file a feature '
189
+ 'request or an issue if you believe this should be supported.') # pragma: no cover
190
+ if any([qc.default_weight_attr_config.enable_weights_quantization for qc in qco.quantization_configurations]):
191
+ Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') # pragma: no cover
192
+
193
+
194
+ def _create_candidate(weight_channel_axis: ChannelAxisMapping,
195
+ op_cfg: OpQuantizationConfig,
196
+ node_attrs_list: List[str]) -> CandidateNodeQuantizationConfig:
197
+ """
198
+ Create quantization configuration candidate.
199
+
200
+ Args:
201
+ weight_channel_axis: channels axes of the node's kernel.
202
+ op_cfg: quantization config for the op.
203
+ node_attrs_list: A list of the node's weights attributes names.
204
+
205
+ Returns:
206
+ Candidate quantization config.
207
+ """
208
+
209
+ aqc = NodeActivationQuantizationConfig(op_cfg=op_cfg)
210
+
211
+ # TODO: remove this validation and warning once enabling all attributes quantization by default
212
+ attrs_with_enabled_quantization = [attr for attr, cfg in op_cfg.attr_weights_configs_mapping.items()
213
+ if cfg.enable_weights_quantization]
214
+ if len(attrs_with_enabled_quantization) > 1:
215
+ Logger.warning(f"Multiple weights attributes quantization is enabled via the provided FQC."
216
+ f"Quantizing any attribute other than the kernel is experimental "
217
+ f"and may be subject to unstable behavior."
218
+ f"Attributes with enabled weights quantization: {attrs_with_enabled_quantization}.")
219
+ wqc = NodeWeightsQuantizationConfig(op_cfg=op_cfg,
220
+ weights_channels_axis=weight_channel_axis,
221
+ node_attrs_list=node_attrs_list)
222
+
223
+ return CandidateNodeQuantizationConfig(activation_quantization_cfg=aqc, weights_quantization_cfg=wqc)
@@ -29,7 +29,7 @@ QNNPACK_TP_MODEL = 'qnnpack'
29
29
  # TP Attributes
30
30
  KERNEL_ATTR = "kernel_attr"
31
31
  BIAS_ATTR = "bias_attr"
32
- POS_ATTR = "pos_attr"
32
+ POSITIONAL_ATTR = "pos_attr"
33
33
 
34
34
  # TODO: this is duplicated from the core frameworks constants files, because the original consts can't be used here
35
35
  # duo to circular dependency. It might be best to extract the constants from the core file and put them here (in a