mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250927.534__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -12,22 +12,74 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import List, Tuple, Dict
15
+ import copy
16
+ from typing import List, Tuple, Dict, Optional
16
17
 
18
+ from mct_quantizers.common.constants import WEIGHTS_N_BITS, ACTIVATION_N_BITS
19
+ from model_compression_toolkit.constants import WEIGHTS, ACTIVATION
17
20
  from model_compression_toolkit.core.common import BaseNode
18
- from model_compression_toolkit.core.common.graph.base_graph import Graph
19
21
  from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
20
- from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
21
22
  from model_compression_toolkit.logger import Logger
22
- from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR
23
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
24
+ from model_compression_toolkit.core.common.graph.base_graph import Graph
25
+ from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
26
+ CandidateNodeQuantizationConfig
27
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig, \
28
+ ActivationQuantizationMode
29
+ from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
30
+ QuantizationErrorMethod
31
+ from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
32
+ get_activation_quantization_params_fn
33
+ from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
23
34
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
24
35
  QuantizationConfigOptions
25
- from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
26
36
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
27
37
  FrameworkQuantizationCapabilities
38
+ from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
39
+
40
+
41
+ def set_quantization_configuration_to_graph(graph: Graph,
42
+ quant_config: QuantizationConfig,
43
+ bit_width_config: BitWidthConfig = None,
44
+ mixed_precision_enable: bool = False,
45
+ running_gptq: bool = False) -> Graph:
46
+ """
47
+ Add quantization configuration for each graph node.
48
+
49
+ Args:
50
+ graph (Graph): Graph for which to add quantization info to each node.
51
+ quant_config (QuantizationConfig): Quantization configuration containing parameters for how the graph should be quantized.
52
+ bit_width_config (BitWidthConfig): Configuration for manual bit width selection. Defaults to None.
53
+ mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
54
+ running_gptq (bool): Whether or not a GPTQ optimization is planned to run after the PTQ process. Defaults to False.
55
+
56
+ Returns:
57
+ Graph: The graph with quantization configurations attached to each node in it.
58
+ """
59
+
60
+ if quant_config.weights_error_method == QuantizationErrorMethod.HMSE:
61
+ if not running_gptq:
62
+ raise ValueError(f"The HMSE error method for parameters selection is only supported when running GPTQ "
63
+ f"optimization due to long execution time that is not suitable for basic PTQ.")
64
+ Logger.warning("Using the HMSE error method for weights quantization parameters search. "
65
+ "Note: This method may significantly increase runtime during the parameter search process.")
66
+
67
+ nodes_to_manipulate_activation_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_activation_bit_widths(graph)
68
+ nodes_to_manipulate_weights_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_weights_bit_widths(graph)
69
+
70
+ for n in graph.get_topo_sorted_nodes():
71
+ manual_bit_width_override = {ACTIVATION: nodes_to_manipulate_activation_bit_widths.get(n),
72
+ WEIGHTS: nodes_to_manipulate_weights_bit_widths.get(n)}
73
+ set_quantization_configs_to_node(node=n,
74
+ graph=graph,
75
+ quant_config=quant_config,
76
+ fw_info=graph.fw_info,
77
+ fqc=graph.fqc,
78
+ mixed_precision_enable=mixed_precision_enable,
79
+ manual_bit_width_override=manual_bit_width_override)
80
+ return graph
28
81
 
29
82
 
