mct-nightly 2.4.0.20250924.535__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -28,13 +28,15 @@ class PruningFrameworkImplementation(FrameworkImplementation):
28
28
  @abstractmethod
29
29
  def prune_entry_node(self,
30
30
  node: BaseNode,
31
- output_mask: np.ndarray):
31
+ output_mask: np.ndarray,
32
+ fw_info: FrameworkInfo):
32
33
  """
33
34
  Abstract method to prune an entry node in the model.
34
35
 
35
36
  Args:
36
37
  node: The node to be pruned.
37
38
  output_mask: A numpy array representing the mask to be applied to the output channels.
39
+ fw_info: Framework-specific information.
38
40
 
39
41
  Raises:
40
42
  NotImplemented: If the method is not implemented in the subclass.
@@ -46,7 +48,8 @@ class PruningFrameworkImplementation(FrameworkImplementation):
46
48
  def prune_intermediate_node(self,
47
49
  node: BaseNode,
48
50
  input_mask: np.ndarray,
49
- output_mask: np.ndarray):
51
+ output_mask: np.ndarray,
52
+ fw_info: FrameworkInfo):
50
53
  """
51
54
  Abstract method to prune an intermediate node in the model.
52
55
 
@@ -54,6 +57,7 @@ class PruningFrameworkImplementation(FrameworkImplementation):
54
57
  node: The node to be pruned.
55
58
  input_mask: Mask to be applied to the input channels.
56
59
  output_mask: Mask to be applied to the output channels.
60
+ fw_info: Framework-specific information.
57
61
 
58
62
  Raises:
59
63
  NotImplemented: If the method is not implemented in the subclass.
@@ -64,13 +68,15 @@ class PruningFrameworkImplementation(FrameworkImplementation):
64
68
  @abstractmethod
65
69
  def prune_exit_node(self,
66
70
  node: BaseNode,
67
- input_mask: np.ndarray):
71
+ input_mask: np.ndarray,
72
+ fw_info: FrameworkInfo):
68
73
  """
69
74
  Abstract method to prune an exit node in the model.
70
75
 
71
76
  Args:
72
77
  node: The node to be pruned.
73
78
  input_mask: Mask to be applied to the input channels.
79
+ fw_info: Framework-specific information.
74
80
 
75
81
  Raises:
76
82
  NotImplemented: If the method is not implemented in the subclass.
@@ -99,7 +105,8 @@ class PruningFrameworkImplementation(FrameworkImplementation):
99
105
  @abstractmethod
100
106
  def is_node_exit_node(self,
101
107
  node: BaseNode,
102
- corresponding_entry_node: BaseNode) -> bool:
108
+ corresponding_entry_node: BaseNode,
109
+ fw_info: FrameworkInfo) -> bool:
103
110
 
104
111
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
105
112
  f'framework\'s is_node_exit_node method.') # pragma: no cover
@@ -122,7 +129,7 @@ class PruningFrameworkImplementation(FrameworkImplementation):
122
129
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
123
130
  f'framework\'s is_node_intermediate_pruning_section method.') # pragma: no cover
124
131
 
125
- def attrs_oi_channels_info_for_pruning(self, node: BaseNode) -> Dict[str, Tuple[int, int]]:
132
+ def attrs_oi_channels_info_for_pruning(self, node: BaseNode, fw_info: FrameworkInfo) -> Dict[str, Tuple[int, int]]:
126
133
  """
127
134
  Retrieves the attributes of a given node along with the output/input (OI) channel axis
128
135
  for each attribute used to prune these attributes.
@@ -139,6 +146,7 @@ class PruningFrameworkImplementation(FrameworkImplementation):
139
146
 
140
147
  Args:
141
148
  node (BaseNode): The node from the computational graph.
149
+ fw_info (FrameworkInfo): Contains framework-specific information and utilities.
142
150
 
143
151
  Returns:
