mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250927.534__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -14,6 +14,8 @@
14
14
  # ==============================================================================
15
15
  import copy
16
16
 
17
+ from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
18
+ from model_compression_toolkit.core import CoreConfig
17
19
  from model_compression_toolkit.core.common import Graph, BaseNode
18
20
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
19
21
  from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
@@ -21,6 +23,7 @@ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_s
21
23
 
22
24
 
23
25
  def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
26
+ core_config: CoreConfig,
24
27
  fw_impl: FrameworkImplementation) -> Graph:
25
28
  """
26
29
  Get a graph, where each node has a final weights quantization configuration (with a bias
@@ -28,6 +31,7 @@ def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
28
31
 
29
32
  Args:
30
33
  graph_to_apply_bias_correction: Graph to apply bias correction to.
34
+ core_config: CoreConfig containing parameters of how the model should be quantized.
31
35
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
32
36
 
33
37
  Returns:
@@ -36,14 +40,21 @@ def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
36
40
 
37
41
  graph = copy.deepcopy(graph_to_apply_bias_correction)
38
42
  for n in graph.nodes:
39
- if (n.final_weights_quantization_cfg and n.final_weights_quantization_cfg.bias_corrected is not None and
40
- not n.final_weights_quantization_cfg.weights_second_moment_correction):
41
- _apply_bias_correction_to_node(n, fw_impl)
43
+ # bias correction is only relevant for nodes with kernel op
44
+ kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
45
+ if core_config.quantization_config.weights_bias_correction and kernel_attr is not None and \
46
+ n.is_weights_quantization_enabled(kernel_attr) and \
47
+ not n.final_weights_quantization_cfg.weights_second_moment_correction:
48
+ # If a kernel was quantized and weights bias correction is enabled in n.quantization_cfg,
49
+ # a bias correction term was calculated during model preparation, and is used now in the node's bias term.
50
+ if n.final_weights_quantization_cfg.weights_bias_correction:
51
+ _apply_bias_correction_to_node(n, fw_impl, core_config.quantization_config)
42
52
  return graph
43
53
 
44
54
 
45
55
  def _apply_bias_correction_to_node(node: BaseNode,
46
- fw_impl: FrameworkImplementation):
56
+ fw_impl: FrameworkImplementation,
57
+ qc: QuantizationConfig):
47
58
  """
48
59
  Set new bias to node using the bias-correction term that is stored in the
49
60
  final weights quantization configuration.
@@ -67,5 +78,7 @@ def _apply_bias_correction_to_node(node: BaseNode,
67
78
  node.set_weights_by_keys(fw_impl.constants.BIAS, - correction)
68
79
  node.framework_attr[fw_impl.constants.USE_BIAS] = True # Mark the use_bias attribute of the node.
69
80
  node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS,
70
- WeightsAttrQuantizationConfig(AttributeQuantizationConfig(
71
- enable_weights_quantization=False)))
81
+ WeightsAttrQuantizationConfig(
82
+ qc,
83
+ AttributeQuantizationConfig(
84
+ enable_weights_quantization=False)))
@@ -24,7 +24,7 @@ from model_compression_toolkit.core.common.model_builder_mode import ModelBuilde
24
24
  from model_compression_toolkit.core.common.model_collector import ModelCollector
25
25
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
26
26
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
27
- import compute_activation_qparams
27
+ import get_activations_qparams
28
28
  from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
29
29
  from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
30
30
 
@@ -32,6 +32,7 @@ from model_compression_toolkit.core.common.substitutions.apply_substitutions imp
32
32
  def _collect_and_assign_act_threshold(graph: Graph,
33
33
  representative_data_gen: Callable,
34
34
  core_config: CoreConfig,
35
+ fw_info: FrameworkInfo,
35
36
  fw_impl: FrameworkImplementation):
