mct-nightly 2.4.0.20250630.629__py3-none-any.whl → 2.4.0.20250702.605__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/METADATA +16 -16
  2. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/RECORD +75 -72
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -1
  5. model_compression_toolkit/core/common/framework_info.py +5 -32
  6. model_compression_toolkit/core/common/fusion/graph_fuser.py +12 -9
  7. model_compression_toolkit/core/common/graph/base_graph.py +20 -37
  8. model_compression_toolkit/core/common/graph/base_node.py +13 -106
  9. model_compression_toolkit/core/common/graph/functional_node.py +1 -1
  10. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +12 -10
  11. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +14 -9
  12. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +9 -15
  13. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +2 -3
  14. model_compression_toolkit/core/common/network_editors/__init__.py +8 -1
  15. model_compression_toolkit/core/common/network_editors/actions.py +4 -96
  16. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  17. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +116 -56
  18. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -1
  19. model_compression_toolkit/core/common/quantization/node_quantization_config.py +55 -179
  20. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +21 -1
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +8 -5
  22. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -70
  23. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +10 -12
  24. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +54 -30
  25. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  26. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +93 -398
  27. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +2 -5
  28. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -4
  29. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -6
  30. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +12 -6
  31. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +1 -1
  32. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -2
  33. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +33 -33
  34. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +2 -4
  35. model_compression_toolkit/core/graph_prep_runner.py +31 -20
  36. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +5 -2
  37. model_compression_toolkit/core/keras/default_framework_info.py +0 -11
  38. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +9 -6
  39. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +3 -1
  40. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -1
  41. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +2 -1
  42. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
  43. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +47 -0
  44. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +3 -2
  45. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +5 -2
  46. model_compression_toolkit/core/pytorch/default_framework_info.py +0 -12
  47. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  48. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +2 -0
  49. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +1 -1
  50. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +2 -1
  51. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +1 -1
  52. model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
  53. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +45 -0
  54. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +3 -2
  55. model_compression_toolkit/core/runner.py +1 -1
  56. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +7 -3
  57. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  58. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +12 -3
  59. model_compression_toolkit/pruning/keras/pruning_facade.py +5 -9
  60. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -5
  61. model_compression_toolkit/ptq/keras/quantization_facade.py +1 -1
  62. model_compression_toolkit/qat/keras/quantization_facade.py +1 -1
  63. model_compression_toolkit/qat/pytorch/quantization_facade.py +1 -1
  64. model_compression_toolkit/quantization_preparation/__init__.py +14 -0
  65. model_compression_toolkit/quantization_preparation/load_fqc.py +223 -0
  66. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  67. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -78
  68. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/WHEEL +0 -0
  69. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/licenses/LICENSE.md +0 -0
  70. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/top_level.txt +0 -0
  71. /model_compression_toolkit/core/keras/{quantizer → quantization}/__init__.py +0 -0
  72. /model_compression_toolkit/core/keras/{quantizer → quantization}/fake_quant_builder.py +0 -0
  73. /model_compression_toolkit/core/keras/{quantizer → quantization}/lut_fake_quant.py +0 -0
  74. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/__init__.py +0 -0
  75. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/fake_quant_builder.py +0 -0
  76. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/lut_fake_quant.py +0 -0
@@ -14,12 +14,12 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import copy
17
- from typing import List, Tuple
17
+ from typing import Tuple
18
18
 
19
19
  from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator
20
20
  from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
21
- from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig
22
- from itertools import product
21
+ from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
22
+ CandidateNodeQuantizationConfig, NodeQuantizationConfig
23
23
 
24
24
 
25
25
  class FusedLayerType:
@@ -30,6 +30,7 @@ class FusedLayerType:
30
30
  def __init__(self):
31
31
  self.__name__ = 'FusedLayer'
32
32
 
33
+
33
34
  class GraphFuser:
34
35
  def apply_node_fusion(self, graph: Graph) -> Graph:
35
36
  """
@@ -64,7 +65,6 @@ class GraphFuser:
64
65
 
65
66
  return graph_copy
66
67
 
67
-
68
68
  @staticmethod
69
69
  def _create_fused_node(fused_node_id: str, nodes: Tuple[BaseNode]) -> BaseNode:
70
70
  """