144
152
  Dict[str, Tuple[int, int]]: A dictionary where each key is an attribute name (like 'kernel' or 'bias')
@@ -76,28 +76,34 @@ class PruningSection:
76
76
 
77
77
  def apply_inner_section_mask(self,
78
78
  pruning_section_mask: PruningSectionMask,
79
- fw_impl: Any):
79
+ fw_impl: Any,
80
+ fw_info: FrameworkInfo):
80
81
  """
81
82
  Apply the provided pruning section mask to all nodes within the pruning section.
82
83
 
83
84
  Args:
84
85
  pruning_section_mask (PruningSectionMask): The mask to be applied to the pruning section.
85
86
  fw_impl (PruningFrameworkImplementation): Framework-specific implementation for applying the mask.
87
+ fw_info (FrameworkInfo): Framework-specific information needed to apply the mask.
86
88
  """
87
89
  fw_impl.prune_entry_node(node=self.entry_node,
88
- output_mask=pruning_section_mask.entry_node_oc_mask)
90
+ output_mask=pruning_section_mask.entry_node_oc_mask,
91
+ fw_info=fw_info)
89
92
 
90
93
  for inter_node in self.intermediate_nodes:
91
94
  fw_impl.prune_intermediate_node(node=inter_node,
92
95
  input_mask=pruning_section_mask.entry_node_oc_mask,
93
- output_mask=pruning_section_mask.entry_node_oc_mask)
96
+ output_mask=pruning_section_mask.entry_node_oc_mask,
97
+ fw_info=fw_info)
94
98
 
95
99
  fw_impl.prune_exit_node(self.exit_node,
96
- input_mask=pruning_section_mask.exit_node_ic_mask)
100
+ input_mask=pruning_section_mask.exit_node_ic_mask,
101
+ fw_info=fw_info)
97
102
 
98
103
  @staticmethod
99
104
  def has_matching_channel_count(exit_node: BaseNode,
100
- corresponding_entry_node: BaseNode) -> bool:
105
+ corresponding_entry_node: BaseNode,
106
+ fw_info: FrameworkInfo) -> bool:
101
107
  """
102
108
  Checks if the number of input channels of the exit node matches the number of output channels
103
109
  of its corresponding entry node.
@@ -109,10 +115,13 @@ class PruningSection:
109
115
  Returns:
110
116
  bool: True if the channel counts match, False otherwise.
111
117
  """
112
- exit_input_channel_axis = exit_node.channel_axis.input
113
- entry_output_channel_axis = corresponding_entry_node.channel_axis.output
118
+ _, exit_input_channel_axis = fw_info.kernel_channels_mapping.get(exit_node.type)
119
+ entry_output_channel_axis, _ = fw_info.kernel_channels_mapping.get(corresponding_entry_node.type)
114
120
 
115
- exit_input_channels = exit_node.get_weights_by_keys(exit_node.kernel_attr).shape[exit_input_channel_axis]
116
- entry_output_channels = corresponding_entry_node.get_weights_by_keys(corresponding_entry_node.kernel_attr).shape[entry_output_channel_axis]
121
+ exit_node_attr = fw_info.get_kernel_op_attributes(exit_node.type)[0]
122
+ entry_node_attr = fw_info.get_kernel_op_attributes(corresponding_entry_node.type)[0]
123
+
124
+ exit_input_channels = exit_node.get_weights_by_keys(exit_node_attr).shape[exit_input_channel_axis]
125
+ entry_output_channels = corresponding_entry_node.get_weights_by_keys(entry_node_attr).shape[entry_output_channel_axis]
117
126
 
118
127
  return exit_input_channels == entry_output_channels
@@ -19,8 +19,8 @@ from model_compression_toolkit.core.common import Graph
19
19
  from model_compression_toolkit.core.common.matchers.node_matcher import BaseNodeMatcher
20
20
  from model_compression_toolkit.logger import Logger
21
21
 
22
- from model_compression_toolkit.core.common.graph.base_node import WeightAttrT, BaseNode
23
- from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR
22
+ from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
23
+ from model_compression_toolkit.target_platform_capabilities.constants import POS_ATTR
24
24
 
25
25
 
26
26
  @dataclass
@@ -95,7 +95,7 @@ class BitWidthConfig:
95
95
  for attr, bit_width, filter in zip (attrs, bit_widths, filters):
96
96
  self.manual_weights_bit_width_selection_list += [ManualWeightsBitWidthSelection(filter, bit_width, attr)]
97
97
 
98
- def get_nodes_activation_bit_widths(self, graph: Graph) -> Dict[BaseNode, int]:
98
+ def get_nodes_to_manipulate_activation_bit_widths(self, graph: Graph) -> Dict:
99
99
  """
100
100
  Retrieve nodes from the graph that need their bit-widths for activation changed according to the manual bit-width selections.
101
101
 
@@ -108,7 +108,7 @@ class BitWidthConfig:
108
108
  activation_nodes_to_change_bit_width = self._construct_node_to_new_activation_bit_mapping(graph)
109
109
  return activation_nodes_to_change_bit_width
110
110
 
111
- def get_nodes_weights_bit_widths(self, graph: Graph) -> Dict[BaseNode, Dict[str, int]]:
111
+ def get_nodes_to_manipulate_weights_bit_widths(self, graph: Graph) -> Dict:
112
112
  """
113
113
  Retrieve nodes from the graph that need their bit-widths for weights changed according to the manual bit-width selections.
114
114
 
@@ -166,7 +166,7 @@ class BitWidthConfig:
166
166
  attrs = BitWidthConfig._expand_to_list_core(filters, attrs)
167
167
  return attrs, bit_widths, filters
168
168
 
169
- def _construct_node_to_new_activation_bit_mapping(self, graph) -> Dict[BaseNode, int]:
169
+ def _construct_node_to_new_activation_bit_mapping(self, graph) -> Dict:
170
170
  """
171
171
  Retrieve nodes from the graph that need their activation bit-widths changed according to the manual bit-width selections.
172
172
 
@@ -192,7 +192,7 @@ class BitWidthConfig:
192
192
  unit_nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width})
193
193
  return unit_nodes_to_change_bit_width
194
194
 
195
- def _construct_node_to_new_weights_bit_mapping(self, graph) -> Dict[BaseNode, Dict[str, int]]:
195
+ def _construct_node_to_new_weights_bit_mapping(self, graph) -> Dict:
196
196
  """
197
197
  Retrieve nodes from the graph that need their weights bit-widths changed according to the manual bit-width selections.
198
198
 
@@ -212,7 +212,7 @@ class BitWidthConfig:
212
212
  f"to change their bit width to {manual_bit_width_selection.bit_width}.")
213
213
 
214
214
  for n in filtered_nodes:
215
- attr_to_change_bit_width = {}
215
+ attr_to_change_bit_width = []
216
216
 
217
217
  attrs_str = n.get_node_weights_attributes()
218
218
  if len(attrs_str) == 0:
@@ -225,8 +225,8 @@ class BitWidthConfig:
225
225
  attr.append(attr_str)
226
226
  # this is a positional attribute, so it needs to be handled separately.
227
227
  # Search manual_bit_width_selection's attribute that contain the POS_ATTR string.
228
- elif isinstance(attr_str, int) and POSITIONAL_ATTR in manual_bit_width_selection.attr:
229
- attr.append(POSITIONAL_ATTR)
228
+ elif isinstance(attr_str, int) and POS_ATTR in manual_bit_width_selection.attr:
229
+ attr.append(POS_ATTR)
230
230
  if len(attr) == 0:
231
231
  Logger.critical(f'The requested attribute {manual_bit_width_selection.attr} to change the bit width for {n} does not exist.')