36
37
  """
37
38
  Collect statistics after second moment correction and assign new thresholds to activations.
@@ -40,32 +41,36 @@ def _collect_and_assign_act_threshold(graph: Graph,
40
41
  representative_data_gen (Callable): Dataset used for calibration.
41
42
  core_config (CoreConfig): Configuration object containing parameters of how the model should be
42
43
  quantized, including mixed precision parameters.
44
+ fw_info: FrameworkInfo object with information about the specific framework's model.
43
45
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
44
46
  """
45
47
 
46
48
  mi = ModelCollector(graph,
47
49
  fw_impl,
48
- core_config.quantization_config) # Mark points for statistics collection
50
+ fw_info,
51
+ core_config.quantization_config) # Mark points for statistics collection
49
52
 
50
53
  for _data in tqdm(representative_data_gen()):
51
54
  mi.infer(_data)
52
55
 
53
- for n in graph.nodes:
56
+ for n in list(graph.nodes):
54
57
  if n.is_activation_quantization_enabled():
55
- activation_params = compute_activation_qparams(quant_cfg=core_config.quantization_config,
56
- node_activation_quant_cfg=n.final_activation_quantization_cfg,
57
- node_prior_info=n.prior_info,
58
- out_stats_container=graph.get_out_stats_collector(n))
58
+ activation_params = get_activations_qparams(
59
+ activation_quant_cfg=n.final_activation_quantization_cfg,
60
+ nodes_prior_info=n.prior_info,
61
+ out_stats_container=graph.get_out_stats_collector(n))
59
62
  n.final_activation_quantization_cfg.set_activation_quantization_param(activation_params)
60
63
 
61
64
 
62
65
  def quantized_model_builder_for_second_moment_correction(graph: common.Graph,
66
+ fw_info: FrameworkInfo,
63
67
  fw_impl: Any):
64
68
  """
65
69
  Build a framework model from a graph for second moment correction.
66
70
 
67
71
  Args:
68
- graph: Graph to build from.
72
+ graph: Graph to build the from.
73
+ fw_info: FrameworkInfo object with information about the specific framework's model.
69
74
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
70
75
 
71
76
  Returns:
@@ -74,13 +79,15 @@ def quantized_model_builder_for_second_moment_correction(graph: common.Graph,
74
79
  quantized_tg = quantize_graph_weights(graph)
75
80
 
76
81
  quantized_model, user_info = fw_impl.model_builder(quantized_tg,
77
- mode=ModelBuilderMode.FLOAT)
82
+ mode=ModelBuilderMode.FLOAT,
83
+ fw_info=fw_info)
78
84
  return quantized_model
79
85
 
80
86
 
81
87
  def apply_second_moment_correction_to_graph(graph: Graph,
82
88
  representative_data_gen: Callable,
83
89
  core_config: CoreConfig,
90
+ fw_info: FrameworkInfo,
84
91
  fw_impl: FrameworkImplementation) -> Graph:
85
92
  """
86
93
  Apply second moment correction on graph.
