mct-nightly 2.4.0.20250924.535__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -28,13 +28,15 @@ class PruningFrameworkImplementation(FrameworkImplementation):
|
|
28
28
|
@abstractmethod
|
29
29
|
def prune_entry_node(self,
|
30
30
|
node: BaseNode,
|
31
|
-
output_mask: np.ndarray
|
31
|
+
output_mask: np.ndarray,
|
32
|
+
fw_info: FrameworkInfo):
|
32
33
|
"""
|
33
34
|
Abstract method to prune an entry node in the model.
|
34
35
|
|
35
36
|
Args:
|
36
37
|
node: The node to be pruned.
|
37
38
|
output_mask: A numpy array representing the mask to be applied to the output channels.
|
39
|
+
fw_info: Framework-specific information.
|
38
40
|
|
39
41
|
Raises:
|
40
42
|
NotImplemented: If the method is not implemented in the subclass.
|
@@ -46,7 +48,8 @@ class PruningFrameworkImplementation(FrameworkImplementation):
|
|
46
48
|
def prune_intermediate_node(self,
|
47
49
|
node: BaseNode,
|
48
50
|
input_mask: np.ndarray,
|
49
|
-
output_mask: np.ndarray
|
51
|
+
output_mask: np.ndarray,
|
52
|
+
fw_info: FrameworkInfo):
|
50
53
|
"""
|
51
54
|
Abstract method to prune an intermediate node in the model.
|
52
55
|
|
@@ -54,6 +57,7 @@ class PruningFrameworkImplementation(FrameworkImplementation):
|
|
54
57
|
node: The node to be pruned.
|
55
58
|
input_mask: Mask to be applied to the input channels.
|
56
59
|
output_mask: Mask to be applied to the output channels.
|
60
|
+
fw_info: Framework-specific information.
|
57
61
|
|
58
62
|
Raises:
|
59
63
|
NotImplemented: If the method is not implemented in the subclass.
|
@@ -64,13 +68,15 @@ class PruningFrameworkImplementation(FrameworkImplementation):
|
|
64
68
|
@abstractmethod
|
65
69
|
def prune_exit_node(self,
|
66
70
|
node: BaseNode,
|
67
|
-
input_mask: np.ndarray
|
71
|
+
input_mask: np.ndarray,
|
72
|
+
fw_info: FrameworkInfo):
|
68
73
|
"""
|
69
74
|
Abstract method to prune an exit node in the model.
|
70
75
|
|
71
76
|
Args:
|
72
77
|
node: The node to be pruned.
|
73
78
|
input_mask: Mask to be applied to the input channels.
|
79
|
+
fw_info: Framework-specific information.
|
74
80
|
|
75
81
|
Raises:
|
76
82
|
NotImplemented: If the method is not implemented in the subclass.
|
@@ -99,7 +105,8 @@ class PruningFrameworkImplementation(FrameworkImplementation):
|
|
99
105
|
@abstractmethod
|
100
106
|
def is_node_exit_node(self,
|
101
107
|
node: BaseNode,
|
102
|
-
corresponding_entry_node: BaseNode
|
108
|
+
corresponding_entry_node: BaseNode,
|
109
|
+
fw_info: FrameworkInfo) -> bool:
|
103
110
|
|
104
111
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
105
112
|
f'framework\'s is_node_exit_node method.') # pragma: no cover
|
@@ -122,7 +129,7 @@ class PruningFrameworkImplementation(FrameworkImplementation):
|
|
122
129
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
123
130
|
f'framework\'s is_node_intermediate_pruning_section method.') # pragma: no cover
|
124
131
|
|
125
|
-
def attrs_oi_channels_info_for_pruning(self, node: BaseNode) -> Dict[str, Tuple[int, int]]:
|
132
|
+
def attrs_oi_channels_info_for_pruning(self, node: BaseNode, fw_info: FrameworkInfo) -> Dict[str, Tuple[int, int]]:
|
126
133
|
"""
|
127
134
|
Retrieves the attributes of a given node along with the output/input (OI) channel axis
|
128
135
|
for each attribute used to prune these attributes.
|
@@ -139,6 +146,7 @@ class PruningFrameworkImplementation(FrameworkImplementation):
|
|
139
146
|
|
140
147
|
Args:
|
141
148
|
node (BaseNode): The node from the computational graph.
|
149
|
+
fw_info (FrameworkInfo): Contains framework-specific information and utilities.
|
142
150
|
|
143
151
|
Returns:
|
144
152
|
Dict[str, Tuple[int, int]]: A dictionary where each key is an attribute name (like 'kernel' or 'bias')
|
@@ -76,28 +76,34 @@ class PruningSection:
|
|
76
76
|
|
77
77
|
def apply_inner_section_mask(self,
|
78
78
|
pruning_section_mask: PruningSectionMask,
|
79
|
-
fw_impl: Any
|
79
|
+
fw_impl: Any,
|
80
|
+
fw_info: FrameworkInfo):
|
80
81
|
"""
|
81
82
|
Apply the provided pruning section mask to all nodes within the pruning section.
|
82
83
|
|
83
84
|
Args:
|
84
85
|
pruning_section_mask (PruningSectionMask): The mask to be applied to the pruning section.
|
85
86
|
fw_impl (PruningFrameworkImplementation): Framework-specific implementation for applying the mask.
|
87
|
+
fw_info (FrameworkInfo): Framework-specific information needed to apply the mask.
|
86
88
|
"""
|
87
89
|
fw_impl.prune_entry_node(node=self.entry_node,
|
88
|
-
output_mask=pruning_section_mask.entry_node_oc_mask
|
90
|
+
output_mask=pruning_section_mask.entry_node_oc_mask,
|
91
|
+
fw_info=fw_info)
|
89
92
|
|
90
93
|
for inter_node in self.intermediate_nodes:
|
91
94
|
fw_impl.prune_intermediate_node(node=inter_node,
|
92
95
|
input_mask=pruning_section_mask.entry_node_oc_mask,
|
93
|
-
output_mask=pruning_section_mask.entry_node_oc_mask
|
96
|
+
output_mask=pruning_section_mask.entry_node_oc_mask,
|
97
|
+
fw_info=fw_info)
|
94
98
|
|
95
99
|
fw_impl.prune_exit_node(self.exit_node,
|
96
|
-
input_mask=pruning_section_mask.exit_node_ic_mask
|
100
|
+
input_mask=pruning_section_mask.exit_node_ic_mask,
|
101
|
+
fw_info=fw_info)
|
97
102
|
|
98
103
|
@staticmethod
|
99
104
|
def has_matching_channel_count(exit_node: BaseNode,
|
100
|
-
corresponding_entry_node: BaseNode
|
105
|
+
corresponding_entry_node: BaseNode,
|
106
|
+
fw_info: FrameworkInfo) -> bool:
|
101
107
|
"""
|
102
108
|
Checks if the number of input channels of the exit node matches the number of output channels
|
103
109
|
of its corresponding entry node.
|
@@ -109,10 +115,13 @@ class PruningSection:
|
|
109
115
|
Returns:
|
110
116
|
bool: True if the channel counts match, False otherwise.
|
111
117
|
"""
|
112
|
-
exit_input_channel_axis = exit_node.
|
113
|
-
entry_output_channel_axis = corresponding_entry_node.
|
118
|
+
_, exit_input_channel_axis = fw_info.kernel_channels_mapping.get(exit_node.type)
|
119
|
+
entry_output_channel_axis, _ = fw_info.kernel_channels_mapping.get(corresponding_entry_node.type)
|
114
120
|
|
115
|
-
|
116
|
-
|
121
|
+
exit_node_attr = fw_info.get_kernel_op_attributes(exit_node.type)[0]
|
122
|
+
entry_node_attr = fw_info.get_kernel_op_attributes(corresponding_entry_node.type)[0]
|
123
|
+
|
124
|
+
exit_input_channels = exit_node.get_weights_by_keys(exit_node_attr).shape[exit_input_channel_axis]
|
125
|
+
entry_output_channels = corresponding_entry_node.get_weights_by_keys(entry_node_attr).shape[entry_output_channel_axis]
|
117
126
|
|
118
127
|
return exit_input_channels == entry_output_channels
|
@@ -19,8 +19,8 @@ from model_compression_toolkit.core.common import Graph
|
|
19
19
|
from model_compression_toolkit.core.common.matchers.node_matcher import BaseNodeMatcher
|
20
20
|
from model_compression_toolkit.logger import Logger
|
21
21
|
|
22
|
-
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
|
23
|
-
from model_compression_toolkit.target_platform_capabilities.constants import
|
22
|
+
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
|
23
|
+
from model_compression_toolkit.target_platform_capabilities.constants import POS_ATTR
|
24
24
|
|
25
25
|
|
26
26
|
@dataclass
|
@@ -95,7 +95,7 @@ class BitWidthConfig:
|
|
95
95
|
for attr, bit_width, filter in zip (attrs, bit_widths, filters):
|
96
96
|
self.manual_weights_bit_width_selection_list += [ManualWeightsBitWidthSelection(filter, bit_width, attr)]
|
97
97
|
|
98
|
-
def
|
98
|
+
def get_nodes_to_manipulate_activation_bit_widths(self, graph: Graph) -> Dict:
|
99
99
|
"""
|
100
100
|
Retrieve nodes from the graph that need their bit-widths for activation changed according to the manual bit-width selections.
|
101
101
|
|
@@ -108,7 +108,7 @@ class BitWidthConfig:
|
|
108
108
|
activation_nodes_to_change_bit_width = self._construct_node_to_new_activation_bit_mapping(graph)
|
109
109
|
return activation_nodes_to_change_bit_width
|
110
110
|
|
111
|
-
def
|
111
|
+
def get_nodes_to_manipulate_weights_bit_widths(self, graph: Graph) -> Dict:
|
112
112
|
"""
|
113
113
|
Retrieve nodes from the graph that need their bit-widths for weights changed according to the manual bit-width selections.
|
114
114
|
|
@@ -166,7 +166,7 @@ class BitWidthConfig:
|
|
166
166
|
attrs = BitWidthConfig._expand_to_list_core(filters, attrs)
|
167
167
|
return attrs, bit_widths, filters
|
168
168
|
|
169
|
-
def _construct_node_to_new_activation_bit_mapping(self, graph) -> Dict
|
169
|
+
def _construct_node_to_new_activation_bit_mapping(self, graph) -> Dict:
|
170
170
|
"""
|
171
171
|
Retrieve nodes from the graph that need their activation bit-widths changed according to the manual bit-width selections.
|
172
172
|
|
@@ -192,7 +192,7 @@ class BitWidthConfig:
|
|
192
192
|
unit_nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width})
|
193
193
|
return unit_nodes_to_change_bit_width
|
194
194
|
|
195
|
-
def _construct_node_to_new_weights_bit_mapping(self, graph) -> Dict
|
195
|
+
def _construct_node_to_new_weights_bit_mapping(self, graph) -> Dict:
|
196
196
|
"""
|
197
197
|
Retrieve nodes from the graph that need their weights bit-widths changed according to the manual bit-width selections.
|
198
198
|
|
@@ -212,7 +212,7 @@ class BitWidthConfig:
|
|
212
212
|
f"to change their bit width to {manual_bit_width_selection.bit_width}.")
|
213
213
|
|
214
214
|
for n in filtered_nodes:
|
215
|
-
attr_to_change_bit_width =
|
215
|
+
attr_to_change_bit_width = []
|
216
216
|
|
217
217
|
attrs_str = n.get_node_weights_attributes()
|
218
218
|
if len(attrs_str) == 0:
|
@@ -225,8 +225,8 @@ class BitWidthConfig:
|
|
225
225
|
attr.append(attr_str)
|
226
226
|
# this is a positional attribute, so it needs to be handled separately.
|
227
227
|
# Search manual_bit_width_selection's attribute that contain the POS_ATTR string.
|
228
|
-
elif isinstance(attr_str, int) and
|
229
|
-
attr.append(
|
228
|
+
elif isinstance(attr_str, int) and POS_ATTR in manual_bit_width_selection.attr:
|
229
|
+
attr.append(POS_ATTR)
|
230
230
|
if len(attr) == 0:
|
231
231
|
Logger.critical(f'The requested attribute {manual_bit_width_selection.attr} to change the bit width for {n} does not exist.')
|
232
232
|
|
@@ -239,7 +239,7 @@ class BitWidthConfig:
|
|
239
239
|
f"Node {n} has an existing manual bit width configuration of {manual_bit_width_selection.attr}."
|
240
240
|
f"A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
|
241
241
|
|
242
|
-
attr_to_change_bit_width[manual_bit_width_selection.
|
242
|
+
attr_to_change_bit_width.append([manual_bit_width_selection.bit_width, manual_bit_width_selection.attr])
|
243
243
|
unit_nodes_to_change_bit_width.update({n: attr_to_change_bit_width})
|
244
244
|
|
245
245
|
return unit_nodes_to_change_bit_width
|
@@ -12,133 +12,72 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
import
|
16
|
-
from dataclasses import dataclass, InitVar
|
17
|
-
from typing import Callable, List, Optional
|
15
|
+
from typing import Callable, List, Tuple
|
18
16
|
|
17
|
+
from model_compression_toolkit.core import QuantizationConfig
|
19
18
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
|
20
|
-
NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
|
19
|
+
NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig, \
|
21
|
+
OpQuantizationConfig
|
22
|
+
from model_compression_toolkit.logger import Logger
|
21
23
|
|
22
24
|
|
23
|
-
|
25
|
+
##########################################
|
26
|
+
# Every node holds a quantization configuration
|
27
|
+
# for its weights quantization, and a different quantization
|
28
|
+
# configuration for its activation quantization configuration.
|
29
|
+
##########################################
|
30
|
+
|
24
31
|
class CandidateNodeQuantizationConfig(BaseNodeQuantizationConfig):
|
25
32
|
"""
|
26
|
-
|
33
|
+
Class for representing candidate node configuration, which includes weights and activation configuration combined.
|
27
34
|
"""
|
28
|
-
activation_quantization_cfg: NodeActivationQuantizationConfig
|
29
|
-
# TODO irena: None is passed in several places, need to check if it's handled properly or it's only passed in cases
|
30
|
-
# that do not affect anything (my guess is it's the second).
|
31
|
-
# I think in general it makes more sense to set it to None when there are no weights, and maybe when all weights
|
32
|
-
# are unquantized, and handle it properly everywhere.
|
33
|
-
weights_quantization_cfg: Optional[NodeWeightsQuantizationConfig]
|
34
|
-
|
35
|
-
|
36
|
-
# TODO irena: currently all code still looks at candidates_quantization_cfg as previously, so this is just an initial
|
37
|
-
# implementation. For now base config is completely separated from candidates (base config must be equal to one of the
|
38
|
-
# candidates, but we create a separate copy), and updating in place is allowed. Also we require quantization mode to
|
39
|
-
# be identical between all configs.
|
40
|
-
@dataclass
|
41
|
-
class NodeQuantizationConfig:
|
42
|
-
# quantization config for single precision
|
43
|
-
base_quantization_cfg: CandidateNodeQuantizationConfig
|
44
|
-
# quantization candidate configs for mixed precision
|
45
|
-
candidates_quantization_cfg: List[CandidateNodeQuantizationConfig]
|
46
|
-
|
47
|
-
validate: InitVar[bool] = True
|
48
|
-
|
49
|
-
def update_all(self, update_fn: Callable[[CandidateNodeQuantizationConfig], None], remove_duplicates: bool = True):
|
50
|
-
"""
|
51
|
-
Apply update function on the base config and all candidates configs.
|
52
|
-
|
53
|
-
Args:
|
54
|
-
update_fn: function to apply.
|
55
|
-
remove_duplicates: remove duplicate candidates.
|
56
|
-
"""
|
57
|
-
if self.base_quantization_cfg:
|
58
|
-
update_fn(self.base_quantization_cfg)
|
59
|
-
for cfg in self.candidates_quantization_cfg:
|
60
|
-
update_fn(cfg)
|
61
|
-
if remove_duplicates:
|
62
|
-
self.remove_duplicates()
|
63
35
|
|
64
|
-
def
|
36
|
+
def __init__(self,
|
37
|
+
qc: QuantizationConfig = None,
|
38
|
+
op_cfg: OpQuantizationConfig = None,
|
39
|
+
activation_quantization_cfg: NodeActivationQuantizationConfig = None,
|
40
|
+
activation_quantization_fn: Callable = None,
|
41
|
+
activation_quantization_params_fn: Callable = None,
|
42
|
+
weights_quantization_cfg: NodeWeightsQuantizationConfig = None,
|
43
|
+
weights_channels_axis: Tuple[int, int] = None,
|
44
|
+
node_attrs_list: List[str] = None):
|
65
45
|
"""
|
66
|
-
Update activation quantization mode for the base config and all candidates configs.
|
67
46
|
|
68
47
|
Args:
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
"""
|
78
|
-
Disable all weights quantization for the base config and all candidates configs.
|
79
|
-
"""
|
80
|
-
self.update_all(lambda c: c.weights_quantization_cfg.disable_all_weights_quantization())
|
81
|
-
|
82
|
-
def get_activation_quant_mode(self) -> ActivationQuantizationMode:
|
83
|
-
"""
|
84
|
-
Retrieve activation quantization mode.
|
85
|
-
|
86
|
-
Returns:
|
87
|
-
Activation quantization mode.
|
88
|
-
|
89
|
-
Raises:
|
90
|
-
ValueError if not all candidates contain the same mode.
|
91
|
-
"""
|
92
|
-
self._validate_consistent_activation_quant_mode()
|
93
|
-
return self.base_quantization_cfg.activation_quantization_cfg.quant_mode
|
94
|
-
|
95
|
-
def remove_duplicates(self):
|
96
|
-
"""
|
97
|
-
Remove duplicate candidates. First candidate among duplicates is kept, and the order is preserved.
|
48
|
+
qc: QuantizationConfig to create the node's config from.
|
49
|
+
op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
|
50
|
+
activation_quantization_cfg: An option to pass a NodeActivationQuantizationConfig to create a new config from.
|
51
|
+
activation_quantization_fn: Function to use when quantizing the node's activations.
|
52
|
+
activation_quantization_params_fn: Function to use when computing the threshold for quantizing a node's activations.
|
53
|
+
weights_quantization_cfg: An option to pass a NodeWeightsQuantizationConfig to create a new config from.
|
54
|
+
weights_channels_axis: Axis to quantize a node's weights attribute when quantizing per-channel.
|
55
|
+
node_attrs_list: A list of the node's weights attributes names.
|
98
56
|
"""
|
99
|
-
uniq_qcs = []
|
100
|
-
for qc in self.candidates_quantization_cfg:
|
101
|
-
if qc not in uniq_qcs:
|
102
|
-
uniq_qcs.append(qc)
|
103
|
-
self.candidates_quantization_cfg = uniq_qcs
|
104
57
|
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
58
|
+
if activation_quantization_cfg is not None:
|
59
|
+
self.activation_quantization_cfg = activation_quantization_cfg
|
60
|
+
else:
|
61
|
+
if any(v is None for v in (qc, op_cfg, activation_quantization_fn, activation_quantization_params_fn)): # pragma: no cover
|
62
|
+
Logger.critical(
|
63
|
+
"Missing required arguments to initialize a node activation quantization configuration. "
|
64
|
+
"Ensure QuantizationConfig, OpQuantizationConfig, activation quantization function, "
|
65
|
+
"and parameters function are provided.")
|
66
|
+
self.activation_quantization_cfg = (
|
67
|
+
NodeActivationQuantizationConfig(qc=qc,
|
68
|
+
op_cfg=op_cfg,
|
69
|
+
activation_quantization_fn=activation_quantization_fn,
|
70
|
+
activation_quantization_params_fn=activation_quantization_params_fn))
|
71
|
+
|
72
|
+
if weights_quantization_cfg is not None:
|
73
|
+
self.weights_quantization_cfg = weights_quantization_cfg
|
74
|
+
elif all(v is not None for v in (qc, op_cfg, node_attrs_list)):
|
75
|
+
self.weights_quantization_cfg = NodeWeightsQuantizationConfig(qc=qc,
|
76
|
+
op_cfg=op_cfg,
|
77
|
+
weights_channels_axis=weights_channels_axis,
|
78
|
+
node_attrs_list=node_attrs_list)
|
79
|
+
else:
|
80
|
+
self.weights_quantization_cfg = None
|
81
|
+
Logger.debug("Setting weights quantization config as None during CandidateNodeQuantizationConfig creation."
|
82
|
+
"Notice, this should happen only for FLN nodes.")
|
115
83
|
|
116
|
-
def _validate_consistent_activation_quant_mode(self):
|
117
|
-
"""
|
118
|
-
Validate that base config and all candidates configs contain identical activation quantization mode.
|
119
|
-
|
120
|
-
Raises:
|
121
|
-
ValueError if activation quantization mode is not consistent.
|
122
|
-
"""
|
123
|
-
activation_quant_mode = self.base_quantization_cfg.activation_quantization_cfg.quant_mode
|
124
|
-
if any(qc.activation_quantization_cfg.quant_mode != activation_quant_mode
|
125
|
-
for qc in self.candidates_quantization_cfg):
|
126
|
-
raise ValueError('Quantization candidates with different quantization modes are not currently supported.')
|
127
|
-
|
128
|
-
def _validate_consistent_weights_quant_mode(self):
|
129
|
-
"""
|
130
|
-
Validate that base config and all candidates configs contain identical weights quantization mode per attribute,
|
131
|
-
i.e. quantization for each attribute should either be enabled in all configs, or disabled in all configs.
|
132
|
-
|
133
|
-
Raises:
|
134
|
-
ValueError if weights quantization is not consistent.
|
135
|
-
"""
|
136
|
-
def get_weights_mode(qc):
|
137
|
-
# in graph fuser weights_quantization_cfg is set to None
|
138
|
-
if qc.weights_quantization_cfg is None:
|
139
|
-
return None
|
140
|
-
return {attr: attr_cfg.enable_weights_quantization for attr, attr_cfg
|
141
|
-
in qc.weights_quantization_cfg.get_all_weight_attrs_configs().items()}
|
142
|
-
if any(get_weights_mode(self.base_quantization_cfg) != get_weights_mode(qc)
|
143
|
-
for qc in self.candidates_quantization_cfg):
|
144
|
-
raise ValueError('Quantization candidates with different quantization modes are not currently supported.')
|
@@ -21,6 +21,7 @@ from model_compression_toolkit.constants import FLOAT_BITWIDTH
|
|
21
21
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
22
22
|
CandidateNodeQuantizationConfig
|
23
23
|
|
24
|
+
|
24
25
|
def filter_nodes_candidates(graph: Graph):
|
25
26
|
"""
|
26
27
|
Filters the graph's nodes candidates configuration list.
|
@@ -33,7 +34,7 @@ def filter_nodes_candidates(graph: Graph):
|
|
33
34
|
"""
|
34
35
|
nodes = list(graph.nodes)
|
35
36
|
for n in nodes:
|
36
|
-
n.
|
37
|
+
n.candidates_quantization_cfg = filter_node_candidates(node=n, fw_info=graph.fw_info)
|
37
38
|
|
38
39
|
return graph
|
39
40
|
|
@@ -70,7 +71,7 @@ def _filter_bit_method_dups(candidates: List[CandidateNodeQuantizationConfig],
|
|
70
71
|
return final_candidates
|
71
72
|
|
72
73
|
|
73
|
-
def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConfig]:
|
74
|
+
def filter_node_candidates(node: BaseNode, fw_info) -> List[CandidateNodeQuantizationConfig]:
|
74
75
|
"""
|
75
76
|
Updates a node's candidates configuration list.
|
76
77
|
If the node's weights quantization is disabled (or it only has activations to quantize), then the updated list
|
@@ -80,13 +81,15 @@ def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConf
|
|
80
81
|
|
81
82
|
Args:
|
82
83
|
node: Node to set its quantization configurations.
|
84
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
83
85
|
|
84
86
|
"""
|
85
87
|
|
86
88
|
filtered_candidates = copy.deepcopy(node.candidates_quantization_cfg)
|
87
89
|
final_candidates = copy.deepcopy(node.candidates_quantization_cfg)
|
90
|
+
kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
|
88
91
|
|
89
|
-
if (
|
92
|
+
if (kernel_attr is None or not node.is_weights_quantization_enabled(kernel_attr)) and not node.is_activation_quantization_enabled():
|
90
93
|
# If activation quantization is disabled and the node doesn't have a kernel or doesn't quantize the kernel,
|
91
94
|
# but for some reason the node has multiple candidates then replace it with a single dummy candidate with
|
92
95
|
# default bit-width values.
|
@@ -94,17 +97,16 @@ def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConf
|
|
94
97
|
single_dummy_candidate.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH
|
95
98
|
single_dummy_candidate.activation_quantization_cfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO
|
96
99
|
|
97
|
-
if
|
98
|
-
kernel_config = single_dummy_candidate.weights_quantization_cfg.get_attr_config(
|
100
|
+
if kernel_attr is not None:
|
101
|
+
kernel_config = single_dummy_candidate.weights_quantization_cfg.get_attr_config(kernel_attr)
|
99
102
|
kernel_config.weights_n_bits = FLOAT_BITWIDTH
|
100
103
|
kernel_config.weights_quantization_method = QuantizationMethod.POWER_OF_TWO
|
101
104
|
|
102
105
|
final_candidates = [single_dummy_candidate]
|
103
106
|
|
104
|
-
elif node.
|
107
|
+
elif not node.is_activation_quantization_enabled():
|
105
108
|
# Remove candidates that have duplicated weights candidates for node with disabled activation quantization.
|
106
109
|
# Replacing the activation n_bits in the remained configurations with default value to prevent confusion.
|
107
|
-
# Set the config of the non-quantized FLN node to POWER_OF_TWO.
|
108
110
|
seen_candidates = set()
|
109
111
|
filtered_candidates = [candidate for candidate in filtered_candidates if
|
110
112
|
candidate.weights_quantization_cfg not in seen_candidates
|
@@ -114,17 +116,9 @@ def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConf
|
|
114
116
|
c.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH
|
115
117
|
c.activation_quantization_cfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO
|
116
118
|
|
117
|
-
final_candidates = _filter_bit_method_dups(filtered_candidates,
|
118
|
-
|
119
|
-
elif node.is_fln_no_quantization() or node.is_fln_quantization():
|
120
|
-
# Remove candidates that have duplicated weights candidates for node with disabled activation quantization.
|
121
|
-
seen_candidates = set()
|
122
|
-
filtered_candidates = [candidate for candidate in filtered_candidates if
|
123
|
-
candidate.weights_quantization_cfg not in seen_candidates
|
124
|
-
and not seen_candidates.add(candidate.weights_quantization_cfg)]
|
125
|
-
final_candidates = _filter_bit_method_dups(filtered_candidates, node.kernel_attr)
|
119
|
+
final_candidates = _filter_bit_method_dups(filtered_candidates, kernel_attr)
|
126
120
|
|
127
|
-
elif
|
121
|
+
elif kernel_attr is None or not node.is_weights_quantization_enabled(kernel_attr):
|
128
122
|
# TODO:
|
129
123
|
# To allow MP on positional weights we need to modify this to consider all weights not only kernel.
|
130
124
|
# Remove candidates that have duplicated activation candidates for node with disabled weights quantization.
|
@@ -135,11 +129,11 @@ def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConf
|
|
135
129
|
and not seen_candidates.add(candidate.activation_quantization_cfg)]
|
136
130
|
|
137
131
|
for c in filtered_candidates:
|
138
|
-
if
|
139
|
-
kernel_config = c.weights_quantization_cfg.get_attr_config(
|
132
|
+
if kernel_attr is not None:
|
133
|
+
kernel_config = c.weights_quantization_cfg.get_attr_config(kernel_attr)
|
140
134
|
kernel_config.weights_n_bits = FLOAT_BITWIDTH
|
141
135
|
kernel_config.weights_quantization_method = QuantizationMethod.POWER_OF_TWO
|
142
136
|
|
143
|
-
final_candidates = _filter_bit_method_dups(filtered_candidates,
|
137
|
+
final_candidates = _filter_bit_method_dups(filtered_candidates, kernel_attr)
|
144
138
|
|
145
139
|
return final_candidates
|