@@ -86,10 +86,15 @@ class GraphFuser:
86
86
  weights={},
87
87
  layer_class=FusedLayerType)
88
88
 
89
+ base_cfg = CandidateNodeQuantizationConfig(
90
+ activation_quantization_cfg=nodes[-1].quantization_cfg.base_quantization_cfg.activation_quantization_cfg,
91
+ weights_quantization_cfg=None
92
+ )
89
93
  activation_cfgs = [c.activation_quantization_cfg for c in nodes[-1].candidates_quantization_cfg]
90
- fused_node.candidates_quantization_cfg = [
91
- CandidateNodeQuantizationConfig(weights_quantization_cfg=None, activation_quantization_cfg=a) for a in
92
- activation_cfgs]
94
+ candidates = [CandidateNodeQuantizationConfig(weights_quantization_cfg=None, activation_quantization_cfg=a)
95
+ for a in activation_cfgs]
96
+ fused_node.quantization_cfg = NodeQuantizationConfig(base_quantization_cfg=base_cfg,
97
+ candidates_quantization_cfg=candidates)
93
98
 
94
99
  # Keep the final configurations if they were set already.
95
100
  fused_node.final_weights_quantization_cfg = nodes[0].final_weights_quantization_cfg
@@ -158,5 +163,3 @@ class GraphFuser:
158
163
 
159
164
  # Finally, add the new fused node to the graph
160
165
  graph.add_node(fused_node)
161
-
162
-
@@ -39,6 +39,7 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
39
39
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
40
40
  FrameworkQuantizationCapabilities
41
41
 
42
+
42
43
  def validate_graph_after_change(method: Callable) -> Callable:
43
44
  """
44
45
  Decorator for graph-mutating methods. After the decorated method executes,
@@ -120,28 +121,13 @@ class Graph(nx.MultiDiGraph, GraphSearches):
120
121
  def fusing_info(self, fusing_info: FusingInfo):
121
122
  self._fusing_info = fusing_info
122
123
 
123
- def set_fqc(self,
124
- fqc: FrameworkQuantizationCapabilities):
124
+ def set_fqc(self, fqc: FrameworkQuantizationCapabilities):
125
125
  """
126
126
  Set the graph's FQC.
127
127
  Args:
128
128
  fqc: FrameworkQuantizationCapabilities object.
129
129
  """
130
- # validate graph nodes are either from the framework or a custom layer defined in the FQC
131
- # Validate graph nodes are either built-in layers from the framework or custom layers defined in the FQC
132
- fqc_layers = fqc.op_sets_to_layers.get_layers()
133
- fqc_filtered_layers = [layer for layer in fqc_layers if isinstance(layer, LayerFilterParams)]
134
- for n in self.nodes:
135
- is_node_in_fqc = any([n.is_match_type(_type) for _type in fqc_layers]) or \
136
- any([n.is_match_filter_params(filtered_layer) for filtered_layer in fqc_filtered_layers])
137
- if n.is_custom:
138
- if not is_node_in_fqc:
139
- Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. '
140
- ' Please add the custom layer to Framework Quantization Capabilities (FQC), or file a feature '
141
- 'request or an issue if you believe this should be supported.') # pragma: no cover
142
- if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(fqc).quantization_configurations]):
143
- Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') # pragma: no cover
144
-
130
+ # TODO irena: this is only passed for negative shift activation.
145
131
  self.fqc = fqc
146
132
 
147
133
  def get_topo_sorted_nodes(self):
@@ -578,7 +564,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
578
564
  A list of nodes that their weights can be configured (namely, has one or more weight qc candidate).
579
565
  """
580
566
  # configurability is only relevant for kernel attribute quantization
581
- potential_conf_nodes = [n for n in list(self) if n.is_kernel_op]
567
+ potential_conf_nodes = [n for n in self.nodes if n.kernel_attr]
582
568
 
583
569
  def is_configurable(n):
584
570
  return n.is_configurable_weight(n.kernel_attr) and (not n.reuse or include_reused_nodes)
@@ -693,10 +679,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
693
679
  """
694
680
  Gets the final number of bits for quantization of each weights' configurable layer.
695
681
 
696
- Args:
697
- fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
698
-
699
- Returns: A list of pairs of (node type, node's weights quantization bitwidth).
682
+ Returns:
683
+ A list of pairs of (node type, node's weights quantization bitwidth).
700
684
 
701
685
  """
702
686
  sorted_conf_weights = self.get_sorted_weights_configurable_nodes()
@@ -876,32 +860,31 @@ class Graph(nx.MultiDiGraph, GraphSearches):
876
860
 
877
861
  return intermediate_nodes, next_node
878
862
 
863
+ # TODO irena move to load_fqc and clean up tests (currently tests_pytest/common_tests/unit_tests/core/graph/test_base_graph.py)
879
864
  def override_fused_node_activation_quantization_candidates(self):
880
865
  """
881
866
  Override fused node activation quantization candidates for all nodes in fused operations,
882
867
  except for the last node in each fused group.
883
868
  Update the value of quantization_config with the value of op_quaitization_cfg from FusingInfo.
884
869
  """
885
- from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig
886
-
887
870
  nodes_in_fln = self.fusing_info.get_inner_fln_nodes()
888
871
  for node in nodes_in_fln:
889
872
  fused_node_op_id = self.fusing_info.get_fused_op_id_for_node(node.name)
890
- fusiong_op_quaitization_cfg = self.fusing_info.get_fused_op_quantization_config(fused_node_op_id)
891
- org_candidate = node.candidates_quantization_cfg[0]
892
- if fusiong_op_quaitization_cfg is not None and fusiong_op_quaitization_cfg.enable_activation_quantization:
893
- # Set ActivationQuantizationMode to FLN_QUANT and update the value of quantization_config
894
- activation_quantization_cfg = NodeActivationQuantizationConfig(qc=org_candidate,
895
- op_cfg=fusiong_op_quaitization_cfg,
896
- activation_quantization_fn=org_candidate.activation_quantization_cfg.activation_quantization_fn,
897
- activation_quantization_params_fn=org_candidate.activation_quantization_cfg.activation_quantization_params_fn)
898
- activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
899
- for qc in node.candidates_quantization_cfg:
900
- qc.activation_quantization_cfg = activation_quantization_cfg
873
+ fusing_op_quantization_cfg = self.fusing_info.get_fused_op_quantization_config(fused_node_op_id)
874
+ if fusing_op_quantization_cfg is not None and fusing_op_quantization_cfg.enable_activation_quantization:
875
+ def update(qc):
876
+ qc.activation_quantization_cfg = NodeActivationQuantizationConfig(fusing_op_quantization_cfg)
877
+ qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
878
+ node.quantization_cfg.update_all(update, remove_duplicates=True)
901
879
  else:
902
- # Set ActivationQuantizationMode to FLN_NO_QUANT
880
+ node.quantization_cfg.update_activation_quantization_mode(ActivationQuantizationMode.FLN_NO_QUANT)
881
+ # Remove duplicate candidates. We cannot compare whole candidates since activation configs might not
882
+ # be identical, but we do want to treat them as such. So we only check duplication by weight configs.
883
+ uniq_qcs = []
903
884
  for qc in node.candidates_quantization_cfg:
904
- qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_NO_QUANT
885
+ if not any(qc.weights_quantization_cfg == uqc.weights_quantization_cfg for uqc in uniq_qcs):
886
+ uniq_qcs.append(qc)
887
+ node.quantization_cfg.candidates_quantization_cfg = uniq_qcs
905
888
 
906
889
  def validate(self):