@@ -89,14 +96,15 @@ def apply_second_moment_correction_to_graph(graph: Graph,
89
96
  representative_data_gen (Callable): Dataset used for calibration.
90
97
  core_config (CoreConfig): Configuration object containing parameters of how the model should be
91
98
  quantized, including mixed precision parameters.
99
+ fw_info: FrameworkInfo object with information about the specific framework's model.
92
100
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
93
101
 
94
102
  Returns:
95
103
  Graph after second moment correction.
96
104
  """
97
- semi_quantized_model = quantized_model_builder_for_second_moment_correction(graph, fw_impl)
105
+ semi_quantized_model = quantized_model_builder_for_second_moment_correction(graph, fw_info, fw_impl)
98
106
  fw_impl.apply_second_moment_correction(semi_quantized_model, core_config, representative_data_gen, graph)
99
107
  graph = substitute(graph, fw_impl.get_substitutions_after_second_moment_correction(core_config.quantization_config))
100
- _collect_and_assign_act_threshold(graph, representative_data_gen, core_config, fw_impl)
108
+ _collect_and_assign_act_threshold(graph, representative_data_gen, core_config, fw_info, fw_impl)
101
109
 
102
110
  return graph
@@ -18,7 +18,7 @@ from typing import Any, Callable
18
18
  from model_compression_toolkit.core import QuantizationConfig
19
19
  from model_compression_toolkit.core.common import BaseNode, Graph
20
20
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
- from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantization_fn
21
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
22
 
23
23
 
24
24
  def get_previous_node_with_activation_quantization(linear_node: BaseNode,
@@ -64,11 +64,11 @@ def calculate_bin_centers(bin_edges: np.ndarray) -> np.ndarray:
64
64
 
65
65
  def compute_activation_bias_correction(graph: Graph,
66
66
  quant_config: QuantizationConfig,
67
+ fw_info: FrameworkInfo,
67
68
  fw_impl: FrameworkImplementation,
68
69
  linear_node: BaseNode,
69
70
  prev_node: BaseNode,
70
- kernel_size: str,
71
- get_activation_quantization_fn_factory: Callable) -> Graph:
71
+ kernel_size: str) -> Graph:
72
72
  """
73
73
  Compute the activation bias correction term, and store it in the final activation
74
74
  quantization configuration.
@@ -76,11 +76,11 @@ def compute_activation_bias_correction(graph: Graph,
76
76
  Args:
77
77
  graph: Graph with nodes to compute the activation bias correction for each node's final activation quantization configuration.
78
78
  quant_config: QuantizationConfig of how the model should be quantized.
79
+ fw_info: Framework info like lists of nodes their kernel should quantized.
79
80
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
80
81
  linear_node: Node to compute the activation bias correction for.
81
82
  prev_node: Node to compute the activation error caused by his activation quantization.
82
83
  kernel_size: The framework specific attribute name of the convolution layer's kernel size.
83
- get_activation_quantization_fn_factory: activation quantization functions factory.
84
84
 
85
85
  Returns:
86
86
  Graph with activation bias correction term for each node.
@@ -107,9 +107,7 @@ def compute_activation_bias_correction(graph: Graph,
107
107
  float_centers = calculate_bin_centers(float_bins)
108
108
 
109
109
  # Quantize the bin edges and calculate the centers of the quantized bins
110
- activation_quantizer = get_activation_quantization_fn(prev_node_act_quant_cfg,
111
- get_activation_quantization_fn_factory)
112
- quant_bins = activation_quantizer(fw_impl.to_tensor(float_bins))
110
+ quant_bins = prev_node_act_quant_cfg.quantize_node_output(fw_impl.to_tensor(float_bins))
113
111
  quant_bins = fw_impl.to_numpy(quant_bins)
114
112
  quant_centers = calculate_bin_centers(quant_bins)
115
113
 
@@ -129,18 +127,19 @@ def compute_activation_bias_correction(graph: Graph,
129
127
  if normalized_bias < quant_config.activation_bias_correction_threshold:
130
128
  return graph
131
129
 
132
- kernel = linear_node.get_weights_by_keys(linear_node.kernel_attr)
130
+ kernel = linear_node.get_weights_by_keys(fw_info.kernel_ops_attributes_mapping.get(linear_node.type)[0])
133
131
 
134
132
  # Compute the activation bias correction by applying the quantization error to the kernel, resulting in an output
135
133
  # size matching the number of output channels.
136
134
  if kernel is not None:
137
135
 
138
136
  # Get the axes that are not the output channel.
137
+ output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(linear_node.type)
139
138
  axis_not_output_channel = list(range(len(kernel.shape)))
140
- axis_not_output_channel.remove(linear_node.channel_axis.output)
139
+ axis_not_output_channel.remove(output_channel_index)
141
140
 
142
141
  # Special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters.
143
- if linear_node.channel_axis.output == linear_node.channel_axis.input:
142
+ if output_channel_index == input_channel_index:
144
143
  axis_not_output_channel.remove(3) # 3 is the depth multiplier index.
145
144
 
146
145
  activation_bias_correction_term = mean_diff * np.sum(kernel, axis=tuple(axis_not_output_channel))
@@ -151,20 +150,21 @@ def compute_activation_bias_correction(graph: Graph,
151
150
 
152
151
  def compute_activation_bias_correction_of_graph(graph: Graph,
153
152
  quant_config: QuantizationConfig,
153
+ fw_info: FrameworkInfo,
154
154
  fw_impl: FrameworkImplementation,
155
155
  activation_bias_correction_node_matchers: Callable,
156
- kernel_size: str,
157
- get_activation_quantization_fn_factory: Callable) -> Graph:
156
+ kernel_size: str) -> Graph:
158
157
  """
159
158
  Compute the activation bias correction term for the graph.
160
159
 
161
160
  Args:
162
161
  graph: Graph with nodes to compute the activation bias correction.
163
162
  quant_config: QuantizationConfig of how the model should be quantized.
163
+ fw_info: Framework info like lists of nodes their kernel should quantized.
164
164
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
165
165
  activation_bias_correction_node_matchers: Function to match the layers for activation bias correction.
166
166
  kernel_size: The framework specific attribute name of the convolution layer's kernel size.
167
- get_activation_quantization_fn_factory: activation quantization functions factory.
167
+
168
168
 
169
169
  Returns:
170
170
  Graph with activation bias correction term for each relevant node.
@@ -177,9 +177,9 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
177
177
  if prev_node is not None:
178
178
  graph = compute_activation_bias_correction(graph=graph,
179
179
  quant_config=quant_config,
180
+ fw_info=fw_info,
180
181
  fw_impl=fw_impl,
181
182
  linear_node=n,
182
183
  prev_node=prev_node,
183
- kernel_size=kernel_size,
184
- get_activation_quantization_fn_factory=get_activation_quantization_fn_factory)
184
+ kernel_size=kernel_size)
185
185
  return graph
@@ -18,6 +18,7 @@ from typing import Any
18
18
  import numpy as np
19
19
 
20
20
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
21
22
  from model_compression_toolkit.core.common import BaseNode, Graph
22
23
  from model_compression_toolkit.core.common.quantization.quantize_node import get_quantized_weights_attr_by_qc
23
24
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
@@ -25,6 +26,7 @@ from model_compression_toolkit.logger import Logger
25
26
 
26
27
 
27
28
  def compute_bias_correction_of_graph(graph: Graph,
29
+ fw_info: FrameworkInfo,
28
30
  fw_impl: FrameworkImplementation) -> Graph:
29
31
  """
30
32
  For each node in a graph, and for each candidate weights quantization configuration,
@@ -33,6 +35,7 @@ def compute_bias_correction_of_graph(graph: Graph,
33
35
  Args:
34
36
  graph: Graph with nodes to compute the bias correction for
35
37
  each node's weights quantization configuration candidates.
38
+ fw_info: Framework info like lists of nodes their kernel should quantized.
36
39
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
37
40
 
38
41
  Returns:
@@ -43,14 +46,25 @@ def compute_bias_correction_of_graph(graph: Graph,
43
46
  for n in graph.nodes:
44
47
  # Bias correction is computed based on the quantized kernel, so we need to get the specific kernel attribute
45
48
  # name out of all the weights attributes of the node.
46
- if n.kernel_attr and n.is_weights_quantization_enabled(n.kernel_attr) and not n.has_positional_weights:
47
- _compute_bias_correction_per_candidate_qc(n, n.kernel_attr, graph.get_in_stats_collector(n),
48
- fw_impl=fw_impl)
49
+ if fw_info.is_kernel_op(n.type):
50
+ kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0]
51
+ if n.is_weights_quantization_enabled(kernel_attr):
52
+ # Bias correction is not applied to layers with constant inputs.
53
+ if n.has_positional_weights:
54
+ for candidate_qc in n.candidates_quantization_cfg:
55
+ candidate_qc.weights_quantization_cfg.weights_bias_correction = False
56
+ else:
57
+ _compute_bias_correction_per_candidate_qc(n,
58
+ kernel_attr,
59
+ fw_info,
60
+ graph.get_in_stats_collector(n),
61
+ fw_impl=fw_impl)
49
62
  return graph
50
63
 
51
64
 
52
65
  def _compute_bias_correction_per_candidate_qc(node: BaseNode,
53
66
  kernel_attr: str,
67
+ fw_info: FrameworkInfo,
54
68
  node_in_stats_collector: BaseStatsCollector,
55
69
  fw_impl: FrameworkImplementation):
56
70
  """
@@ -60,13 +74,15 @@ def _compute_bias_correction_per_candidate_qc(node: BaseNode,
60
74
  Args:
61
75
  node: Node to compute the bias correction for its different candidates.
62
76
  kernel_attr: The name of the kernel attribute of the node.
77
+ fw_info: Framework info like lists of nodes their kernel should quantized.
63
78
  node_in_stats_collector: Statistics collector of the node for the mean per-channel.
64
79
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
65
80
 
66
81
  """
67
82
 
68
83
  for candidate_qc in node.candidates_quantization_cfg:
69
- if not candidate_qc.weights_quantization_cfg.weights_second_moment_correction:
84
+ if candidate_qc.weights_quantization_cfg.weights_bias_correction and not \
85
+ candidate_qc.weights_quantization_cfg.weights_second_moment_correction:
70
86
 
71
87
  quantized_kernel, io_channels_axes = get_quantized_weights_attr_by_qc(kernel_attr,
72
88
  node,
@@ -32,6 +32,7 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
32
32
 
33
33
  def statistics_correction_runner(transformed_graph: Graph,
34
34
  core_config: CoreConfig,
35
+ fw_info: FrameworkInfo,
35
36
  fw_impl: FrameworkImplementation,
36
37
  tb_w: TensorboardWriter = None, ) -> Graph:
37
38
  """
@@ -40,6 +41,7 @@ def statistics_correction_runner(transformed_graph: Graph,
40
41
  transformed_graph: Graph to add statistics correction.
41
42
  core_config (CoreConfig): Configuration object containing parameters of how the model should be
42
43
  quantized, including mixed precision parameters.
44
+ fw_info: FrameworkInfo object with information about the specific framework's model.
43
45
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
44
46
  tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
45
47
 
@@ -56,9 +58,9 @@ def statistics_correction_runner(transformed_graph: Graph,
56
58
  ########################################################
57
59
  # Compute bias correction to nodes' config candidates
58
60
  ########################################################
59
- if core_config.quantization_config.weights_bias_correction:
60
- tg_with_bias = compute_bias_correction_of_graph(tg_with_bias,
61
- fw_impl)
61
+ tg_with_bias = compute_bias_correction_of_graph(tg_with_bias,
62
+ fw_info,
63
+ fw_impl)
62
64
 
63
65
  if tb_w is not None:
64
66
  tb_w.add_graph(tg_with_bias, 'statistics_computation')
@@ -69,6 +71,7 @@ def statistics_correction_runner(transformed_graph: Graph,
69
71
  def apply_statistics_correction(transformed_graph: Graph,
70
72
  representative_data_gen: Callable,
71
73
  core_config: CoreConfig,
74
+ fw_info: FrameworkInfo,
72
75
  fw_impl: FrameworkImplementation,
73
76
  tb_w: TensorboardWriter = None, ) -> Graph:
74
77
  """
@@ -78,6 +81,7 @@ def apply_statistics_correction(transformed_graph: Graph,
78
81
  representative_data_gen (Callable): Dataset used for calibration.
79
82
  core_config (CoreConfig): Configuration object containing parameters of how the model should be
80
83
  quantized, including mixed precision parameters.
84
+ fw_info: FrameworkInfo object with information about the specific framework's model.
81
85
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
82
86
  tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
83
87
 
@@ -90,13 +94,14 @@ def apply_statistics_correction(transformed_graph: Graph,
90
94
  #############################################
91
95
  if core_config.quantization_config.weights_second_moment_correction:
92
96
  transformed_graph = apply_second_moment_correction_to_graph(transformed_graph, representative_data_gen,
93
- core_config, fw_impl)
97
+ core_config, fw_info, fw_impl)
94
98
 
95
99
  #############################################
96
100
  # Apply Bias Correction
97
101
  #############################################
98
102
  if core_config.quantization_config.weights_bias_correction:
99
103
  transformed_graph = apply_bias_correction_to_graph(transformed_graph,
104
+ core_config,
100
105
  fw_impl=fw_impl)
101
106
  if tb_w is not None:
102
107
  tb_w.add_graph(transformed_graph, 'after_statistics_correction')
@@ -20,6 +20,7 @@ from typing import Callable
20
20
  import numpy as np
21
21
 
22
22
  from model_compression_toolkit.core.common import Graph
23
+ from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
23
24
  from model_compression_toolkit.core import common
24
25
  from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
25
26
  ActivationQuantizationMode
@@ -83,28 +84,30 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
83
84
  # If the linear operator is part of a reused group (it is the "base" node, or a reused node),
84
85
  # we should skip the substitution.
85
86
  if source_node.is_reused():
87
+ for qc in source_node.candidates_quantization_cfg:
88
+ qc.weights_quantization_cfg.weights_second_moment_correction = False
86
89
  return graph
87
90
 
88
91
  # We apply only on nodes with folded BatchNormalization.
89
92
  if source_node.prior_info.std_output is None or source_node.prior_info.mean_output is None:
93
+ for qc in source_node.candidates_quantization_cfg:
94
+ qc.weights_quantization_cfg.weights_second_moment_correction = False
90
95
  return graph
91
96
 
92
97
  # This feature disabled for models with weights quantization method of Power of 2
93
98
  for qc in source_node.candidates_quantization_cfg:
94
99
  # this feature is relevant only for layers with kernel op
95
- if source_node.kernel_attr is None:
100
+ kernel_attr = graph.fw_info.get_kernel_op_attributes(source_node.type)
101
+ if kernel_attr is None:
96
102
  Logger.error(f"Can't preform BatchNorm reconstruction on a node {source_node.name} without a kernel op.")
97
- if (qc.weights_quantization_cfg.get_attr_config(source_node.kernel_attr).weights_quantization_method
103
+ if (qc.weights_quantization_cfg.get_attr_config(kernel_attr[0]).weights_quantization_method
98
104
  == QuantizationMethod.POWER_OF_TWO):
99
105
  Logger.warning("Second moment statistics correction feature disabled for models with weights "
100
106
  "quantization method of Power of 2")
107
+ for qc_inner in source_node.candidates_quantization_cfg:
108
+ qc_inner.weights_quantization_cfg.weights_second_moment_correction = False
101
109
  return graph
102
110
 
103
- # turn on second moment correction flag
104
- def set_second_moment_correction(qc):
105
- qc.weights_quantization_cfg.weights_second_moment_correction = True
106
- source_node.quantization_cfg.update_all(set_second_moment_correction)
107
-
108
111
  eps = self.epsilon_val
109
112
 
110
113
  original_gamma = source_node.prior_info.std_output
@@ -122,7 +125,7 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
122
125
 
123
126
  bn_node.prior_info = copy.deepcopy(source_node.prior_info)
124
127
 
125
- bn_node.quantization_cfg = copy.deepcopy(source_node.quantization_cfg)
128
+ bn_node.candidates_quantization_cfg = copy.deepcopy(source_node.candidates_quantization_cfg)
126
129
 
127
130
  for qc in bn_node.candidates_quantization_cfg:
128
131
  qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
@@ -137,6 +140,7 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
137
140
  # reconstructed node BN attributes need to be quantized and how.
138
141
  qc.weights_quantization_cfg.set_attr_config(attr,
139
142
  WeightsAttrQuantizationConfig(
143
+ QuantizationConfig(),
140
144
  AttributeQuantizationConfig(
141
145
  enable_weights_quantization=False)))
142
146
 
@@ -157,7 +157,7 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
157
157
  graph.remove_node(bn_node)
158
158
  graph.remove_node(source_node)
159
159
 
160
- self._calc_weights_quantization_params(conv_bn, weights_scale)
160
+ self._calc_weights_quantization_params(conv_bn, weights_scale, graph.fw_info)
161
161
 
162
162
  assert num_nodes_before_substitution - len(graph.nodes) == 1
163
163
  assert num_edges_before_substitution - len(graph.edges) == 1
@@ -165,15 +165,18 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
165
165
 
166
166
  def _calc_weights_quantization_params(self,
167
167
  conv_bn: BaseNode,
168
- weights_scale: np.ndarray):
168
+ weights_scale: np.ndarray,
169
+ fw_info):
169
170
  """
170
171
  Update node weights quantization params.
171
172
  Args:
172
173
  conv_bn: Convolution node to update the weights quantization params.
173
174
  weights_scale: Weight scale factor in which to multiply the conv node's weight.
175
+ fw_info: FrameworkInfo object with information about the specific framework's model
174
176
  """
175
177
  # Conv layer is ensured to have a kernel attribute
176
- conv_bn_kernel_cfg = conv_bn.final_weights_quantization_cfg.get_attr_config(conv_bn.kernel_attr)
178
+ kernel_attr = fw_info.get_kernel_op_attributes(conv_bn.type)[0]
179
+ conv_bn_kernel_cfg = conv_bn.final_weights_quantization_cfg.get_attr_config(kernel_attr)
177
180
  # In case of SYMMETRIC weight quantization method, we update the threshold by weights_scale
178
181
  if conv_bn_kernel_cfg.weights_quantization_method == QuantizationMethod.SYMMETRIC:
179
182
  original_threshold = conv_bn_kernel_cfg.weights_quantization_params[THRESHOLD]
@@ -20,6 +20,8 @@ import scipy
20
20
 
21
21
  from model_compression_toolkit.core import common
22
22
  from model_compression_toolkit.core.common import Graph, BaseNode
23
+ from model_compression_toolkit.defaultdict import DefaultDict
24
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
25
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
24
26
 
25
27
 
@@ -75,6 +77,7 @@ def fixed_second_moment_after_relu(mu: np.ndarray,
75
77
 
76
78
  def scale_reshaping(scale: np.ndarray,
77
79
  op2d: common.BaseNode,
80
+ kernel_channel_mapping: DefaultDict,
78
81
  kernel_str: str,
79
82
  in_channels: bool = True) -> np.ndarray:
80
83
  """
@@ -86,6 +89,7 @@ def scale_reshaping(scale: np.ndarray,
86
89
  Args:
87
90
  scale: Scale factor to scale the kernel channels by.
88
91
  op2d: Node to scale its kernel.
92
+ kernel_channel_mapping: Mapping from a layer to a tuple of indices of its output/input kernel channels.
89
93
  kernel_str: The framework specific attribute name of the convolution layer's weight/kernel.
90
94
  in_channels: Kernel's index of input channels.
91
95
 
@@ -95,11 +99,12 @@ def scale_reshaping(scale: np.ndarray,
95
99
 
96
100
  op_ndims = op2d.get_weights_by_keys(kernel_str).ndim
97
101
  reshape_target = np.ones(op_ndims, dtype=np.int32)
98
- reshape_target[op2d.channel_axis.input if in_channels else op2d.channel_axis.output] = -1
102
+ reshape_target[kernel_channel_mapping.get(op2d.type)[int(in_channels)]] = -1
99
103
  return np.reshape(scale, reshape_target)
100
104
 
101
105
 
102
- def update_linear_nodes(first_op2d_node: BaseNode,
106
+ def update_linear_nodes(fw_info: FrameworkInfo,
107
+ first_op2d_node: BaseNode,
103
108
  second_op2d_node: BaseNode,
104
109
  scale_factor: np.ndarray,
105
110
  kernel_str: str,
@@ -111,6 +116,7 @@ def update_linear_nodes(first_op2d_node: BaseNode,
111
116
  The scale factor contain a scale value per-channel.
112
117
 
113
118
  Args:
119
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
114
120
  groups of layers by how they should be quantized, etc.)
115
121
  first_op2d_node: Node to multiply its kernel by the scale factor.
116
122
  second_op2d_node: Node to divide its kernel by the scale factor.
@@ -119,12 +125,15 @@ def update_linear_nodes(first_op2d_node: BaseNode,
119
125
  kernel_str: The framework specific attribute name of the convolution layer's weight/kernel.
120
126
 
121
127
  """
128
+
122
129
  w2_fixed = second_op2d_node.get_weights_by_keys(kernel_str) / scale_reshaping(scale_factor,
123
130
  second_op2d_node,
131
+ fw_info.kernel_channels_mapping,
124
132
  kernel_str)
125
133
 
126
134
  w1_fixed = first_op2d_node.get_weights_by_keys(kernel_str) * scale_reshaping(scale_factor,
127
135
  first_op2d_node,
136
+ fw_info.kernel_channels_mapping,
128
137
  kernel_str,
129
138
  in_channels=False)
130
139
 
@@ -159,7 +168,8 @@ def calculate_scale_correction(first_op2d_node: BaseNode) -> tuple:
159
168
  return scale_factor
160
169
 
161
170
 
162
- def scale_equalization_lnl(first_op2d_node: BaseNode,
171
+ def scale_equalization_lnl(fw_info: FrameworkInfo,
172
+ first_op2d_node: BaseNode,
163
173
  second_op2d_node: BaseNode,
164
174
  kernel_str: str,
165
175
  bias_str: str):
@@ -169,6 +179,7 @@ def scale_equalization_lnl(first_op2d_node: BaseNode,
169
179
  follows the activation node to get the same expected output without the scaling.
170
180
 
171
181
  Args:
182
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
172
183
  groups of layers by how they should be quantized, etc.)
173
184
  first_op2d_node: Node to multiply its kernel by the scale factor.
174
185
  second_op2d_node: Node to divide its kernel by the scale factor.
@@ -178,7 +189,8 @@ def scale_equalization_lnl(first_op2d_node: BaseNode,
178
189
  """
179
190
  scale_factor = calculate_scale_correction(first_op2d_node)
180
191
 
181
- update_linear_nodes(first_op2d_node,
192
+ update_linear_nodes(fw_info,
193
+ first_op2d_node,
182
194
  second_op2d_node,
183
195
  scale_factor,
184
196
  kernel_str,
@@ -194,6 +206,7 @@ class BaseScaleEqualization(common.BaseSubstitution):
194
206
 
195
207
  def __init__(self,
196
208
  quant_config: QuantizationConfig,
209
+ fw_info: FrameworkInfo,
197
210
  matcher_instance,
198
211
  kernel_str: str,
199
212
  bias_str: str):
@@ -201,11 +214,13 @@ class BaseScaleEqualization(common.BaseSubstitution):
201
214
  Initialize a ScaleEqualization object.
202
215
  Args:
203
216
  quant_config: QuantizationConfig containing parameters of how the model should be quantized.
217
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
204
218
  groups of layers by how they should be quantized, etc.)
205
219
  matcher_instance: Per substitution matcher instance of type WalkMatcher
206
220
  """
207
221
 
208
222
  self.quant_config = quant_config
223
+ self.fw_info = fw_info
209
224
  self.kernel_str = kernel_str
210
225
  self.bias_str = bias_str
211
226
  super().__init__(matcher_instance=matcher_instance)
@@ -228,7 +243,8 @@ class BaseScaleEqualization(common.BaseSubstitution):
228
243
  act_node = nodes_list[1]
229
244
  second_op2d_node = nodes_list[-1]
230
245
  if first_op2d_node.prior_info.std_output is not None and act_node.is_activation_quantization_enabled():
231
- scale_equalization_lnl(first_op2d_node,
246
+ scale_equalization_lnl(self.fw_info,
247
+ first_op2d_node,
232
248
  second_op2d_node,
233
249
  self.kernel_str,
234
250
  self.bias_str)