232
232
 
@@ -239,7 +239,7 @@ class BitWidthConfig:
239
239
  f"Node {n} has an existing manual bit width configuration of {manual_bit_width_selection.attr}."
240
240
  f"A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
241
241
 
242
- attr_to_change_bit_width[manual_bit_width_selection.attr] = manual_bit_width_selection.bit_width
242
+ attr_to_change_bit_width.append([manual_bit_width_selection.bit_width, manual_bit_width_selection.attr])
243
243
  unit_nodes_to_change_bit_width.update({n: attr_to_change_bit_width})
244
244
 
245
245
  return unit_nodes_to_change_bit_width
@@ -12,133 +12,72 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- import copy
16
- from dataclasses import dataclass, InitVar
17
- from typing import Callable, List, Optional
15
+ from typing import Callable, List, Tuple
18
16
 
17
+ from model_compression_toolkit.core import QuantizationConfig
19
18
  from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
20
- NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig, ActivationQuantizationMode
19
+ NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
20
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig, \
21
+ OpQuantizationConfig
22
+ from model_compression_toolkit.logger import Logger
21
23
 
22
24
 
23
- @dataclass(eq=True)
25
+ ##########################################
26
+ # Every node holds a quantization configuration
27
+ # for its weights quantization, and a different quantization
28
+ # configuration for its activation quantization configuration.
29
+ ##########################################
30
+
24
31
  class CandidateNodeQuantizationConfig(BaseNodeQuantizationConfig):
25
32
  """
26
- Candidate quantization configuration for a node.
33
+ Class for representing candidate node configuration, which includes weights and activation configuration combined.
27
34
  """
28
- activation_quantization_cfg: NodeActivationQuantizationConfig
29
- # TODO irena: None is passed in several places, need to check if it's handled properly or it's only passed in cases
30
- # that do not affect anything (my guess is it's the second).
31
- # I think in general it makes more sense to set it to None when there are no weights, and maybe when all weights
32
- # are unquantized, and handle it properly everywhere.
33
- weights_quantization_cfg: Optional[NodeWeightsQuantizationConfig]
34
-
35
-
36
- # TODO irena: currently all code still looks at candidates_quantization_cfg as previously, so this is just an initial
37
- # implementation. For now base config is completely separated from candidates (base config must be equal to one of the
38
- # candidates, but we create a separate copy), and updating in place is allowed. Also we require quantization mode to
39
- # be identical between all configs.
40
- @dataclass
41
- class NodeQuantizationConfig:
42
- # quantization config for single precision
43
- base_quantization_cfg: CandidateNodeQuantizationConfig
44
- # quantization candidate configs for mixed precision
45
- candidates_quantization_cfg: List[CandidateNodeQuantizationConfig]
46
-
47
- validate: InitVar[bool] = True
48
-
49
- def update_all(self, update_fn: Callable[[CandidateNodeQuantizationConfig], None], remove_duplicates: bool = True):
50
- """
51
- Apply update function on the base config and all candidates configs.
52
-
53
- Args:
54
- update_fn: function to apply.
55
- remove_duplicates: remove duplicate candidates.
56
- """
57
- if self.base_quantization_cfg:
58
- update_fn(self.base_quantization_cfg)
59
- for cfg in self.candidates_quantization_cfg:
60
- update_fn(cfg)
61
- if remove_duplicates:
62
- self.remove_duplicates()
63
35
 
64
- def update_activation_quantization_mode(self, mode: ActivationQuantizationMode):
36
+ def __init__(self,
37
+ qc: QuantizationConfig = None,
38
+ op_cfg: OpQuantizationConfig = None,
39
+ activation_quantization_cfg: NodeActivationQuantizationConfig = None,
40
+ activation_quantization_fn: Callable = None,
41
+ activation_quantization_params_fn: Callable = None,
42
+ weights_quantization_cfg: NodeWeightsQuantizationConfig = None,
43
+ weights_channels_axis: Tuple[int, int] = None,
44
+ node_attrs_list: List[str] = None):
65
45
  """
66
- Update activation quantization mode for the base config and all candidates configs.
67
46
 
68
47
  Args:
69
- mode: quantization mode.
70
- """
71
- def fn(c):
72
- c.activation_quantization_cfg.quant_mode = mode
73
-
74
- self.update_all(fn)
75
-
76
- def disable_weights_quantization(self):
77
- """
78
- Disable all weights quantization for the base config and all candidates configs.
79
- """
80
- self.update_all(lambda c: c.weights_quantization_cfg.disable_all_weights_quantization())
81
-
82
- def get_activation_quant_mode(self) -> ActivationQuantizationMode:
83
- """
84
- Retrieve activation quantization mode.
85
-
86
- Returns:
87
- Activation quantization mode.
88
-
89
- Raises:
90
- ValueError if not all candidates contain the same mode.
91
- """
92
- self._validate_consistent_activation_quant_mode()
93
- return self.base_quantization_cfg.activation_quantization_cfg.quant_mode
94
-
95
- def remove_duplicates(self):
96
- """
97
- Remove duplicate candidates. First candidate among duplicates is kept, and the order is preserved.
48
+ qc: QuantizationConfig to create the node's config from.
49
+ op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
50
+ activation_quantization_cfg: An option to pass a NodeActivationQuantizationConfig to create a new config from.
51
+ activation_quantization_fn: Function to use when quantizing the node's activations.
52
+ activation_quantization_params_fn: Function to use when computing the threshold for quantizing a node's activations.
53
+ weights_quantization_cfg: An option to pass a NodeWeightsQuantizationConfig to create a new config from.
54
+ weights_channels_axis: Axis to quantize a node's weights attribute when quantizing per-channel.
55
+ node_attrs_list: A list of the node's weights attributes names.
98
56
  """
99
- uniq_qcs = []
100
- for qc in self.candidates_quantization_cfg:
101
- if qc not in uniq_qcs:
102
- uniq_qcs.append(qc)
103
- self.candidates_quantization_cfg = uniq_qcs
104
57
 
105
- def __post_init__(self, validate=True):
106
- if validate:
107
- if not any(self.base_quantization_cfg == qc for qc in self.candidates_quantization_cfg):
108
- raise ValueError('Candidates should contain the base config.')
109
- self._validate_consistent_activation_quant_mode()
110
- self._validate_consistent_weights_quant_mode()
111
- # TODO irena
112
- # for now make sure they are separate objects so that one doesnt inadvertently modify the other
113
- if any(self.base_quantization_cfg is qc for qc in self.candidates_quantization_cfg):
114
- self.base_quantization_cfg = copy.deepcopy(self.base_quantization_cfg)
58
+ if activation_quantization_cfg is not None:
59
+ self.activation_quantization_cfg = activation_quantization_cfg
60
+ else:
61
+ if any(v is None for v in (qc, op_cfg, activation_quantization_fn, activation_quantization_params_fn)): # pragma: no cover
62
+ Logger.critical(
63
+ "Missing required arguments to initialize a node activation quantization configuration. "
64
+ "Ensure QuantizationConfig, OpQuantizationConfig, activation quantization function, "
65
+ "and parameters function are provided.")
66
+ self.activation_quantization_cfg = (
67
+ NodeActivationQuantizationConfig(qc=qc,
68
+ op_cfg=op_cfg,
69
+ activation_quantization_fn=activation_quantization_fn,
70
+ activation_quantization_params_fn=activation_quantization_params_fn))
71
+
72
+ if weights_quantization_cfg is not None:
73
+ self.weights_quantization_cfg = weights_quantization_cfg
74
+ elif all(v is not None for v in (qc, op_cfg, node_attrs_list)):
75
+ self.weights_quantization_cfg = NodeWeightsQuantizationConfig(qc=qc,
76
+ op_cfg=op_cfg,
77
+ weights_channels_axis=weights_channels_axis,
78
+ node_attrs_list=node_attrs_list)
79
+ else:
80
+ self.weights_quantization_cfg = None
81
+ Logger.debug("Setting weights quantization config as None during CandidateNodeQuantizationConfig creation."
82
+ "Notice, this should happen only for FLN nodes.")
115
83
 
116
- def _validate_consistent_activation_quant_mode(self):
117
- """
118
- Validate that base config and all candidates configs contain identical activation quantization mode.
119
-
120
- Raises:
121
- ValueError if activation quantization mode is not consistent.
122
- """
123
- activation_quant_mode = self.base_quantization_cfg.activation_quantization_cfg.quant_mode
124
- if any(qc.activation_quantization_cfg.quant_mode != activation_quant_mode
125
- for qc in self.candidates_quantization_cfg):
126
- raise ValueError('Quantization candidates with different quantization modes are not currently supported.')
127
-
128
- def _validate_consistent_weights_quant_mode(self):
129
- """
130
- Validate that base config and all candidates configs contain identical weights quantization mode per attribute,
131
- i.e. quantization for each attribute should either be enabled in all configs, or disabled in all configs.
132
-
133
- Raises:
134
- ValueError if weights quantization is not consistent.
135
- """
136
- def get_weights_mode(qc):
137
- # in graph fuser weights_quantization_cfg is set to None
138
- if qc.weights_quantization_cfg is None:
139
- return None
140
- return {attr: attr_cfg.enable_weights_quantization for attr, attr_cfg
141
- in qc.weights_quantization_cfg.get_all_weight_attrs_configs().items()}
142
- if any(get_weights_mode(self.base_quantization_cfg) != get_weights_mode(qc)
143
- for qc in self.candidates_quantization_cfg):
144
- raise ValueError('Quantization candidates with different quantization modes are not currently supported.')
@@ -21,6 +21,7 @@ from model_compression_toolkit.constants import FLOAT_BITWIDTH
21
21
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
22
22
  CandidateNodeQuantizationConfig
23
23
 
24
+
24
25
  def filter_nodes_candidates(graph: Graph):
25
26
  """
26
27
  Filters the graph's nodes candidates configuration list.
@@ -33,7 +34,7 @@ def filter_nodes_candidates(graph: Graph):
33
34
  """
34
35
  nodes = list(graph.nodes)
35
36
  for n in nodes:
36
- n.quantization_cfg.candidates_quantization_cfg = filter_node_candidates(node=n)
37
+ n.candidates_quantization_cfg = filter_node_candidates(node=n, fw_info=graph.fw_info)
37
38
 
38
39
  return graph
39
40
 
@@ -70,7 +71,7 @@ def _filter_bit_method_dups(candidates: List[CandidateNodeQuantizationConfig],
70
71
  return final_candidates
71
72
 
72
73
 
73
- def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConfig]:
74
+ def filter_node_candidates(node: BaseNode, fw_info) -> List[CandidateNodeQuantizationConfig]:
74
75
  """
75
76
  Updates a node's candidates configuration list.
76
77
  If the node's weights quantization is disabled (or it only has activations to quantize), then the updated list
@@ -80,13 +81,15 @@ def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConf
80
81
 
81
82
  Args:
82
83
  node: Node to set its quantization configurations.
84
+ fw_info: FrameworkInfo object with information about the specific framework's model.
83
85
 
84
86
  """
85
87
 
86
88
  filtered_candidates = copy.deepcopy(node.candidates_quantization_cfg)
87
89
  final_candidates = copy.deepcopy(node.candidates_quantization_cfg)
90
+ kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
88
91
 
89
- if (node.kernel_attr is None or not node.is_weights_quantization_enabled(node.kernel_attr)) and node.is_no_quantization():
92
+ if (kernel_attr is None or not node.is_weights_quantization_enabled(kernel_attr)) and not node.is_activation_quantization_enabled():
90
93
  # If activation quantization is disabled and the node doesn't have a kernel or doesn't quantize the kernel,
91
94
  # but for some reason the node has multiple candidates then replace it with a single dummy candidate with
92
95
  # default bit-width values.
@@ -94,17 +97,16 @@ def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConf
94
97
  single_dummy_candidate.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH
95
98
  single_dummy_candidate.activation_quantization_cfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO
96
99
 
97
- if node.kernel_attr is not None:
98
- kernel_config = single_dummy_candidate.weights_quantization_cfg.get_attr_config(node.kernel_attr)
100
+ if kernel_attr is not None:
101
+ kernel_config = single_dummy_candidate.weights_quantization_cfg.get_attr_config(kernel_attr)
99
102
  kernel_config.weights_n_bits = FLOAT_BITWIDTH
100
103
  kernel_config.weights_quantization_method = QuantizationMethod.POWER_OF_TWO
101
104
 
102
105
  final_candidates = [single_dummy_candidate]
103
106
 
104
- elif node.is_no_quantization():
107
+ elif not node.is_activation_quantization_enabled():
105
108
  # Remove candidates that have duplicated weights candidates for node with disabled activation quantization.
106
109
  # Replacing the activation n_bits in the remained configurations with default value to prevent confusion.
107
- # Set the config of the non-quantized FLN node to POWER_OF_TWO.
108
110
  seen_candidates = set()
109
111
  filtered_candidates = [candidate for candidate in filtered_candidates if
110
112
  candidate.weights_quantization_cfg not in seen_candidates
@@ -114,17 +116,9 @@ def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConf
114
116
  c.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH
115
117
  c.activation_quantization_cfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO
116
118
 
117
- final_candidates = _filter_bit_method_dups(filtered_candidates, node.kernel_attr)
118
-
119
- elif node.is_fln_no_quantization() or node.is_fln_quantization():
120
- # Remove candidates that have duplicated weights candidates for node with disabled activation quantization.
121
- seen_candidates = set()
122
- filtered_candidates = [candidate for candidate in filtered_candidates if
123
- candidate.weights_quantization_cfg not in seen_candidates
124
- and not seen_candidates.add(candidate.weights_quantization_cfg)]
125
- final_candidates = _filter_bit_method_dups(filtered_candidates, node.kernel_attr)
119
+ final_candidates = _filter_bit_method_dups(filtered_candidates, kernel_attr)
126
120
 
127
- elif node.kernel_attr is None or not node.is_weights_quantization_enabled(node.kernel_attr):
121
+ elif kernel_attr is None or not node.is_weights_quantization_enabled(kernel_attr):
128
122
  # TODO:
129
123
  # To allow MP on positional weights we need to modify this to consider all weights not only kernel.
130
124
  # Remove candidates that have duplicated activation candidates for node with disabled weights quantization.
@@ -135,11 +129,11 @@ def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConf
135
129
  and not seen_candidates.add(candidate.activation_quantization_cfg)]
136
130
 
137
131
  for c in filtered_candidates:
138
- if node.kernel_attr is not None:
139
- kernel_config = c.weights_quantization_cfg.get_attr_config(node.kernel_attr)
132
+ if kernel_attr is not None:
133
+ kernel_config = c.weights_quantization_cfg.get_attr_config(kernel_attr)
140
134
  kernel_config.weights_n_bits = FLOAT_BITWIDTH
141
135
  kernel_config.weights_quantization_method = QuantizationMethod.POWER_OF_TWO
142
136
 
143
- final_candidates = _filter_bit_method_dups(filtered_candidates, node.kernel_attr)
137
+ final_candidates = _filter_bit_method_dups(filtered_candidates, kernel_attr)
144
138
 
145
139
  return final_candidates