30
- # TODO irena refactor (if needed) and move to load_fqc
31
83
  def filter_node_qco_by_graph(node: BaseNode,
32
84
  fqc: FrameworkQuantizationCapabilities,
33
85
  graph: Graph,
@@ -50,8 +102,6 @@ def filter_node_qco_by_graph(node: BaseNode,
50
102
  that are compatible with next nodes supported input bit-widths.
51
103
 
52
104
  """
53
- from model_compression_toolkit.quantization_preparation.load_fqc import fetch_qc_options_for_node
54
-
55
105
  # Filter quantization config options that don't match the graph.
56
106
  _base_config = node_qc_options.base_config
57
107
  _node_qc_options = node_qc_options.quantization_configurations
@@ -61,7 +111,7 @@ def filter_node_qco_by_graph(node: BaseNode,
61
111
  next_nodes = []
62
112
  while len(_next_nodes):
63
113
  n = _next_nodes.pop(0)
64
- qco = fetch_qc_options_for_node(n, fqc)
114
+ qco = n.get_qco(fqc)
65
115
  qp = [qc.quantization_preserving for qc in qco.quantization_configurations]
66
116
  if not all(qp) and any(qp):
67
117
  Logger.error(f'Attribute "quantization_preserving" should be the same for all QuantizaionConfigOptions in {n}.')
@@ -71,8 +121,7 @@ def filter_node_qco_by_graph(node: BaseNode,
71
121
 
72
122
  if len(next_nodes) == 0:
73
123
  return _base_config, _node_qc_options
74
-
75
- next_nodes_qc_options = [fetch_qc_options_for_node(_node, fqc) for _node in next_nodes]
124
+ next_nodes_qc_options = [_node.get_qco(fqc) for _node in next_nodes]
76
125
  all_next_nodes_supported_input_bitwidth = [max_input_activation_n_bits(op_cfg)
77
126
  for qc_opts in next_nodes_qc_options
78
127
  for op_cfg in qc_opts.quantization_configurations
@@ -102,98 +151,368 @@ def filter_node_qco_by_graph(node: BaseNode,
102
151
  return _base_config, _node_qc_options
103
152
 
104
153
 
105
- def set_manual_bitwidth_config(graph, bit_width_config: BitWidthConfig):
154
+ def set_quantization_configs_to_node(node: BaseNode,
155
+ graph: Graph,
156
+ quant_config: QuantizationConfig,
157
+ fw_info: FrameworkInfo,
158
+ fqc: FrameworkQuantizationCapabilities,
159
+ mixed_precision_enable: bool = False,
160
+ manual_bit_width_override: Optional[Dict] = None):
106
161
  """
107
- Filters candidates per manual bit-width config.
162
+ Create and set quantization configurations to a node (for both weights and activation).
108
163
 
109
164
  Args:
110
- graph: graph after candidates have been set on nodes.
111
- bit_width_config: bit-width config.
165
+ node (BaseNode): Node to set its quantization configurations.
166
+ graph (Graph): Model's internal representation graph.
167
+ quant_config (QuantizationConfig): Quantization configuration to generate the node's configurations from.
168
+ fw_info (FrameworkInfo): Information needed for quantization about the specific framework.
169
+ fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to get default OpQuantizationConfig.
170
+ mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
171
+ manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width. Defaults to None.
112
172
  """
113
- manual_activation_bitwidths = bit_width_config.get_nodes_activation_bit_widths(graph)
114
- manual_weights_bitwidths = bit_width_config.get_nodes_weights_bit_widths(graph)
173
+ node_qc_options = node.get_qco(fqc)
174
+ base_config, node_qc_options_list = filter_node_qco_by_graph(node, fqc, graph, node_qc_options)
115
175
 
116
- if manual_activation_bitwidths:
117
- _set_manual_activation_bitwidths(manual_activation_bitwidths)
176
+ # If a manual_bit_width_override is given, filter node_qc_options_list to retain only the options with activation and weights bits equal to manual_bit_width_override,
177
+ # and update base_config accordingly.
178
+ if manual_bit_width_override is None:
179
+ manual_bit_width_override = {ACTIVATION: None, WEIGHTS: None}
180
+
181
+ base_config, node_qc_options_list = filter_qc_options_with_manual_bit_width(
182
+ node=node,
183
+ node_qc_options_list=node_qc_options_list,
184
+ base_config=base_config,
185
+ manual_bit_width_override=manual_bit_width_override,
186
+ mixed_precision_enable=mixed_precision_enable)
118
187
 
119
- if manual_weights_bitwidths:
120
- _set_manual_weights_bitwidths(manual_weights_bitwidths)
188
+ # Create QC candidates for weights and activation combined
189
+ weight_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
190
+ node.candidates_quantization_cfg = _create_node_candidates_qc(quant_config,
191
+ fw_info,
192
+ weight_channel_axis,
193
+ node_qc_options_list,
194
+ base_config,
195
+ node,
196
+ mixed_precision_enable=mixed_precision_enable)
121
197
 
198
+ # sorting the candidates by kernel attribute weights number of bits first and then by activation number of bits
199
+ # (in reversed order). since only kernel attribute is quantized in weights mixed precision,
200
+ # if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
201
+ node.sort_node_candidates(fw_info)
122
202
 
123
- # TODO irena: check coverage and add missing tests
124
- def _set_manual_activation_bitwidths(manual_activation_bitwidths: Dict[BaseNode, int]):
203
+ for candidate_qc in node.candidates_quantization_cfg:
204
+ if candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.QUANT and \
205
+ not node.get_has_activation():
206
+ candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
207
+ elif candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.PRESERVE_QUANT:
208
+ prev_nodes = graph.get_prev_nodes(node)
209
+ if len(prev_nodes) != 1:
210
+ # Preserving the quantization of more than 1 previous node is ambiguous, so disable it.
211
+ Logger.info(f"Disabling Quantization-Preserving for node {node.name} because it has more than 1 input activations.")
212
+ candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
213
+ elif not prev_nodes[0].is_quantization_preserving() and not prev_nodes[0].is_activation_quantization_enabled():
214
+ # Preserving the quantization of an unquantized node isn't possible, so disable it.
215
+ Logger.info(f"Disabling Quantization-Preserving for node {node.name} because previous node activation quantization is disabled.")
216
+ candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
217
+
218
+
219
+ def create_node_activation_qc(qc: QuantizationConfig,
220
+ fw_info: FrameworkInfo,
221
+ op_cfg: OpQuantizationConfig) -> NodeActivationQuantizationConfig:
125
222
  """
126
- Filters out candidates that don't match the requested manual activation bitwidths, and updates the
127
- activation bitwidth in the base quantization config.
223
+ Create an activation quantization configuration from a QuantizationConfig object.
128
224
 
129
225
  Args:
130
- manual_activation_bitwidths: nodes' manual activation bitwidth.
131
-
132
- Raises:
133
- ValueError: if the manual bitwidth is requested for un-quantized node.
134
- if the manual bitwidth is not compatible with any candidate.
135
- """
136
- for n, a_nbits in manual_activation_bitwidths.items():
137
- quant_mode = n.quantization_cfg.get_activation_quant_mode()
138
- # TODO irena: for FLN I think it should be ignored with warning for layer filter, and error for name filter
139
- if quant_mode != ActivationQuantizationMode.QUANT:
140
- raise ValueError(f'Cannot apply manual activation bit-width for node {n} with activation quantization mode'
141
- f'{quant_mode}, as it does not have its own quantization configuration.')
142
- candidates = [qc for qc in n.candidates_quantization_cfg
143
- if qc.activation_quantization_cfg.activation_n_bits == a_nbits]
144
- if not candidates:
145
- valid_nbits = sorted(list({qc.activation_quantization_cfg.activation_n_bits
146
- for qc in n.candidates_quantization_cfg}))
147
- raise ValueError(
148
- f'Manually selected activation bit-width {a_nbits} is invalid for node {n}. '
149
- f'Valid bit-widths: {valid_nbits}.')
150
- n.quantization_cfg.candidates_quantization_cfg = candidates
151
- n.quantization_cfg.base_quantization_cfg.activation_quantization_cfg.activation_n_bits = a_nbits
152
-
153
-
154
- # TODO irena: check coverage
155
- def _set_manual_weights_bitwidths(manual_weights_bitwidths: Dict[BaseNode, Dict[str, int]]):
156
- """
157
- Filters out candidates that don't match the requested weight attributes manual bitwidths, and updates the bitwidths
158
- in the base quantization config.
226
+ qc: QuantizationConfig to create the node's config from.
227
+ fw_info: Information about the specific framework the node was created from (e.g., whether or not its
228
+ weights/activations should be quantized)
229
+ op_cfg: OpQuantizationConfig with quantizers types to set in node quantization configuration.
230
+
231
+ Returns:
232
+ Activation quantization configuration of a node.
233
+ """
234
+
235
+ activation_quantization_fn = fw_info.activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
236
+ if activation_quantization_fn is None:
237
+ Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
238
+
239
+ activation_quantization_params_fn = get_activation_quantization_params_fn(op_cfg.activation_quantization_method)
240
+
241
+ return NodeActivationQuantizationConfig(qc,
242
+ op_cfg,
243
+ activation_quantization_fn,
244
+ activation_quantization_params_fn)
245
+
246
+
247
+ def _create_node_single_candidate_qc(qc: QuantizationConfig,
248
+ fw_info: FrameworkInfo,
249
+ weight_channel_axis: Tuple[int, int],
250
+ op_cfg: OpQuantizationConfig,
251
+ node_attrs_list: List[str]) -> CandidateNodeQuantizationConfig:
252
+ """
253
+ Create quantization configuration candidate from a QuantizationConfig object.
254
+ Creates both weights and activation quantization configurations
255
+ and initialize a candidate object that encapsulates both.
159
256
 
160
257
  Args:
161
- manual_activation_bitwidths: nodes' manual activation bitwidth.
162
-
163
- Raises:
164
- ValueError: if the manual bitwidth is requested for non-existing attribute.
165
- if the manual bitwidth is requested for un-quantized weights attribute.
166
- if the manual bitwidth is not compatible with any candidate.
167
- """
168
- def qc_attr_nbits(qc, attr, n):
169
- if attr == POSITIONAL_ATTR:
170
- pos_attrs = qc.weights_quantization_cfg.pos_attributes_config_mapping
171
- if not pos_attrs:
172
- raise ValueError('Unexpected positional attribute in manual weights bit-width for node {n}.')
173
- if any(cfg.enable_weights_quantization is False for cfg in pos_attrs.values()):
174
- raise ValueError(f'Cannot apply manual bit-width configuration for positional attribute of node {n} as '
175
- f'the attribute is not quantized.')
176
- assert len({cfg.weights_n_bits for cfg in pos_attrs.values()}) == 1
177
- return list(pos_attrs.values())[0].weights_n_bits
178
- if attr not in qc.weights_quantization_cfg.all_weight_attrs:
179
- raise ValueError(f'Unexpected attribute {attr} in manual weights bit-width configuration for node {n}.')
180
- attr_cfg = qc.weights_quantization_cfg.get_attr_config(attr)
181
- if not attr_cfg.enable_weights_quantization:
182
- raise ValueError(f'Cannot apply manual bit-width configuration for weights attribute {attr} of node {n} as '
183
- f'the attribute is not quantized.')
184
- return qc.weights_quantization_cfg.get_attr_config(attr).weights_n_bits
185
-
186
- for n, manual_wbits in manual_weights_bitwidths.items():
187
- candidates = [qc for qc in n.candidates_quantization_cfg
188
- if all(qc_attr_nbits(qc, attr, n) == w_nbits for attr, w_nbits in manual_wbits.items())]
189
- if not candidates:
190
- raise ValueError(f'Cannot apply manual weights bit-width configuration {manual_wbits} for node {n} as it '
191
- f'does not match any of the quantization candidates.')
192
- n.quantization_cfg.candidates_quantization_cfg = candidates
193
- for attr, w_nbits in manual_wbits.items():
194
- base_weights_cfg = n.quantization_cfg.base_quantization_cfg.weights_quantization_cfg
195
- if attr == POSITIONAL_ATTR:
196
- for pos_attr in base_weights_cfg.pos_attributes_config_mapping:
197
- base_weights_cfg.get_attr_config(pos_attr).weights_n_bits = w_nbits
198
- else:
199
- base_weights_cfg.get_attr_config(attr).weights_n_bits = w_nbits
258
+ qc: QuantizationConfig to create the node's config from.
259
+ fw_info: Information about the specific framework the node was created from (e.g., whether its
260
+ weights/activations should be quantized)
261
+ weight_channel_axis: (Output, Input) channel index of the node's kernel.
262
+ op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
263
+ node_attrs_list: A list of the node's weights attributes names.
264
+
265
+ Returns: a CandidateNodeQuantizationConfig object with both weights and activation quantization config objects.
266
+
267
+ """
268
+
269
+ # parameters for weights attributes quantization are set within CandidateNodeQuantizationConfig initialization
270
+
271
+ # get parameters for activation quantization
272
+ activation_quantization_fn = fw_info.activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
273
+ if activation_quantization_fn is None:
274
+ Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
275
+
276
+ activation_quantization_params_fn = get_activation_quantization_params_fn(op_cfg.activation_quantization_method)
277
+
278
+ # TODO: remove this validation and warning once enabling all attributes quantization by default
279
+ attrs_with_enabled_quantization = [attr for attr, cfg in op_cfg.attr_weights_configs_mapping.items()
280
+ if cfg.enable_weights_quantization]
281
+ if len(attrs_with_enabled_quantization) > 1:
282
+ Logger.warning(f"Multiple weights attributes quantization is enabled via the provided FQC."
283
+ f"Quantizing any attribute other than the kernel is experimental "
284
+ f"and may be subject to unstable behavior."
285
+ f"Attributes with enabled weights quantization: {attrs_with_enabled_quantization}.")
286
+
287
+ return CandidateNodeQuantizationConfig(qc=qc,
288
+ op_cfg=op_cfg,
289
+ activation_quantization_fn=activation_quantization_fn,
290
+ activation_quantization_params_fn=activation_quantization_params_fn,
291
+ weights_channels_axis=weight_channel_axis,
292
+ node_attrs_list=node_attrs_list)
293
+
294
+
295
+ def _create_node_candidates_qc(qc: QuantizationConfig,
296
+ fw_info: FrameworkInfo,
297
+ weight_channel_axis: Tuple[int, int],
298
+ node_qc_options_list: List[OpQuantizationConfig],
299
+ base_config: OpQuantizationConfig,
300
+ node: BaseNode,
301
+ mixed_precision_enable: bool = False) -> List[CandidateNodeQuantizationConfig]:
302
+ """
303
+ Create a list of candidates of weights and activation quantization configurations for a node.
304
+
305
+ Args:
306
+ qc (QuantizationConfig): Quantization configuration the quantization process should follow.
307
+ fw_info (FrameworkInfo): Framework information (e.g., which layers should have their kernels quantized).
308
+ weight_channel_axis (Tuple[int, int]): (Output, Input) channel index of the node's kernel.
309
+ node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs of node.
310
+ base_config (OpQuantizationConfig): Base quantization config for node.
311
+ node (BaseNode): A node to set quantization configuration candidates to.
312
+ mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
313
+
314
+ Returns:
315
+ List[CandidateNodeQuantizationConfig]: List of candidates of weights quantization configurations to set for a node.
316
+ """
317
+
318
+ candidates = []
319
+ node_attrs_list = node.get_node_weights_attributes()
320
+
321
+ if mixed_precision_enable:
322
+ for op_cfg in node_qc_options_list:
323
+ candidate_qc = copy.deepcopy(qc)
324
+ candidates.append(_create_node_single_candidate_qc(candidate_qc,
325
+ fw_info,
326
+ weight_channel_axis,
327
+ op_cfg,
328
+ node_attrs_list))
329
+
330
+ else:
331
+ candidates.append(_create_node_single_candidate_qc(qc,
332
+ fw_info,
333
+ weight_channel_axis,
334
+ base_config,
335
+ node_attrs_list))
336
+
337
+ return candidates
338
+
339
+
340
+ def filter_qc_options_with_manual_bit_width(
341
+ node: BaseNode,
342
+ node_qc_options_list: List[OpQuantizationConfig],
343
+ base_config: OpQuantizationConfig,
344
+ manual_bit_width_override: Optional[Dict],
345
+ mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
346
+ """
347
+ Update the quantization configurations for a node, allowing manual bit-width overrides if specified.
348
+
349
+ Args:
350
+ node (BaseNode): A node to set quantization configuration candidates to.
351
+ node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
352
+ base_config (OpQuantizationConfig): Base quantization config for the node.
353
+ manual_bit_width_override (Optional[Dict]): Specifies a custom bit-width to override the node's activation and weights bit-width.
354
+ mixed_precision_enable (bool): Whether mixed precision is enabled.
355
+
356
+ Returns:
357
+ Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
358
+ """
359
+ base_config, node_qc_options_list = filter_activation_qc_options_with_manual_bit_width(node,
360
+ node_qc_options_list,
361
+ base_config,
362
+ manual_bit_width_override.get(ACTIVATION),
363
+ mixed_precision_enable)
364
+
365
+ base_config, node_qc_options_list = filter_weights_qc_options_with_manual_bit_width(node,
366
+ node_qc_options_list,
367
+ base_config,
368
+ manual_bit_width_override.get(WEIGHTS),
369
+ mixed_precision_enable)
370
+ return base_config, node_qc_options_list
371
+
372
+
373
+ def filter_activation_qc_options_with_manual_bit_width(
374
+ node: BaseNode,
375
+ node_qc_options_list: List[OpQuantizationConfig],
376
+ base_config: OpQuantizationConfig,
377
+ activation_manual_bit_width_override: Optional[int],
378
+ mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
379
+ """
380
+ Update the activation quantization configurations for a node, allowing manual bit-width overrides if specified.
381
+
382
+ Args:
383
+ node (BaseNode): A node to set quantization configuration candidates to.
384
+ node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
385
+ base_config (OpQuantizationConfig): Base quantization config for the node.
386
+ activation_manual_bit_width_override (Optional[Dict]): Specifies a custom bit-width to override the node's activation bit-width.
387
+ mixed_precision_enable (bool): Whether mixed precision is enabled.
388
+
389
+ Returns:
390
+ Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
391
+ """
392
+ if activation_manual_bit_width_override is None:
393
+ return base_config, node_qc_options_list
394
+
395
+ # Filter node_qc_options_list to retain only the options with activation bits equal to activation_manual_bit_width_override.
396
+ node_qc_options_list = [op_cfg for op_cfg in node_qc_options_list if
397
+ activation_manual_bit_width_override == op_cfg.activation_n_bits]
398
+ if len(node_qc_options_list) == 0:
399
+ Logger.critical(f"Manually selected activation bit-width {activation_manual_bit_width_override} is invalid for node {node}.")
400
+ else:
401
+ # Update the base_config to one of the values from the filtered node_qc_options_list.
402
+ # First, check if a configuration similar to the original base_config but with activation bits equal to activation_manual_bit_width_override exists.
403
+ # If it does, use it as the base_config. If not, choose a different configuration from node_qc_options_list.
404
+ Logger.info(f"Setting node {node} bit-width to manually selected bit-width: {activation_manual_bit_width_override} bits.")
405
+ updated_base_config = base_config.clone_and_edit({ACTIVATION_N_BITS, activation_manual_bit_width_override})
406
+ if updated_base_config in node_qc_options_list:
407
+ # If a base_config with the specified activation_manual_bit_width_override exists in the node_qc_options_list,
408
+ # point the base_config to this option.
409
+ base_config = node_qc_options_list[node_qc_options_list.index(updated_base_config)]
410
+ else:
411
+ # Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
412
+ base_config = node_qc_options_list[0]
413
+ if len(node_qc_options_list) > 0 and not mixed_precision_enable:
414
+ Logger.info(
415
+ f"Request received to select {activation_manual_bit_width_override} activation bits. However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
416
+ f" Overriding base_config with an option that uses {activation_manual_bit_width_override} bit activations.") # pragma: no cover
417
+
418
+ return base_config, node_qc_options_list
419
+
420
+
421
+ def filter_weights_qc_options_with_manual_bit_width(
422
+ node: BaseNode,
423
+ node_qc_options_list: List[OpQuantizationConfig],
424
+ base_config: OpQuantizationConfig,
425
+ weights_manual_bit_width_override: Optional[Tuple[int, WeightAttrT]],
426
+ mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
427
+ """
428
+ Update the weights quantization configurations for a node, allowing manual bit-width overrides if specified.
429
+
430
+ Args:
431
+ node (BaseNode): A node to set quantization configuration candidates to.
432
+ node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
433
+ base_config (OpQuantizationConfig): Base quantization config for the node.
434
+ weights_manual_bit_width_override (Optional[[int, WeightAttrT]]): Specifies a custom bit-width to override the node's weights bit-width.
435
+ mixed_precision_enable (bool): Whether mixed precision is enabled.
436
+
437
+ Returns:
438
+ Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
439
+ """
440
+ if not weights_manual_bit_width_override:
441
+ return base_config, node_qc_options_list
442
+
443
+ # Filter node_qc_options_list to retain only the options with weights bits equal to weights_manual_bit_width_override.
444
+ node_qc_options_weights_list = _filter_options(node_qc_options_list, weights_manual_bit_width_override)
445
+
446
+ if len(node_qc_options_weights_list) == 0:
447
+ Logger.critical(f"Manually selected weights bit-width {weights_manual_bit_width_override} is invalid for node {node}.")
448
+ else:
449
+ # Update the base_config to one of the values from the filtered node_qc_options_list.
450
+ # First, check if a configuration similar to the original base_config but with weights bits equal to weights_manual_bit_width_override exists.
451
+ # If it does, use it as the base_config. If not, choose a different configuration from node_qc_options_list.
452
+ updated_base_config = base_config.clone_and_edit()
453
+
454
+ for bit_width, attr in weights_manual_bit_width_override:
455
+ Logger.info(f"Setting node {node} bit-width to manually selected {attr} bit-width: {bit_width} bits.")
456
+ updated_base_config = updated_base_config.clone_and_edit(attr_to_edit={attr : {WEIGHTS_N_BITS: bit_width}})
457
+
458
+ if updated_base_config in node_qc_options_weights_list:
459
+ # If a base_config with the specified weights_manual_bit_width_override exists in the node_qc_options_list,
460
+ # point the base_config to this option.
461
+ base_config = node_qc_options_weights_list[node_qc_options_weights_list.index(updated_base_config)]
462
+ else:
463
+ # Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
464
+ base_config = node_qc_options_weights_list[0]
465
+ if len(node_qc_options_weights_list) > 0 and not mixed_precision_enable:
466
+ Logger.info(
467
+ f"Request received to select weights bit-widths {weights_manual_bit_width_override}."
468
+ f"However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
469
+ f" Overriding base_config with an option that uses manually selected weights bit-widths {weights_manual_bit_width_override}.") # pragma: no cover
470
+
471
+ return base_config, node_qc_options_weights_list
472
+
473
+
474
+ def _is_valid_option(
475
+ op_cfg: OpQuantizationConfig,
476
+ attr: WeightAttrT,
477
+ bit_width: int) -> bool:
478
+ """
479
+ Judge whether the specified option is valid based on the specified attribute and bit width.
480
+
481
+ Args:
482
+ op_cfg (OpQuantizationConfig): The quantization configuration to be judged.
483
+ attr (WeightAttrT): The filtered node's attributes to apply bit-width manipulation to.
484
+ bit_width (int): The bit width to be applied to the selected nodes.
485
+
486
+ Returns:
487
+ Result to judge whether the specified option is valid based on the specified attribute and bit width
488
+ """
489
+ weights_attrs = op_cfg.attr_weights_configs_mapping.keys()
490
+
491
+ if attr not in weights_attrs:
492
+ return False
493
+
494
+ weights_n_bits = op_cfg.attr_weights_configs_mapping[attr].weights_n_bits
495
+ return weights_n_bits == bit_width
496
+
497
+
498
+ def _filter_options(
499
+ node_qc_options_list: List[OpQuantizationConfig],
500
+ weights_manual_bit_width_override: Tuple[int, WeightAttrT]) -> List[OpQuantizationConfig]:
501
+ """
502
+ Filter the options based on the specified bit width and attribute.
503
+
504
+ Args:
505
+ node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
506
+ weights_manual_bit_width_override (Tuple[int, WeightAttrT])): Specifies a custom bit-width to override the node's weights bit-width.
507
+
508
+ Returns:
509
+ List[OpQuantizationConfig]: Filtered the options based on the specified bit width and attribute.
510
+ """
511
+ filtered_options = []
512
+
513
+ for bit_width, attr in weights_manual_bit_width_override:
514
+ for op_cfg in node_qc_options_list:
515
+ if _is_valid_option(op_cfg, attr, bit_width):
516
+ filtered_options.append(op_cfg)
517
+
518
+ return filtered_options
@@ -38,16 +38,18 @@ def apply_activation_bias_correction_to_graph(graph: Graph,
38
38
 
39
39
  for n in graph.nodes:
40
40
  # Activation bias correction is only relevant for nodes with kernel op
41
- if core_config.quantization_config.activation_bias_correction and n.kernel_attr is not None and \
41
+ kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
42
+ if core_config.quantization_config.activation_bias_correction and kernel_attr is not None and \
42
43
  n.final_activation_quantization_cfg.activation_bias_correction_term is not None:
43
44
  # If activation bias correction is enabled in n.quantization_cfg, an activation bias correction term was
44
45
  # calculated during model preparation, and is used now in the node's bias term.
45
- _apply_activation_bias_correction_to_node(n, fw_impl)
46
+ _apply_activation_bias_correction_to_node(n, fw_impl, core_config.quantization_config)
46
47
  return graph
47
48
 
48
49
 
49
50
  def _apply_activation_bias_correction_to_node(node: BaseNode,
50
- fw_impl: FrameworkImplementation):
51
+ fw_impl: FrameworkImplementation,
52
+ qc: QuantizationConfig):
51
53
  """
52
54
  Set new bias to node using the activation bias correction term that is stored in the
53
55
  final activation quantization configuration.
@@ -55,6 +57,7 @@ def _apply_activation_bias_correction_to_node(node: BaseNode,
55
57
  Args:
56
58
  node: Node to set its corrected bias after activation bias correction.
57
59
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
60
+ qc: QuantizationConfig containing parameters of how the model should be quantized.
58
61
 
59
62
  """
60
63
  correction = node.final_activation_quantization_cfg.activation_bias_correction_term
@@ -70,6 +73,7 @@ def _apply_activation_bias_correction_to_node(node: BaseNode,
70
73
  # Configure the quantization of the bias as disabled.
71
74
  node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS,
72
75
  WeightsAttrQuantizationConfig(
76
+ qc,
73
77
  AttributeQuantizationConfig(
74
78
  enable_weights_quantization=False)))
75
79
  else: