mct-nightly 2.4.0.20250630.629__py3-none-any.whl → 2.4.0.20250701.185106__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/METADATA +16 -16
  2. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/RECORD +75 -72
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -1
  5. model_compression_toolkit/core/common/framework_info.py +5 -32
  6. model_compression_toolkit/core/common/fusion/graph_fuser.py +12 -9
  7. model_compression_toolkit/core/common/graph/base_graph.py +20 -37
  8. model_compression_toolkit/core/common/graph/base_node.py +13 -106
  9. model_compression_toolkit/core/common/graph/functional_node.py +1 -1
  10. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +12 -10
  11. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +14 -9
  12. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +9 -15
  13. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +2 -3
  14. model_compression_toolkit/core/common/network_editors/__init__.py +8 -1
  15. model_compression_toolkit/core/common/network_editors/actions.py +4 -96
  16. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  17. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +116 -56
  18. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -1
  19. model_compression_toolkit/core/common/quantization/node_quantization_config.py +55 -179
  20. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +21 -1
  21. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +8 -5
  22. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -70
  23. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +10 -12
  24. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +54 -30
  25. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  26. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +93 -398
  27. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +2 -5
  28. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -4
  29. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -6
  30. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +12 -6
  31. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +1 -1
  32. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -2
  33. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +33 -33
  34. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +2 -4
  35. model_compression_toolkit/core/graph_prep_runner.py +31 -20
  36. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +5 -2
  37. model_compression_toolkit/core/keras/default_framework_info.py +0 -11
  38. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +9 -6
  39. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +3 -1
  40. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -1
  41. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +2 -1
  42. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
  43. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +47 -0
  44. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +3 -2
  45. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +5 -2
  46. model_compression_toolkit/core/pytorch/default_framework_info.py +0 -12
  47. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  48. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +2 -0
  49. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +1 -1
  50. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +2 -1
  51. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +1 -1
  52. model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
  53. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +45 -0
  54. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +3 -2
  55. model_compression_toolkit/core/runner.py +1 -1
  56. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +7 -3
  57. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  58. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +12 -3
  59. model_compression_toolkit/pruning/keras/pruning_facade.py +5 -9
  60. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -5
  61. model_compression_toolkit/ptq/keras/quantization_facade.py +1 -1
  62. model_compression_toolkit/qat/keras/quantization_facade.py +1 -1
  63. model_compression_toolkit/qat/pytorch/quantization_facade.py +1 -1
  64. model_compression_toolkit/quantization_preparation/__init__.py +14 -0
  65. model_compression_toolkit/quantization_preparation/load_fqc.py +223 -0
  66. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  67. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -78
  68. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/WHEEL +0 -0
  69. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/licenses/LICENSE.md +0 -0
  70. {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/top_level.txt +0 -0
  71. /model_compression_toolkit/core/keras/{quantizer → quantization}/__init__.py +0 -0
  72. /model_compression_toolkit/core/keras/{quantizer → quantization}/fake_quant_builder.py +0 -0
  73. /model_compression_toolkit/core/keras/{quantizer → quantization}/lut_fake_quant.py +0 -0
  74. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/__init__.py +0 -0
  75. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/fake_quant_builder.py +0 -0
  76. /model_compression_toolkit/core/pytorch/{quantizer → quantization}/lut_fake_quant.py +0 -0
@@ -12,73 +12,22 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- import copy
16
- from typing import List, Tuple, Dict, Optional
15
+ from typing import List, Tuple, Dict
17
16
 
18
- from mct_quantizers.common.constants import WEIGHTS_N_BITS, ACTIVATION_N_BITS
19
- from model_compression_toolkit.constants import WEIGHTS, ACTIVATION
20
17
  from model_compression_toolkit.core.common import BaseNode
18
+ from model_compression_toolkit.core.common.graph.base_graph import Graph
21
19
  from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
20
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
22
21
  from model_compression_toolkit.logger import Logger
23
- from model_compression_toolkit.core.common.framework_info import get_fw_info, ChannelAxisMapping
24
- from model_compression_toolkit.core.common.graph.base_graph import Graph
25
- from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
26
- CandidateNodeQuantizationConfig
27
- from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig, \
28
- ActivationQuantizationMode
29
- from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
30
- QuantizationErrorMethod
31
- from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
32
- get_activation_quantization_params_fn
33
- from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
22
+ from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR
34
23
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
35
24
  QuantizationConfigOptions
25
+ from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
36
26
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
37
27
  FrameworkQuantizationCapabilities
38
- from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
39
-
40
-
41
- def set_quantization_configuration_to_graph(graph: Graph,
42
- quant_config: QuantizationConfig,
43
- bit_width_config: BitWidthConfig = None,
44
- mixed_precision_enable: bool = False,
45
- running_gptq: bool = False) -> Graph:
46
- """
47
- Add quantization configuration for each graph node.
48
-
49
- Args:
50
- graph (Graph): Graph for which to add quantization info to each node.
51
- quant_config (QuantizationConfig): Quantization configuration containing parameters for how the graph should be quantized.
52
- bit_width_config (BitWidthConfig): Configuration for manual bit width selection. Defaults to None.
53
- mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
54
- running_gptq (bool): Whether or not a GPTQ optimization is planned to run after the PTQ process. Defaults to False.
55
-
56
- Returns:
57
- Graph: The graph with quantization configurations attached to each node in it.
58
- """
59
-
60
- if quant_config.weights_error_method == QuantizationErrorMethod.HMSE:
61
- if not running_gptq:
62
- raise ValueError(f"The HMSE error method for parameters selection is only supported when running GPTQ "
63
- f"optimization due to long execution time that is not suitable for basic PTQ.")
64
- Logger.warning("Using the HMSE error method for weights quantization parameters search. "
65
- "Note: This method may significantly increase runtime during the parameter search process.")
66
-
67
- nodes_to_manipulate_activation_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_activation_bit_widths(graph)
68
- nodes_to_manipulate_weights_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_weights_bit_widths(graph)
69
-
70
- for n in graph.get_topo_sorted_nodes():
71
- manual_bit_width_override = {ACTIVATION: nodes_to_manipulate_activation_bit_widths.get(n),
72
- WEIGHTS: nodes_to_manipulate_weights_bit_widths.get(n)}
73
- set_quantization_configs_to_node(node=n,
74
- graph=graph,
75
- quant_config=quant_config,
76
- fqc=graph.fqc,
77
- mixed_precision_enable=mixed_precision_enable,
78
- manual_bit_width_override=manual_bit_width_override)
79
- return graph
80
28
 
81
29
 
30
+ # TODO irena refactor (if needed) and move to load_fqc
82
31
  def filter_node_qco_by_graph(node: BaseNode,
83
32
  fqc: FrameworkQuantizationCapabilities,
84
33
  graph: Graph,
@@ -101,6 +50,8 @@ def filter_node_qco_by_graph(node: BaseNode,
101
50
  that are compatible with next nodes supported input bit-widths.
102
51
 
103
52
  """
53
+ from model_compression_toolkit.quantization_preparation.load_fqc import fetch_qc_options_for_node
54
+
104
55
  # Filter quantization config options that don't match the graph.
105
56
  _base_config = node_qc_options.base_config
106
57
  _node_qc_options = node_qc_options.quantization_configurations
@@ -110,7 +61,7 @@ def filter_node_qco_by_graph(node: BaseNode,
110
61
  next_nodes = []
111
62
  while len(_next_nodes):
112
63
  n = _next_nodes.pop(0)
113
- qco = n.get_qco(fqc)
64
+ qco = fetch_qc_options_for_node(n, fqc)
114
65
  qp = [qc.quantization_preserving for qc in qco.quantization_configurations]
115
66
  if not all(qp) and any(qp):
116
67
  Logger.error(f'Attribute "quantization_preserving" should be the same for all QuantizaionConfigOptions in {n}.')
@@ -120,7 +71,8 @@ def filter_node_qco_by_graph(node: BaseNode,
120
71
 
121
72
  if len(next_nodes) == 0:
122
73
  return _base_config, _node_qc_options
123
- next_nodes_qc_options = [_node.get_qco(fqc) for _node in next_nodes]
74
+
75
+ next_nodes_qc_options = [fetch_qc_options_for_node(_node, fqc) for _node in next_nodes]
124
76
  all_next_nodes_supported_input_bitwidth = [max_input_activation_n_bits(op_cfg)
125
77
  for qc_opts in next_nodes_qc_options
126
78
  for op_cfg in qc_opts.quantization_configurations
@@ -150,355 +102,98 @@ def filter_node_qco_by_graph(node: BaseNode,
150
102
  return _base_config, _node_qc_options
151
103
 
152
104
 
153
- def set_quantization_configs_to_node(node: BaseNode,
154
- graph: Graph,
155
- quant_config: QuantizationConfig,
156
- fqc: FrameworkQuantizationCapabilities,
157
- mixed_precision_enable: bool = False,
158
- manual_bit_width_override: Optional[Dict] = None):
105
+ def set_manual_bitwidth_config(graph, bit_width_config: BitWidthConfig):
159
106
  """
160
- Create and set quantization configurations to a node (for both weights and activation).
107
+ Filters candidates per manual bit-width config.
161
108
 
162
109
  Args:
163
- node (BaseNode): Node to set its quantization configurations.
164
- graph (Graph): Model's internal representation graph.
165
- quant_config (QuantizationConfig): Quantization configuration to generate the node's configurations from.
166
- fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to get default OpQuantizationConfig.
167
- mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
168
- manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width. Defaults to None.
110
+ graph: graph after candidates have been set on nodes.
111
+ bit_width_config: bit-width config.
169
112
  """
170
- node_qc_options = node.get_qco(fqc)
171
- base_config, node_qc_options_list = filter_node_qco_by_graph(node, fqc, graph, node_qc_options)
113
+ manual_activation_bitwidths = bit_width_config.get_nodes_activation_bit_widths(graph)
114
+ manual_weights_bitwidths = bit_width_config.get_nodes_weights_bit_widths(graph)
172
115
 
173
- # If a manual_bit_width_override is given, filter node_qc_options_list to retain only the options with activation and weights bits equal to manual_bit_width_override,
174
- # and update base_config accordingly.
175
- if manual_bit_width_override is None:
176
- manual_bit_width_override = {ACTIVATION: None, WEIGHTS: None}
177
-
178
- base_config, node_qc_options_list = filter_qc_options_with_manual_bit_width(
179
- node=node,
180
- node_qc_options_list=node_qc_options_list,
181
- base_config=base_config,
182
- manual_bit_width_override=manual_bit_width_override,
183
- mixed_precision_enable=mixed_precision_enable)
116
+ if manual_activation_bitwidths:
117
+ _set_manual_activation_bitwidths(manual_activation_bitwidths)
184
118
 
185
- # Create QC candidates for weights and activation combined
186
- node.candidates_quantization_cfg = _create_node_candidates_qc(quant_config,
187
- node.channel_axis,
188
- node_qc_options_list,
189
- base_config,
190
- node,
191
- mixed_precision_enable=mixed_precision_enable)
119
+ if manual_weights_bitwidths:
120
+ _set_manual_weights_bitwidths(manual_weights_bitwidths)
192
121
 
193
- # sorting the candidates by kernel attribute weights number of bits first and then by activation number of bits
194
- # (in reversed order). since only kernel attribute is quantized in weights mixed precision,
195
- # if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
196
- node.sort_node_candidates()
197
122
 
198
- for candidate_qc in node.candidates_quantization_cfg:
199
- if candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.QUANT and \
200
- not node.get_has_activation():
201
- candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
202
- elif candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.PRESERVE_QUANT:
203
- prev_nodes = graph.get_prev_nodes(node)
204
- if len(prev_nodes) != 1:
205
- # Preserving the quantization of more than 1 previous node is ambiguous, so disable it.
206
- Logger.info(f"Disabling Quantization-Preserving for node {node.name} because it has more than 1 input activations.")
207
- candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
208
- elif not prev_nodes[0].is_quantization_preserving() and not prev_nodes[0].is_activation_quantization_enabled():
209
- # Preserving the quantization of an unquantized node isn't possible, so disable it.
210
- Logger.info(f"Disabling Quantization-Preserving for node {node.name} because previous node activation quantization is disabled.")
211
- candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
212
-
213
-
214
- def create_node_activation_qc(qc: QuantizationConfig,
215
- op_cfg: OpQuantizationConfig) -> NodeActivationQuantizationConfig:
123
+ # TODO irena: check coverage and add missing tests
124
+ def _set_manual_activation_bitwidths(manual_activation_bitwidths: Dict[BaseNode, int]):
216
125
  """
217
- Create an activation quantization configuration from a QuantizationConfig object.
126
+ Filters out candidates that don't match the requested manual activation bitwidths, and updates the
127
+ activation bitwidth in the base quantization config.
218
128
 
219
129
  Args:
220
- qc: QuantizationConfig to create the node's config from.
221
- weights/activations should be quantized)
222
- op_cfg: OpQuantizationConfig with quantizers types to set in node quantization configuration.
223
-
224
- Returns:
225
- Activation quantization configuration of a node.
226
- """
227
-
228
- activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
229
- if activation_quantization_fn is None:
230
- Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
231
-
232
- activation_quantization_params_fn = get_activation_quantization_params_fn(op_cfg.activation_quantization_method)
233
-
234
- return NodeActivationQuantizationConfig(qc,
235
- op_cfg,
236
- activation_quantization_fn,
237
- activation_quantization_params_fn)
238
-
239
-
240
- def _create_node_single_candidate_qc(qc: QuantizationConfig,
241
- weight_channel_axis: ChannelAxisMapping,
242
- op_cfg: OpQuantizationConfig,
243
- node_attrs_list: List[str]) -> CandidateNodeQuantizationConfig:
244
- """
245
- Create quantization configuration candidate from a QuantizationConfig object.
246
- Creates both weights and activation quantization configurations
247
- and initialize a candidate object that encapsulates both.
130
+ manual_activation_bitwidths: nodes' manual activation bitwidth.
131
+
132
+ Raises:
133
+ ValueError: if the manual bitwidth is requested for un-quantized node.
134
+ if the manual bitwidth is not compatible with any candidate.
135
+ """
136
+ for n, a_nbits in manual_activation_bitwidths.items():
137
+ quant_mode = n.quantization_cfg.get_activation_quant_mode()
138
+ # TODO irena: for FLN I think it should be ignored with warning for layer filter, and error for name filter
139
+ if quant_mode != ActivationQuantizationMode.QUANT:
140
+ raise ValueError(f'Cannot apply manual activation bit-width for node {n} with activation quantization mode'
141
+ f'{quant_mode}, as it does not have its own quantization configuration.')
142
+ candidates = [qc for qc in n.candidates_quantization_cfg
143
+ if qc.activation_quantization_cfg.activation_n_bits == a_nbits]
144
+ if not candidates:
145
+ valid_nbits = sorted(list({qc.activation_quantization_cfg.activation_n_bits
146
+ for qc in n.candidates_quantization_cfg}))
147
+ raise ValueError(
148
+ f'Manually selected activation bit-width {a_nbits} is invalid for node {n}. '
149
+ f'Valid bit-widths: {valid_nbits}.')
150
+ n.quantization_cfg.candidates_quantization_cfg = candidates
151
+ n.quantization_cfg.base_quantization_cfg.activation_quantization_cfg.activation_n_bits = a_nbits
152
+
153
+
154
+ # TODO irena: check coverage
155
+ def _set_manual_weights_bitwidths(manual_weights_bitwidths: Dict[BaseNode, Dict[str, int]]):
156
+ """
157
+ Filters out candidates that don't match the requested weight attributes manual bitwidths, and updates the bitwidths
158
+ in the base quantization config.
248
159
 
249
160
  Args:
250
- qc: QuantizationConfig to create the node's config from.
251
- weight_channel_axis: (Output, Input) channel index of the node's kernel.
252
- op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
253
- node_attrs_list: A list of the node's weights attributes names.
254
-
255
- Returns: a CandidateNodeQuantizationConfig object with both weights and activation quantization config objects.
256
-
257
- """
258
-
259
- # parameters for weights attributes quantization are set within CandidateNodeQuantizationConfig initialization
260
-
261
- # get parameters for activation quantization
262
- activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
263
- if activation_quantization_fn is None:
264
- Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
265
-
266
- activation_quantization_params_fn = get_activation_quantization_params_fn(op_cfg.activation_quantization_method)
267
-
268
- # TODO: remove this validation and warning once enabling all attributes quantization by default
269
- attrs_with_enabled_quantization = [attr for attr, cfg in op_cfg.attr_weights_configs_mapping.items()
270
- if cfg.enable_weights_quantization]
271
- if len(attrs_with_enabled_quantization) > 1:
272
- Logger.warning(f"Multiple weights attributes quantization is enabled via the provided FQC."
273
- f"Quantizing any attribute other than the kernel is experimental "
274
- f"and may be subject to unstable behavior."
275
- f"Attributes with enabled weights quantization: {attrs_with_enabled_quantization}.")
276
-
277
- return CandidateNodeQuantizationConfig(qc=qc,
278
- op_cfg=op_cfg,
279
- activation_quantization_fn=activation_quantization_fn,
280
- activation_quantization_params_fn=activation_quantization_params_fn,
281
- weights_channels_axis=weight_channel_axis,
282
- node_attrs_list=node_attrs_list)
283
-
284
-
285
- def _create_node_candidates_qc(qc: QuantizationConfig,
286
- weight_channel_axis: ChannelAxisMapping,
287
- node_qc_options_list: List[OpQuantizationConfig],
288
- base_config: OpQuantizationConfig,
289
- node: BaseNode,
290
- mixed_precision_enable: bool = False) -> List[CandidateNodeQuantizationConfig]:
291
- """
292
- Create a list of candidates of weights and activation quantization configurations for a node.
293
-
294
- Args:
295
- qc (QuantizationConfig): Quantization configuration the quantization process should follow.
296
- weight_channel_axis (ChannelAxisMapping): (Output, Input) channel index of the node's kernel.
297
- node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs of node.
298
- base_config (OpQuantizationConfig): Base quantization config for node.
299
- node (BaseNode): A node to set quantization configuration candidates to.
300
- mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
301
-
302
- Returns:
303
- List[CandidateNodeQuantizationConfig]: List of candidates of weights quantization configurations to set for a node.
304
- """
305
-
306
- candidates = []
307
- node_attrs_list = node.get_node_weights_attributes()
308
-
309
- if mixed_precision_enable:
310
- for op_cfg in node_qc_options_list:
311
- candidate_qc = copy.deepcopy(qc)
312
- candidates.append(_create_node_single_candidate_qc(candidate_qc,
313
- weight_channel_axis,
314
- op_cfg,
315
- node_attrs_list))
316
-
317
- else:
318
- candidates.append(_create_node_single_candidate_qc(qc,
319
- weight_channel_axis,
320
- base_config,
321
- node_attrs_list))
322
-
323
- return candidates
324
-
325
-
326
- def filter_qc_options_with_manual_bit_width(
327
- node: BaseNode,
328
- node_qc_options_list: List[OpQuantizationConfig],
329
- base_config: OpQuantizationConfig,
330
- manual_bit_width_override: Optional[Dict],
331
- mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
332
- """
333
- Update the quantization configurations for a node, allowing manual bit-width overrides if specified.
334
-
335
- Args:
336
- node (BaseNode): A node to set quantization configuration candidates to.
337
- node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
338
- base_config (OpQuantizationConfig): Base quantization config for the node.
339
- manual_bit_width_override (Optional[Dict]): Specifies a custom bit-width to override the node's activation and weights bit-width.
340
- mixed_precision_enable (bool): Whether mixed precision is enabled.
341
-
342
- Returns:
343
- Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
344
- """
345
- base_config, node_qc_options_list = filter_activation_qc_options_with_manual_bit_width(node,
346
- node_qc_options_list,
347
- base_config,
348
- manual_bit_width_override.get(ACTIVATION),
349
- mixed_precision_enable)
350
-
351
- base_config, node_qc_options_list = filter_weights_qc_options_with_manual_bit_width(node,
352
- node_qc_options_list,
353
- base_config,
354
- manual_bit_width_override.get(WEIGHTS),
355
- mixed_precision_enable)
356
- return base_config, node_qc_options_list
357
-
358
-
359
- def filter_activation_qc_options_with_manual_bit_width(
360
- node: BaseNode,
361
- node_qc_options_list: List[OpQuantizationConfig],
362
- base_config: OpQuantizationConfig,
363
- activation_manual_bit_width_override: Optional[int],
364
- mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
365
- """
366
- Update the activation quantization configurations for a node, allowing manual bit-width overrides if specified.
367
-
368
- Args:
369
- node (BaseNode): A node to set quantization configuration candidates to.
370
- node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
371
- base_config (OpQuantizationConfig): Base quantization config for the node.
372
- activation_manual_bit_width_override (Optional[Dict]): Specifies a custom bit-width to override the node's activation bit-width.
373
- mixed_precision_enable (bool): Whether mixed precision is enabled.
374
-
375
- Returns:
376
- Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
377
- """
378
- if activation_manual_bit_width_override is None:
379
- return base_config, node_qc_options_list
380
-
381
- # Filter node_qc_options_list to retain only the options with activation bits equal to activation_manual_bit_width_override.
382
- node_qc_options_list = [op_cfg for op_cfg in node_qc_options_list if
383
- activation_manual_bit_width_override == op_cfg.activation_n_bits]
384
- if len(node_qc_options_list) == 0:
385
- Logger.critical(f"Manually selected activation bit-width {activation_manual_bit_width_override} is invalid for node {node}.")
386
- else:
387
- # Update the base_config to one of the values from the filtered node_qc_options_list.
388
- # First, check if a configuration similar to the original base_config but with activation bits equal to activation_manual_bit_width_override exists.
389
- # If it does, use it as the base_config. If not, choose a different configuration from node_qc_options_list.
390
- Logger.info(f"Setting node {node} bit-width to manually selected bit-width: {activation_manual_bit_width_override} bits.")
391
- updated_base_config = base_config.clone_and_edit({ACTIVATION_N_BITS, activation_manual_bit_width_override})
392
- if updated_base_config in node_qc_options_list:
393
- # If a base_config with the specified activation_manual_bit_width_override exists in the node_qc_options_list,
394
- # point the base_config to this option.
395
- base_config = node_qc_options_list[node_qc_options_list.index(updated_base_config)]
396
- else:
397
- # Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
398
- base_config = node_qc_options_list[0]
399
- if len(node_qc_options_list) > 0 and not mixed_precision_enable:
400
- Logger.info(
401
- f"Request received to select {activation_manual_bit_width_override} activation bits. However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
402
- f" Overriding base_config with an option that uses {activation_manual_bit_width_override} bit activations.") # pragma: no cover
403
-
404
- return base_config, node_qc_options_list
405
-
406
-
407
- def filter_weights_qc_options_with_manual_bit_width(
408
- node: BaseNode,
409
- node_qc_options_list: List[OpQuantizationConfig],
410
- base_config: OpQuantizationConfig,
411
- weights_manual_bit_width_override: Optional[Tuple[int, WeightAttrT]],
412
- mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
413
- """
414
- Update the weights quantization configurations for a node, allowing manual bit-width overrides if specified.
415
-
416
- Args:
417
- node (BaseNode): A node to set quantization configuration candidates to.
418
- node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
419
- base_config (OpQuantizationConfig): Base quantization config for the node.
420
- weights_manual_bit_width_override (Optional[[int, WeightAttrT]]): Specifies a custom bit-width to override the node's weights bit-width.
421
- mixed_precision_enable (bool): Whether mixed precision is enabled.
422
-
423
- Returns:
424
- Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
425
- """
426
- if not weights_manual_bit_width_override:
427
- return base_config, node_qc_options_list
428
-
429
- # Filter node_qc_options_list to retain only the options with weights bits equal to weights_manual_bit_width_override.
430
- node_qc_options_weights_list = _filter_options(node_qc_options_list, weights_manual_bit_width_override)
431
-
432
- if len(node_qc_options_weights_list) == 0:
433
- Logger.critical(f"Manually selected weights bit-width {weights_manual_bit_width_override} is invalid for node {node}.")
434
- else:
435
- # Update the base_config to one of the values from the filtered node_qc_options_list.
436
- # First, check if a configuration similar to the original base_config but with weights bits equal to weights_manual_bit_width_override exists.
437
- # If it does, use it as the base_config. If not, choose a different configuration from node_qc_options_list.
438
- updated_base_config = base_config.clone_and_edit()
439
-
440
- for bit_width, attr in weights_manual_bit_width_override:
441
- Logger.info(f"Setting node {node} bit-width to manually selected {attr} bit-width: {bit_width} bits.")
442
- updated_base_config = updated_base_config.clone_and_edit(attr_to_edit={attr : {WEIGHTS_N_BITS: bit_width}})
443
-
444
- if updated_base_config in node_qc_options_weights_list:
445
- # If a base_config with the specified weights_manual_bit_width_override exists in the node_qc_options_list,
446
- # point the base_config to this option.
447
- base_config = node_qc_options_weights_list[node_qc_options_weights_list.index(updated_base_config)]
448
- else:
449
- # Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
450
- base_config = node_qc_options_weights_list[0]
451
- if len(node_qc_options_weights_list) > 0 and not mixed_precision_enable:
452
- Logger.info(
453
- f"Request received to select weights bit-widths {weights_manual_bit_width_override}."
454
- f"However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
455
- f" Overriding base_config with an option that uses manually selected weights bit-widths {weights_manual_bit_width_override}.") # pragma: no cover
456
-
457
- return base_config, node_qc_options_weights_list
458
-
459
-
460
- def _is_valid_option(
461
- op_cfg: OpQuantizationConfig,
462
- attr: WeightAttrT,
463
- bit_width: int) -> bool:
464
- """
465
- Judge whether the specified option is valid based on the specified attribute and bit width.
466
-
467
- Args:
468
- op_cfg (OpQuantizationConfig): The quantization configuration to be judged.
469
- attr (WeightAttrT): The filtered node's attributes to apply bit-width manipulation to.
470
- bit_width (int): The bit width to be applied to the selected nodes.
471
-
472
- Returns:
473
- Result to judge whether the specified option is valid based on the specified attribute and bit width
474
- """
475
- weights_attrs = op_cfg.attr_weights_configs_mapping.keys()
476
-
477
- if attr not in weights_attrs:
478
- return False
479
-
480
- weights_n_bits = op_cfg.attr_weights_configs_mapping[attr].weights_n_bits
481
- return weights_n_bits == bit_width
482
-
483
-
484
- def _filter_options(
485
- node_qc_options_list: List[OpQuantizationConfig],
486
- weights_manual_bit_width_override: Tuple[int, WeightAttrT]) -> List[OpQuantizationConfig]:
487
- """
488
- Filter the options based on the specified bit width and attribute.
489
-
490
- Args:
491
- node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
492
- weights_manual_bit_width_override (Tuple[int, WeightAttrT])): Specifies a custom bit-width to override the node's weights bit-width.
493
-
494
- Returns:
495
- List[OpQuantizationConfig]: Filtered the options based on the specified bit width and attribute.
496
- """
497
- filtered_options = []
498
-
499
- for bit_width, attr in weights_manual_bit_width_override:
500
- for op_cfg in node_qc_options_list:
501
- if _is_valid_option(op_cfg, attr, bit_width):
502
- filtered_options.append(op_cfg)
503
-
504
- return filtered_options
161
+ manual_activation_bitwidths: nodes' manual activation bitwidth.
162
+
163
+ Raises:
164
+ ValueError: if the manual bitwidth is requested for non-existing attribute.
165
+ if the manual bitwidth is requested for un-quantized weights attribute.
166
+ if the manual bitwidth is not compatible with any candidate.
167
+ """
168
+ def qc_attr_nbits(qc, attr, n):
169
+ if attr == POSITIONAL_ATTR:
170
+ pos_attrs = qc.weights_quantization_cfg.pos_attributes_config_mapping
171
+ if not pos_attrs:
172
+ raise ValueError('Unexpected positional attribute in manual weights bit-width for node {n}.')
173
+ if any(cfg.enable_weights_quantization is False for cfg in pos_attrs.values()):
174
+ raise ValueError(f'Cannot apply manual bit-width configuration for positional attribute of node {n} as '
175
+ f'the attribute is not quantized.')
176
+ assert len({cfg.weights_n_bits for cfg in pos_attrs.values()}) == 1
177
+ return list(pos_attrs.values())[0].weights_n_bits
178
+ if attr not in qc.weights_quantization_cfg.all_weight_attrs:
179
+ raise ValueError(f'Unexpected attribute {attr} in manual weights bit-width configuration for node {n}.')
180
+ attr_cfg = qc.weights_quantization_cfg.get_attr_config(attr)
181
+ if not attr_cfg.enable_weights_quantization:
182
+ raise ValueError(f'Cannot apply manual bit-width configuration for weights attribute {attr} of node {n} as '
183
+ f'the attribute is not quantized.')
184
+ return qc.weights_quantization_cfg.get_attr_config(attr).weights_n_bits
185
+
186
+ for n, manual_wbits in manual_weights_bitwidths.items():
187
+ candidates = [qc for qc in n.candidates_quantization_cfg
188
+ if all(qc_attr_nbits(qc, attr, n) == w_nbits for attr, w_nbits in manual_wbits.items())]
189
+ if not candidates:
190
+ raise ValueError(f'Cannot apply manual weights bit-width configuration {manual_wbits} for node {n} as it '
191
+ f'does not match any of the quantization candidates.')
192
+ n.quantization_cfg.candidates_quantization_cfg = candidates
193
+ for attr, w_nbits in manual_wbits.items():
194
+ base_weights_cfg = n.quantization_cfg.base_quantization_cfg.weights_quantization_cfg
195
+ if attr == POSITIONAL_ATTR:
196
+ for pos_attr in base_weights_cfg.pos_attributes_config_mapping:
197
+ base_weights_cfg.get_attr_config(pos_attr).weights_n_bits = w_nbits
198
+ else:
199
+ base_weights_cfg.get_attr_config(attr).weights_n_bits = w_nbits
@@ -42,13 +42,12 @@ def apply_activation_bias_correction_to_graph(graph: Graph,
42
42
  n.final_activation_quantization_cfg.activation_bias_correction_term is not None:
43
43
  # If activation bias correction is enabled in n.quantization_cfg, an activation bias correction term was
44
44
  # calculated during model preparation, and is used now in the node's bias term.
45
- _apply_activation_bias_correction_to_node(n, fw_impl, core_config.quantization_config)
45
+ _apply_activation_bias_correction_to_node(n, fw_impl)
46
46
  return graph
47
47
 
48
48
 
49
49
  def _apply_activation_bias_correction_to_node(node: BaseNode,
50
- fw_impl: FrameworkImplementation,
51
- qc: QuantizationConfig):
50
+ fw_impl: FrameworkImplementation):
52
51
  """
53
52
  Set new bias to node using the activation bias correction term that is stored in the
54
53
  final activation quantization configuration.
@@ -56,7 +55,6 @@ def _apply_activation_bias_correction_to_node(node: BaseNode,
56
55
  Args:
57
56
  node: Node to set its corrected bias after activation bias correction.
58
57
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
59
- qc: QuantizationConfig containing parameters of how the model should be quantized.
60
58
 
61
59
  """
62
60
  correction = node.final_activation_quantization_cfg.activation_bias_correction_term
@@ -72,7 +70,6 @@ def _apply_activation_bias_correction_to_node(node: BaseNode,
72
70
  # Configure the quantization of the bias as disabled.
73
71
  node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS,
74
72
  WeightsAttrQuantizationConfig(
75
- qc,
76
73
  AttributeQuantizationConfig(
77
74
  enable_weights_quantization=False)))
78
75
  else:
@@ -77,7 +77,5 @@ def _apply_bias_correction_to_node(node: BaseNode,
77
77
  node.set_weights_by_keys(fw_impl.constants.BIAS, - correction)
78
78
  node.framework_attr[fw_impl.constants.USE_BIAS] = True # Mark the use_bias attribute of the node.
79
79
  node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS,
80
- WeightsAttrQuantizationConfig(
81
- qc,
82
- AttributeQuantizationConfig(
83
- enable_weights_quantization=False)))
80
+ WeightsAttrQuantizationConfig(AttributeQuantizationConfig(
81
+ enable_weights_quantization=False)))
@@ -24,7 +24,7 @@ from model_compression_toolkit.core.common.model_builder_mode import ModelBuilde
24
24
  from model_compression_toolkit.core.common.model_collector import ModelCollector
25
25
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
26
26
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
27
- import get_activations_qparams
27
+ import compute_activation_qparams
28
28
  from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
29
29
  from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
30
30
 
@@ -50,12 +50,11 @@ def _collect_and_assign_act_threshold(graph: Graph,
50
50
  for _data in tqdm(representative_data_gen()):
51
51
  mi.infer(_data)
52
52
 
53
- for n in list(graph.nodes):
53
+ for n in graph.nodes:
54
54
  if n.is_activation_quantization_enabled():
55
- activation_params = get_activations_qparams(
56
- activation_quant_cfg=n.final_activation_quantization_cfg,
57
- nodes_prior_info=n.prior_info,
58
- out_stats_container=graph.get_out_stats_collector(n))
55
+ activation_params = compute_activation_qparams(activation_quant_cfg=n.final_activation_quantization_cfg,
56
+ node_prior_info=n.prior_info,
57
+ out_stats_container=graph.get_out_stats_collector(n))
59
58
  n.final_activation_quantization_cfg.set_activation_quantization_param(activation_params)
60
59
 
61
60