907
890
  """
@@ -21,15 +21,11 @@ import numpy as np
21
21
  from model_compression_toolkit.core.common.framework_info import get_fw_info, ChannelAxisMapping
22
22
  from model_compression_toolkit.constants import WEIGHTS_NBITS_ATTRIBUTE, CORRECTED_BIAS_ATTRIBUTE, \
23
23
  ACTIVATION_N_BITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER
24
+ from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import NodeQuantizationConfig
24
25
  from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
25
26
  ActivationQuantizationMode
26
27
  from model_compression_toolkit.logger import Logger
27
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \
28
- OpQuantizationConfig
29
- from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
30
28
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
31
- from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
32
- FrameworkQuantizationCapabilities
33
29
 
34
30
 
35
31
  WeightAttrT = Union[str, int]
@@ -43,7 +39,6 @@ class NodeFrameworkInfo(NamedTuple):
43
39
  out_channel_axis: int
44
40
  minmax: Tuple[float, float]
45
41
  kernel_attr: str
46
- is_kernel_op: bool
47
42
 
48
43
 
49
44
  class BaseNode:
@@ -95,7 +90,7 @@ class BaseNode:
95
90
  self.inputs_as_list = inputs_as_list
96
91
  self.final_weights_quantization_cfg = None
97
92
  self.final_activation_quantization_cfg = None
98
- self.candidates_quantization_cfg = None
93
+ self.quantization_cfg: NodeQuantizationConfig = None
99
94
  self.prior_info = None
100
95
  self.has_activation = has_activation
101
96
  self.is_custom = is_custom
@@ -108,7 +103,6 @@ class BaseNode:
108
103
  fw_info.get_out_channel_axis(node_type),
109
104
  fw_info.get_layer_min_max(node_type, framework_attr),
110
105
  fw_info.get_kernel_op_attribute(node_type),
111
- fw_info.is_kernel_op(node_type)
112
106
  )
113
107
 
114
108
  def _assert_fw_info_exists(self):
@@ -162,15 +156,9 @@ class BaseNode:
162
156
  return self.node_fw_info.kernel_attr
163
157
 
164
158
  @property
165
- def is_kernel_op(self) -> bool:
166
- """
167
- Check if kernel exists for the node.
168
-
169
- Returns:
170
- Whether the node has a kernel or not.
171
- """
172
- self._assert_fw_info_exists()
173
- return self.node_fw_info.is_kernel_op
159
+ def candidates_quantization_cfg(self):
160
+ assert self.quantization_cfg
161
+ return self.quantization_cfg.candidates_quantization_cfg
174
162
 
175
163
  @property
176
164
  def type(self):
@@ -181,15 +169,6 @@ class BaseNode:
181
169
  """
182
170
  return self.layer_class
183
171
 
184
- def get_has_activation(self):
185
- """
186
- Returns has_activation attribute.
187
-
188
- Returns: Whether the node has activation to quantize.
189
-
190
- """
191
- return self.has_activation
192
-
193
172
  @property
194
173
  def has_positional_weights(self):
195
174
  """
@@ -646,8 +625,9 @@ class BaseNode:
646
625
  Returns: True if the node has at list one quantization configuration candidate with activation quantization enabled.
647
626
  """
648
627
 
649
- return len(self.candidates_quantization_cfg) > 0 and \
650
- any([c.activation_quantization_cfg.enable_activation_quantization for c in self.candidates_quantization_cfg])
628
+ return (len(self.candidates_quantization_cfg) > 0 and
629
+ any([c.activation_quantization_cfg.enable_activation_quantization
630
+ for c in self.candidates_quantization_cfg]))
651
631
 
652
632
  def get_all_weights_attr_candidates(self, attr: str) -> List[WeightsAttrQuantizationConfig]:
653
633
  """
@@ -663,79 +643,6 @@ class BaseNode:
663
643
  # the inner method would log an exception.
664
644
  return [c.weights_quantization_cfg.get_attr_config(attr) for c in self.candidates_quantization_cfg]
665
645
 
666
- def get_qco(self, fqc: FrameworkQuantizationCapabilities) -> QuantizationConfigOptions:
667
- """
668
- Get the QuantizationConfigOptions of the node according
669
- to the mappings from layers/LayerFilterParams to the OperatorsSet in the TargetPlatformCapabilities.
670
-
671
- Args:
672
- fqc: FQC to extract the QuantizationConfigOptions for the node.
673
-
674
- Returns:
675
- QuantizationConfigOptions of the node.
676
- """
677
-
678
- if fqc is None:
679
- Logger.critical(f'Can not retrieve QC options for None FQC') # pragma: no cover
680
-
681
- for fl, qco in fqc.filterlayer2qco.items():
682
- if self.is_match_filter_params(fl):
683
- return qco
684
- # Extract qco with is_match_type to overcome mismatch of function types in TF 2.15
685
- matching_qcos = [_qco for _type, _qco in fqc.layer2qco.items() if self.is_match_type(_type)]
686
- if matching_qcos:
687
- if all([_qco == matching_qcos[0] for _qco in matching_qcos]):
688
- return matching_qcos[0]
689
- else:
690
- Logger.critical(f"Found duplicate qco types for node '{self.name}' of type '{self.type}'!") # pragma: no cover
691
- return fqc.tpc.default_qco
692
-
693
- def filter_node_qco_by_graph(self, fqc: FrameworkQuantizationCapabilities,
694
- next_nodes: List, node_qc_options: QuantizationConfigOptions
695
- ) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
696
- """
697
- Filter quantization config options that don't match the graph.
698
- A node may have several quantization config options with 'activation_n_bits' values, and
699
- the next nodes in the graph may support different bit-width as input activation. This function
700
- filters out quantization config that don't comply to these attributes.
701
-
702
- Args:
703
- fqc: FQC to extract the QuantizationConfigOptions for the next nodes.
704
- next_nodes: Output nodes of current node.
705
- node_qc_options: Node's QuantizationConfigOptions.
706
-
707
- Returns:
708
-
709
- """
710
- # Filter quantization config options that don't match the graph.
711
- _base_config = node_qc_options.base_config
712
- _node_qc_options = node_qc_options.quantization_configurations
713
- if len(next_nodes):
714
- next_nodes_qc_options = [_node.get_qco(fqc) for _node in next_nodes]
715
- next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
716
- for qc_opts in next_nodes_qc_options
717
- for op_cfg in qc_opts.quantization_configurations])
718
-
719
- # Filter node's QC options that match next nodes input bit-width.
720
- _node_qc_options = [_option for _option in _node_qc_options
721
- if _option.activation_n_bits <= next_nodes_supported_input_bitwidth]
722
- if len(_node_qc_options) == 0:
723
- Logger.critical(f"Graph doesn't match FQC bit configurations: {self} -> {next_nodes}.") # pragma: no cover
724
-
725
- # Verify base config match
726
- if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
727
- for qc_opt in next_nodes_qc_options]):
728
- # base_config activation bits doesn't match next node supported input bit-width -> replace with
729
- # a qco from quantization_configurations with maximum activation bit-width.
730
- if len(_node_qc_options) > 0:
731
- output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
732
- _base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
733
- Logger.warning(f"Node {self} base quantization config changed to match Graph and FQC configuration.\nCause: {self} -> {next_nodes}.")
734
- else:
735
- Logger.critical(f"Graph doesn't match FQC bit configurations: {self} -> {next_nodes}.") # pragma: no cover
736
-
737
- return _base_config, _node_qc_options
738
-
739
646
  def is_match_type(self, _type: Type) -> bool:
740
647
  """
741
648
  Check if input type matches the node type, either in instance type or in type name.
@@ -768,7 +675,7 @@ class BaseNode:
768
675
  return False
769
676
 
770
677
  # Get attributes from node to filter
771
- layer_config = self.framework_attr
678
+ layer_config = self.framework_attr.copy()
772
679
  if hasattr(self, "op_call_kwargs"):
773
680
  layer_config.update(self.op_call_kwargs)
774
681
 
@@ -812,11 +719,11 @@ class BaseNode:
812
719
  the candidates in descending order.
813
720
  The operation is done inplace.
814
721
  """
815
- if self.candidates_quantization_cfg is not None:
722
+ if self.quantization_cfg.candidates_quantization_cfg is not None:
816
723
  if self.kernel_attr is not None:
817
- self.candidates_quantization_cfg.sort(
724
+ self.quantization_cfg.candidates_quantization_cfg.sort(
818
725
  key=lambda c: (c.weights_quantization_cfg.get_attr_config(self.kernel_attr).weights_n_bits,
819
726
  c.activation_quantization_cfg.activation_n_bits), reverse=True)
820
727
  else:
821
- self.candidates_quantization_cfg.sort(key=lambda c: c.activation_quantization_cfg.activation_n_bits,
822
- reverse=True)
728
+ self.quantization_cfg.candidates_quantization_cfg.sort(
729
+ key=lambda c: c.activation_quantization_cfg.activation_n_bits, reverse=True)
@@ -103,4 +103,4 @@ class FunctionalNode(BaseNode):
103
103
 
104
104
  """
105
105
  names_match = _type.__name__ == self.type.__name__
106
- return super().is_match_type(_type) or names_match
106
+ return names_match or super().is_match_type(_type)
@@ -19,9 +19,8 @@ from model_compression_toolkit.constants import VIRTUAL_ACTIVATION_WEIGHTS_NODE_
19
19
  VIRTUAL_WEIGHTS_SUFFIX, VIRTUAL_ACTIVATION_SUFFIX, FLOAT_BITWIDTH
20
20
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
21
21
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
22
- CandidateNodeQuantizationConfig
22
+ CandidateNodeQuantizationConfig, NodeQuantizationConfig
23
23
  from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
24
- from model_compression_toolkit.core.common.framework_info import DEFAULT_KERNEL_ATTRIBUTE
25
24
 
26
25
 
27
26
  class VirtualNode(BaseNode, abc.ABC):
@@ -76,8 +75,11 @@ class VirtualSplitWeightsNode(VirtualSplitNode):
76
75
 
77
76
  self.name = origin_node.name + VIRTUAL_WEIGHTS_SUFFIX
78
77
 
79
- self.candidates_quantization_cfg = origin_node.get_unique_weights_candidates(kernel_attr)
80
- for c in self.candidates_quantization_cfg:
78
+ self.quantization_cfg = NodeQuantizationConfig(
79
+ candidates_quantization_cfg=origin_node.get_unique_weights_candidates(kernel_attr),
80
+ base_quantization_cfg=None, validate=False
81
+ )
82
+ for c in self.quantization_cfg.candidates_quantization_cfg:
81
83
  c.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
82
84
  c.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH
83
85
 
@@ -106,10 +108,9 @@ class VirtualSplitActivationNode(VirtualSplitNode):
106
108
  self.weights = {}
107
109
  self.layer_class = activation_class
108
110
 
109
- self.candidates_quantization_cfg = origin_node.get_unique_activation_candidates()
110
- for c in self.candidates_quantization_cfg:
111
- c.weights_quantization_cfg.enable_weights_quantization = False
112
- c.weights_quantization_cfg.weights_n_bits = FLOAT_BITWIDTH
111
+ self.quantization_cfg = NodeQuantizationConfig(candidates_quantization_cfg=origin_node.get_unique_activation_candidates(),
112
+ base_quantization_cfg=None, validate=False)
113
+ self.quantization_cfg.disable_weights_quantization()
113
114
 
114
115
 
115
116
  class VirtualActivationWeightsNode(VirtualNode):
@@ -143,7 +144,7 @@ class VirtualActivationWeightsNode(VirtualNode):
143
144
  weights = weights_node.weights.copy()
144
145
  act_node_w_rename = {}
145
146
  if act_node.weights:
146
- if act_node.kernel_attr != DEFAULT_KERNEL_ATTRIBUTE:
147
+ if act_node.kernel_attr:
147
148
  raise NotImplementedError(f'Node {act_node} with kernel cannot be used as activation for '
148
149
  f'VirtualActivationWeightsNode.')
149
150
  if act_node.has_any_configurable_weight():
@@ -200,4 +201,5 @@ class VirtualActivationWeightsNode(VirtualNode):
200
201
  v_candidates.sort(key=lambda c: (c.weights_quantization_cfg.get_attr_config(weights_node.kernel_attr).weights_n_bits,
201
202
  c.activation_quantization_cfg.activation_n_bits), reverse=True)
202
203
 
203
- self.candidates_quantization_cfg = v_candidates
204
+ self.quantization_cfg = NodeQuantizationConfig(candidates_quantization_cfg=v_candidates,
205
+ base_quantization_cfg=None, validate=False)
@@ -18,6 +18,8 @@ import numpy as np
18
18
 
19
19
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
20
20
  CandidateNodeQuantizationConfig
21
+ from model_compression_toolkit.core.common.quantization.quantization_fn_selection import (get_activation_quantization_fn,
22
+ get_weights_quantization_fn)
21
23
 
22
24
 
23
25
  def verify_candidates_descending_order(node_q_cfg: List[CandidateNodeQuantizationConfig],
@@ -77,20 +79,21 @@ def init_quantized_weights(node_q_cfg: List[CandidateNodeQuantizationConfig],
77
79
  quantized_weights = []
78
80
  for qc in node_q_cfg:
79
81
  qc_weights_attr = qc.weights_quantization_cfg.get_attr_config(kernel_attr)
80
- q_weight = qc_weights_attr.weights_quantization_fn(float_weights,
81
- qc_weights_attr.weights_n_bits,
82
- True,
83
- qc_weights_attr.weights_quantization_params,
84
- qc_weights_attr.weights_per_channel_threshold,
85
- qc_weights_attr.weights_channels_axis[
86
- 0]) # output channel axis
82
+ weights_quantization_fn = get_weights_quantization_fn(qc_weights_attr.weights_quantization_method)
83
+ q_weight = weights_quantization_fn(float_weights,
84
+ qc_weights_attr.weights_n_bits,
85
+ True,
86
+ qc_weights_attr.weights_quantization_params,
87
+ qc_weights_attr.weights_per_channel_threshold,
88
+ qc_weights_attr.weights_channels_axis[0]) # output channel axis
87
89
 
88
90
  quantized_weights.append(fw_tensor_convert_func(q_weight))
89
91
 
90
92
  return quantized_weights
91
93
 
92
94
 
93
- def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]) -> List:
95
+ def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig],
96
+ get_activation_quantization_fn_factory: Callable) -> List:
94
97
  """
95
98
  Builds a list of quantizers for each of the bitwidth candidates for activation quantization,
96
99
  to be stored and used during MP search.
@@ -98,6 +101,7 @@ def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]
98
101
  Args:
99
102
  node_q_cfg: Quantization configuration candidates of the node that generated the layer that will
100
103
  use this quantizer.
104
+ get_activation_quantization_fn_factory: activation quantization functions factory.
101
105
 
102
106
  Returns: a list of activation quantizers - for each bitwidth and layer's attribute to be quantized.
103
107
  """
@@ -105,6 +109,7 @@ def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]
105
109
  activation_quantizers = []
106
110
  for index, qc in enumerate(node_q_cfg):
107
111
  q_activation = node_q_cfg[index].activation_quantization_cfg
108
- activation_quantizers.append(q_activation.quantize_node_output)
112
+ quantizer = get_activation_quantization_fn(q_activation, get_activation_quantization_fn_factory)
113
+ activation_quantizers.append(quantizer)
109
114
 
110
115
  return activation_quantizers
@@ -12,17 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- import numpy as np
16
-
17
- from model_compression_toolkit.core import ResourceUtilization, FrameworkInfo
15
+ from model_compression_toolkit.core import ResourceUtilization
18
16
  from model_compression_toolkit.core.common import Graph
19
- from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
20
- FrameworkQuantizationCapabilities
21
17
 
22
18
 
23
19
  def filter_candidates_for_mixed_precision(graph: Graph,
24
- target_resource_utilization: ResourceUtilization,
25
- fqc: FrameworkQuantizationCapabilities):
20
+ target_resource_utilization: ResourceUtilization):
26
21
  """
27
22
  Filters out candidates in case of mixed precision search for only weights or activation compression.
28
23
  For instance, if running only weights compression - filters out candidates of activation configurable nodes
@@ -34,8 +29,6 @@ def filter_candidates_for_mixed_precision(graph: Graph,
34
29
  Args:
35
30
  graph: A graph representation of the model to be quantized.
36
31
  target_resource_utilization: The resource utilization of the target device.
37
- fqc: FrameworkQuantizationCapabilities object that describes the desired inference target platform.
38
-
39
32
  """
40
33
 
41
34
  tru = target_resource_utilization
@@ -47,20 +40,21 @@ def filter_candidates_for_mixed_precision(graph: Graph,
47
40
  # filter out candidates activation only configurable node
48
41
  activation_configurable_nodes = [n for n in graph.get_activation_configurable_nodes()]
49
42
  for n in activation_configurable_nodes:
50
- base_cfg_nbits = n.get_qco(fqc).base_config.activation_n_bits
51
- filtered_conf = [c for c in n.candidates_quantization_cfg if
43
+ base_cfg_nbits = n.quantization_cfg.base_quantization_cfg.activation_quantization_cfg.activation_n_bits
44
+ filtered_cfgs = [c for c in n.candidates_quantization_cfg if
52
45
  c.activation_quantization_cfg.enable_activation_quantization and
53
46
  c.activation_quantization_cfg.activation_n_bits == base_cfg_nbits]
54
47
 
55
- n.candidates_quantization_cfg = filtered_conf
48
+ n.quantization_cfg.candidates_quantization_cfg = filtered_cfgs
56
49
 
57
50
  elif tru.activation_restricted() and not tru.weight_restricted():
58
51
  # Running mixed precision for activation compression only -
59
52
  # filter out candidates weights only configurable node
60
53
  weight_configurable_nodes = [n for n in graph.get_weights_configurable_nodes()]
61
54
  for n in weight_configurable_nodes:
62
- base_cfg_nbits = n.get_qco(fqc).base_config.attr_weights_configs_mapping[n.kernel_attr].weights_n_bits
63
- filtered_conf = [c for c in n.candidates_quantization_cfg if
55
+ base_cfg_nbits = (n.quantization_cfg.base_quantization_cfg.weights_quantization_cfg.
56
+ get_attr_config(n.kernel_attr).weights_n_bits)
57
+ filtered_cfgs = [c for c in n.candidates_quantization_cfg if
64
58
  c.weights_quantization_cfg.get_attr_config(n.kernel_attr).enable_weights_quantization and
65
59
  c.weights_quantization_cfg.get_attr_config(n.kernel_attr).weights_n_bits == base_cfg_nbits]
66
- n.candidates_quantization_cfg = filtered_conf
60
+ n.quantization_cfg.candidates_quantization_cfg = filtered_cfgs
@@ -392,9 +392,8 @@ class DistanceMetricCalculator(MetricCalculator):
392
392
  """
393
393
 
394
394
  return [n.node for n in graph.get_outputs()
395
- if (n.node.is_kernel_op and
396
- n.node.is_weights_quantization_enabled(n.node.kernel_attr)) or
397
- n.node.is_activation_quantization_enabled()]
395
+ if (n.node.kernel_attr and n.node.is_weights_quantization_enabled(n.node.kernel_attr))
396
+ or n.node.is_activation_quantization_enabled()]
398
397
 
399
398
  @staticmethod
400
399
  def bound_num_interest_points(sorted_ip_list: List[BaseNode], num_ip_factor: float) -> List[BaseNode]:
@@ -13,7 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.network_editors.actions import ChangeCandidatesWeightsQuantConfigAttr, ChangeFinalWeightsQuantConfigAttr, ChangeCandidatesActivationQuantConfigAttr, ChangeQuantizationParamFunction, ChangeCandidatesActivationQuantizationMethod, ChangeFinalWeightsQuantizationMethod, ChangeCandidatesWeightsQuantizationMethod, ChangeFinalActivationQuantConfigAttr
16
+ from model_compression_toolkit.core.common.network_editors.actions import (
17
+ ChangeCandidatesWeightsQuantConfigAttr,
18
+ ChangeFinalWeightsQuantConfigAttr,
19
+ ChangeCandidatesActivationQuantConfigAttr,
20
+ ChangeCandidatesActivationQuantizationMethod,
21
+ ChangeFinalWeightsQuantizationMethod,
22
+ ChangeCandidatesWeightsQuantizationMethod,
23
+ ChangeFinalActivationQuantConfigAttr)
17
24
  from model_compression_toolkit.core.common.network_editors.actions import EditRule
18
25
  from model_compression_toolkit.core.common.network_editors.node_filters import NodeTypeFilter, NodeNameScopeFilter, \
19
26
  NodeNameFilter