mct-nightly 2.4.0.20250616.616__py3-none-any.whl → 2.4.0.20250618.606__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/METADATA +1 -1
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/RECORD +120 -120
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +2 -5
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
- model_compression_toolkit/core/common/framework_implementation.py +10 -22
- model_compression_toolkit/core/common/framework_info.py +105 -68
- model_compression_toolkit/core/common/graph/base_graph.py +15 -42
- model_compression_toolkit/core/common/graph/base_node.py +103 -42
- model_compression_toolkit/core/common/graph/functional_node.py +18 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
- model_compression_toolkit/core/common/model_collector.py +10 -20
- model_compression_toolkit/core/common/model_validation.py +1 -4
- model_compression_toolkit/core/common/network_editors/actions.py +14 -38
- model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
- model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
- model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
- model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
- model_compression_toolkit/core/common/pruning/pruner.py +1 -6
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
- model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
- model_compression_toolkit/core/graph_prep_runner.py +2 -16
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/keras/default_framework_info.py +138 -87
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
- model_compression_toolkit/core/keras/keras_implementation.py +15 -35
- model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
- model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +0 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/quantization_prep_runner.py +4 -9
- model_compression_toolkit/core/runner.py +5 -15
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
- model_compression_toolkit/gptq/common/gptq_training.py +1 -8
- model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
- model_compression_toolkit/gptq/keras/graph_info.py +4 -6
- model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
- model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/runner.py +1 -7
- model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
- model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
- model_compression_toolkit/ptq/runner.py +1 -4
- model_compression_toolkit/qat/common/qat_config.py +2 -6
- model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
- model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
- model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
- model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
- model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
- model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
- model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/top_level.txt +0 -0
@@ -34,18 +34,16 @@ class MemoryCalculator:
|
|
34
34
|
which is crucial for deploying models on memory-constrained devices or optimizing for computational efficiency.
|
35
35
|
"""
|
36
36
|
|
37
|
-
def __init__(self, graph: Graph,
|
37
|
+
def __init__(self, graph: Graph, fw_impl: PruningFrameworkImplementation):
|
38
38
|
"""
|
39
39
|
Initializes the MemoryCalculator with necessary information about the model's graph,
|
40
40
|
framework-specific details, and pruning implementation.
|
41
41
|
|
42
42
|
Args:
|
43
43
|
graph (Graph): Computational graph of the model.
|
44
|
-
fw_info (FrameworkInfo): Contains framework-specific information.
|
45
44
|
fw_impl (PruningFrameworkImplementation): Implementation details for pruning.
|
46
45
|
"""
|
47
46
|
self.graph = graph
|
48
|
-
self.fw_info = fw_info
|
49
47
|
self.fw_impl = fw_impl
|
50
48
|
|
51
49
|
def get_pruned_graph_memory(self,
|
@@ -204,19 +202,13 @@ class MemoryCalculator:
|
|
204
202
|
if node == section.exit_node:
|
205
203
|
return masks.get(section.entry_node)
|
206
204
|
|
207
|
-
kernel_attr = self.fw_info.get_kernel_op_attributes(node.type)
|
208
|
-
# Ensure only one kernel attribute exists for the given node.
|
209
|
-
if len(kernel_attr) != 1:
|
210
|
-
Logger.critical(f"Expected a single attribute, but found {len(kernel_attr)} attributes for node '{node}'. Ensure the node configuration is correct.")
|
211
|
-
kernel_attr = kernel_attr[0]
|
212
|
-
|
213
205
|
# Retrieve and validate the axis index for the output channels.
|
214
|
-
|
206
|
+
ic_axis = node.channel_axis.input
|
215
207
|
if ic_axis is None or int(ic_axis) != ic_axis:
|
216
208
|
Logger.critical(f"Invalid input channel axis type for node '{node}': expected integer but got '{ic_axis}'.")
|
217
209
|
|
218
210
|
# Get the number of output channels based on the kernel attribute and axis.
|
219
|
-
num_ic = node.get_weights_by_keys(kernel_attr).shape[ic_axis]
|
211
|
+
num_ic = node.get_weights_by_keys(node.kernel_attr).shape[ic_axis]
|
220
212
|
mask = np.ones(num_ic, dtype=bool)
|
221
213
|
return mask
|
222
214
|
|
@@ -289,7 +281,7 @@ class MemoryCalculator:
|
|
289
281
|
int: The total number of parameters in the node after pruning.
|
290
282
|
"""
|
291
283
|
total_params = 0
|
292
|
-
attributes_and_oc_axis = self.fw_impl.attrs_oi_channels_info_for_pruning(node
|
284
|
+
attributes_and_oc_axis = self.fw_impl.attrs_oi_channels_info_for_pruning(node)
|
293
285
|
|
294
286
|
# Iterate over the node's weights and apply pruning based on the masks.
|
295
287
|
for w_attr, w in node.weights.items():
|
@@ -311,7 +303,7 @@ class MemoryCalculator:
|
|
311
303
|
num_oc = np.sum(output_mask)
|
312
304
|
else:
|
313
305
|
# Get the node channel axis from framework info
|
314
|
-
channel_axis =
|
306
|
+
channel_axis = node.out_channel_axis
|
315
307
|
if channel_axis is None:
|
316
308
|
Logger.critical(f"The channel axis is undefined. Please ensure the channel axis is explicitly defined for node {node.type} in the framework info.")
|
317
309
|
|
@@ -27,7 +27,6 @@ from model_compression_toolkit.logger import Logger
|
|
27
27
|
|
28
28
|
def build_pruned_graph(graph: Graph,
|
29
29
|
masks: Dict[BaseNode, np.ndarray],
|
30
|
-
fw_info: FrameworkInfo,
|
31
30
|
fw_impl: FrameworkImplementation) -> Graph:
|
32
31
|
"""
|
33
32
|
Prunes the provided graph according to the given pruning output-channels masks.
|
@@ -35,7 +34,6 @@ def build_pruned_graph(graph: Graph,
|
|
35
34
|
Args:
|
36
35
|
graph: The original computational graph to be pruned.
|
37
36
|
masks: A dictionary mapping each prunable node to its pruning mask.
|
38
|
-
fw_info: Framework-specific information object.
|
39
37
|
fw_impl: Framework-specific implementation object.
|
40
38
|
|
41
39
|
Returns:
|
@@ -66,8 +64,7 @@ def build_pruned_graph(graph: Graph,
|
|
66
64
|
section_mask = PruningSectionMask(entry_node_oc_mask=mask,
|
67
65
|
exit_node_ic_mask=mask)
|
68
66
|
pruning_section.apply_inner_section_mask(section_mask,
|
69
|
-
fw_impl
|
70
|
-
fw_info)
|
67
|
+
fw_impl)
|
71
68
|
|
72
69
|
return graph_to_prune
|
73
70
|
|
@@ -40,7 +40,6 @@ class Pruner:
|
|
40
40
|
"""
|
41
41
|
def __init__(self,
|
42
42
|
float_graph: Graph,
|
43
|
-
fw_info: FrameworkInfo,
|
44
43
|
fw_impl: PruningFrameworkImplementation,
|
45
44
|
target_resource_utilization: ResourceUtilization,
|
46
45
|
representative_data_gen: Callable,
|
@@ -49,7 +48,6 @@ class Pruner:
|
|
49
48
|
"""
|
50
49
|
Args:
|
51
50
|
float_graph (Graph): The floating-point representation of the model's computation graph.
|
52
|
-
fw_info (FrameworkInfo): Contains metadata and helper functions for the framework.
|
53
51
|
fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning.
|
54
52
|
target_resource_utilization (ResourceUtilization): The target resource utilization to be achieved after pruning.
|
55
53
|
representative_data_gen (Callable): Generator function for representative dataset used in pruning analysis.
|
@@ -57,7 +55,6 @@ class Pruner:
|
|
57
55
|
target_platform_capabilities (FrameworkQuantizationCapabilities): Object encapsulating the capabilities of the target hardware platform.
|
58
56
|
"""
|
59
57
|
self.float_graph = float_graph
|
60
|
-
self.fw_info = fw_info
|
61
58
|
self.fw_impl = fw_impl
|
62
59
|
self.target_resource_utilization = target_resource_utilization
|
63
60
|
self.representative_data_gen = representative_data_gen
|
@@ -84,7 +81,6 @@ class Pruner:
|
|
84
81
|
# Apply Greedy strategy to compute masks based on importance scores.
|
85
82
|
if self.pruning_config.channels_filtering_strategy == ChannelsFilteringStrategy.GREEDY:
|
86
83
|
mask_calculator = GreedyMaskCalculator(entry_nodes,
|
87
|
-
self.fw_info,
|
88
84
|
self.simd_scores,
|
89
85
|
self.target_resource_utilization,
|
90
86
|
self.float_graph,
|
@@ -99,7 +95,6 @@ class Pruner:
|
|
99
95
|
Logger.info("Start pruning graph...")
|
100
96
|
_pruned_graph = build_pruned_graph(self.float_graph,
|
101
97
|
self.per_oc_mask,
|
102
|
-
self.fw_info,
|
103
98
|
self.fw_impl)
|
104
99
|
return _pruned_graph
|
105
100
|
|
@@ -116,7 +111,7 @@ class Pruner:
|
|
116
111
|
# Retrieve and initialize the importance metric.
|
117
112
|
im = get_importance_metric(self.pruning_config.importance_metric, graph=self.float_graph,
|
118
113
|
representative_data_gen=self.representative_data_gen, fw_impl=self.fw_impl,
|
119
|
-
pruning_config=self.pruning_config
|
114
|
+
pruning_config=self.pruning_config)
|
120
115
|
entry_node_to_simd_score, simd_groups_indices = im.get_entry_node_to_simd_score(entry_nodes)
|
121
116
|
return entry_node_to_simd_score, simd_groups_indices
|
122
117
|
|
@@ -28,15 +28,13 @@ class PruningFrameworkImplementation(FrameworkImplementation):
|
|
28
28
|
@abstractmethod
|
29
29
|
def prune_entry_node(self,
|
30
30
|
node: BaseNode,
|
31
|
-
output_mask: np.ndarray
|
32
|
-
fw_info: FrameworkInfo):
|
31
|
+
output_mask: np.ndarray):
|
33
32
|
"""
|
34
33
|
Abstract method to prune an entry node in the model.
|
35
34
|
|
36
35
|
Args:
|
37
36
|
node: The node to be pruned.
|
38
37
|
output_mask: A numpy array representing the mask to be applied to the output channels.
|
39
|
-
fw_info: Framework-specific information.
|
40
38
|
|
41
39
|
Raises:
|
42
40
|
NotImplemented: If the method is not implemented in the subclass.
|
@@ -48,8 +46,7 @@ class PruningFrameworkImplementation(FrameworkImplementation):
|
|
48
46
|
def prune_intermediate_node(self,
|
49
47
|
node: BaseNode,
|
50
48
|
input_mask: np.ndarray,
|
51
|
-
output_mask: np.ndarray
|
52
|
-
fw_info: FrameworkInfo):
|
49
|
+
output_mask: np.ndarray):
|
53
50
|
"""
|
54
51
|
Abstract method to prune an intermediate node in the model.
|
55
52
|
|
@@ -57,7 +54,6 @@ class PruningFrameworkImplementation(FrameworkImplementation):
|
|
57
54
|
node: The node to be pruned.
|
58
55
|
input_mask: Mask to be applied to the input channels.
|
59
56
|
output_mask: Mask to be applied to the output channels.
|
60
|
-
fw_info: Framework-specific information.
|
61
57
|
|
62
58
|
Raises:
|
63
59
|
NotImplemented: If the method is not implemented in the subclass.
|
@@ -68,15 +64,13 @@ class PruningFrameworkImplementation(FrameworkImplementation):
|
|
68
64
|
@abstractmethod
|
69
65
|
def prune_exit_node(self,
|
70
66
|
node: BaseNode,
|
71
|
-
input_mask: np.ndarray
|
72
|
-
fw_info: FrameworkInfo):
|
67
|
+
input_mask: np.ndarray):
|
73
68
|
"""
|
74
69
|
Abstract method to prune an exit node in the model.
|
75
70
|
|
76
71
|
Args:
|
77
72
|
node: The node to be pruned.
|
78
73
|
input_mask: Mask to be applied to the input channels.
|
79
|
-
fw_info: Framework-specific information.
|
80
74
|
|
81
75
|
Raises:
|
82
76
|
NotImplemented: If the method is not implemented in the subclass.
|
@@ -105,8 +99,7 @@ class PruningFrameworkImplementation(FrameworkImplementation):
|
|
105
99
|
@abstractmethod
|
106
100
|
def is_node_exit_node(self,
|
107
101
|
node: BaseNode,
|
108
|
-
corresponding_entry_node: BaseNode
|
109
|
-
fw_info: FrameworkInfo) -> bool:
|
102
|
+
corresponding_entry_node: BaseNode) -> bool:
|
110
103
|
|
111
104
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
112
105
|
f'framework\'s is_node_exit_node method.') # pragma: no cover
|
@@ -129,7 +122,7 @@ class PruningFrameworkImplementation(FrameworkImplementation):
|
|
129
122
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
130
123
|
f'framework\'s is_node_intermediate_pruning_section method.') # pragma: no cover
|
131
124
|
|
132
|
-
def attrs_oi_channels_info_for_pruning(self, node: BaseNode
|
125
|
+
def attrs_oi_channels_info_for_pruning(self, node: BaseNode) -> Dict[str, Tuple[int, int]]:
|
133
126
|
"""
|
134
127
|
Retrieves the attributes of a given node along with the output/input (OI) channel axis
|
135
128
|
for each attribute used to prune these attributes.
|
@@ -146,7 +139,6 @@ class PruningFrameworkImplementation(FrameworkImplementation):
|
|
146
139
|
|
147
140
|
Args:
|
148
141
|
node (BaseNode): The node from the computational graph.
|
149
|
-
fw_info (FrameworkInfo): Contains framework-specific information and utilities.
|
150
142
|
|
151
143
|
Returns:
|
152
144
|
Dict[str, Tuple[int, int]]: A dictionary where each key is an attribute name (like 'kernel' or 'bias')
|
@@ -76,34 +76,28 @@ class PruningSection:
|
|
76
76
|
|
77
77
|
def apply_inner_section_mask(self,
|
78
78
|
pruning_section_mask: PruningSectionMask,
|
79
|
-
fw_impl: Any
|
80
|
-
fw_info: FrameworkInfo):
|
79
|
+
fw_impl: Any):
|
81
80
|
"""
|
82
81
|
Apply the provided pruning section mask to all nodes within the pruning section.
|
83
82
|
|
84
83
|
Args:
|
85
84
|
pruning_section_mask (PruningSectionMask): The mask to be applied to the pruning section.
|
86
85
|
fw_impl (PruningFrameworkImplementation): Framework-specific implementation for applying the mask.
|
87
|
-
fw_info (FrameworkInfo): Framework-specific information needed to apply the mask.
|
88
86
|
"""
|
89
87
|
fw_impl.prune_entry_node(node=self.entry_node,
|
90
|
-
output_mask=pruning_section_mask.entry_node_oc_mask
|
91
|
-
fw_info=fw_info)
|
88
|
+
output_mask=pruning_section_mask.entry_node_oc_mask)
|
92
89
|
|
93
90
|
for inter_node in self.intermediate_nodes:
|
94
91
|
fw_impl.prune_intermediate_node(node=inter_node,
|
95
92
|
input_mask=pruning_section_mask.entry_node_oc_mask,
|
96
|
-
output_mask=pruning_section_mask.entry_node_oc_mask
|
97
|
-
fw_info=fw_info)
|
93
|
+
output_mask=pruning_section_mask.entry_node_oc_mask)
|
98
94
|
|
99
95
|
fw_impl.prune_exit_node(self.exit_node,
|
100
|
-
input_mask=pruning_section_mask.exit_node_ic_mask
|
101
|
-
fw_info=fw_info)
|
96
|
+
input_mask=pruning_section_mask.exit_node_ic_mask)
|
102
97
|
|
103
98
|
@staticmethod
|
104
99
|
def has_matching_channel_count(exit_node: BaseNode,
|
105
|
-
corresponding_entry_node: BaseNode
|
106
|
-
fw_info: FrameworkInfo) -> bool:
|
100
|
+
corresponding_entry_node: BaseNode) -> bool:
|
107
101
|
"""
|
108
102
|
Checks if the number of input channels of the exit node matches the number of output channels
|
109
103
|
of its corresponding entry node.
|
@@ -115,13 +109,10 @@ class PruningSection:
|
|
115
109
|
Returns:
|
116
110
|
bool: True if the channel counts match, False otherwise.
|
117
111
|
"""
|
118
|
-
|
119
|
-
entry_output_channel_axis
|
112
|
+
exit_input_channel_axis = exit_node.channel_axis.input
|
113
|
+
entry_output_channel_axis = corresponding_entry_node.channel_axis.output
|
120
114
|
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
exit_input_channels = exit_node.get_weights_by_keys(exit_node_attr).shape[exit_input_channel_axis]
|
125
|
-
entry_output_channels = corresponding_entry_node.get_weights_by_keys(entry_node_attr).shape[entry_output_channel_axis]
|
115
|
+
exit_input_channels = exit_node.get_weights_by_keys(exit_node.kernel_attr).shape[exit_input_channel_axis]
|
116
|
+
entry_output_channels = corresponding_entry_node.get_weights_by_keys(corresponding_entry_node.kernel_attr).shape[entry_output_channel_axis]
|
126
117
|
|
127
118
|
return exit_input_channels == entry_output_channels
|
@@ -15,6 +15,7 @@
|
|
15
15
|
from typing import Callable, List, Tuple
|
16
16
|
|
17
17
|
from model_compression_toolkit.core import QuantizationConfig
|
18
|
+
from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
|
18
19
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
|
19
20
|
NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
|
20
21
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig, \
|
@@ -40,7 +41,7 @@ class CandidateNodeQuantizationConfig(BaseNodeQuantizationConfig):
|
|
40
41
|
activation_quantization_fn: Callable = None,
|
41
42
|
activation_quantization_params_fn: Callable = None,
|
42
43
|
weights_quantization_cfg: NodeWeightsQuantizationConfig = None,
|
43
|
-
weights_channels_axis:
|
44
|
+
weights_channels_axis: ChannelAxisMapping = None,
|
44
45
|
node_attrs_list: List[str] = None):
|
45
46
|
"""
|
46
47
|
|
@@ -34,7 +34,7 @@ def filter_nodes_candidates(graph: Graph):
|
|
34
34
|
"""
|
35
35
|
nodes = list(graph.nodes)
|
36
36
|
for n in nodes:
|
37
|
-
n.candidates_quantization_cfg = filter_node_candidates(node=n
|
37
|
+
n.candidates_quantization_cfg = filter_node_candidates(node=n)
|
38
38
|
|
39
39
|
return graph
|
40
40
|
|
@@ -71,7 +71,7 @@ def _filter_bit_method_dups(candidates: List[CandidateNodeQuantizationConfig],
|
|
71
71
|
return final_candidates
|
72
72
|
|
73
73
|
|
74
|
-
def filter_node_candidates(node: BaseNode
|
74
|
+
def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConfig]:
|
75
75
|
"""
|
76
76
|
Updates a node's candidates configuration list.
|
77
77
|
If the node's weights quantization is disabled (or it only has activations to quantize), then the updated list
|
@@ -81,15 +81,13 @@ def filter_node_candidates(node: BaseNode, fw_info) -> List[CandidateNodeQuantiz
|
|
81
81
|
|
82
82
|
Args:
|
83
83
|
node: Node to set its quantization configurations.
|
84
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
85
84
|
|
86
85
|
"""
|
87
86
|
|
88
87
|
filtered_candidates = copy.deepcopy(node.candidates_quantization_cfg)
|
89
88
|
final_candidates = copy.deepcopy(node.candidates_quantization_cfg)
|
90
|
-
kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
|
91
89
|
|
92
|
-
if (kernel_attr is None or not node.is_weights_quantization_enabled(kernel_attr)) and not node.is_activation_quantization_enabled():
|
90
|
+
if (node.kernel_attr is None or not node.is_weights_quantization_enabled(node.kernel_attr)) and not node.is_activation_quantization_enabled():
|
93
91
|
# If activation quantization is disabled and the node doesn't have a kernel or doesn't quantize the kernel,
|
94
92
|
# but for some reason the node has multiple candidates then replace it with a single dummy candidate with
|
95
93
|
# default bit-width values.
|
@@ -97,8 +95,8 @@ def filter_node_candidates(node: BaseNode, fw_info) -> List[CandidateNodeQuantiz
|
|
97
95
|
single_dummy_candidate.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH
|
98
96
|
single_dummy_candidate.activation_quantization_cfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO
|
99
97
|
|
100
|
-
if kernel_attr is not None:
|
101
|
-
kernel_config = single_dummy_candidate.weights_quantization_cfg.get_attr_config(kernel_attr)
|
98
|
+
if node.kernel_attr is not None:
|
99
|
+
kernel_config = single_dummy_candidate.weights_quantization_cfg.get_attr_config(node.kernel_attr)
|
102
100
|
kernel_config.weights_n_bits = FLOAT_BITWIDTH
|
103
101
|
kernel_config.weights_quantization_method = QuantizationMethod.POWER_OF_TWO
|
104
102
|
|
@@ -116,9 +114,9 @@ def filter_node_candidates(node: BaseNode, fw_info) -> List[CandidateNodeQuantiz
|
|
116
114
|
c.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH
|
117
115
|
c.activation_quantization_cfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO
|
118
116
|
|
119
|
-
final_candidates = _filter_bit_method_dups(filtered_candidates, kernel_attr)
|
117
|
+
final_candidates = _filter_bit_method_dups(filtered_candidates, node.kernel_attr)
|
120
118
|
|
121
|
-
elif kernel_attr is None or not node.is_weights_quantization_enabled(kernel_attr):
|
119
|
+
elif node.kernel_attr is None or not node.is_weights_quantization_enabled(node.kernel_attr):
|
122
120
|
# TODO:
|
123
121
|
# To allow MP on positional weights we need to modify this to consider all weights not only kernel.
|
124
122
|
# Remove candidates that have duplicated activation candidates for node with disabled weights quantization.
|
@@ -129,11 +127,11 @@ def filter_node_candidates(node: BaseNode, fw_info) -> List[CandidateNodeQuantiz
|
|
129
127
|
and not seen_candidates.add(candidate.activation_quantization_cfg)]
|
130
128
|
|
131
129
|
for c in filtered_candidates:
|
132
|
-
if kernel_attr is not None:
|
133
|
-
kernel_config = c.weights_quantization_cfg.get_attr_config(kernel_attr)
|
130
|
+
if node.kernel_attr is not None:
|
131
|
+
kernel_config = c.weights_quantization_cfg.get_attr_config(node.kernel_attr)
|
134
132
|
kernel_config.weights_n_bits = FLOAT_BITWIDTH
|
135
133
|
kernel_config.weights_quantization_method = QuantizationMethod.POWER_OF_TWO
|
136
134
|
|
137
|
-
final_candidates = _filter_bit_method_dups(filtered_candidates, kernel_attr)
|
135
|
+
final_candidates = _filter_bit_method_dups(filtered_candidates, node.kernel_attr)
|
138
136
|
|
139
137
|
return final_candidates
|
@@ -18,6 +18,7 @@ from typing import Callable, Any, List, Tuple, Union, Dict, TYPE_CHECKING
|
|
18
18
|
from enum import Enum, auto
|
19
19
|
import numpy as np
|
20
20
|
|
21
|
+
from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
|
21
22
|
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_weights_quantization_fn
|
22
23
|
from model_compression_toolkit.logger import Logger
|
23
24
|
from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
|
@@ -262,7 +263,7 @@ class WeightsAttrQuantizationConfig:
|
|
262
263
|
def __init__(self,
|
263
264
|
qc: QuantizationConfig,
|
264
265
|
weights_attr_cfg: AttributeQuantizationConfig,
|
265
|
-
weights_channels_axis:
|
266
|
+
weights_channels_axis: ChannelAxisMapping = None):
|
266
267
|
"""
|
267
268
|
|
268
269
|
Args:
|
@@ -352,7 +353,7 @@ class WeightsAttrQuantizationConfig:
|
|
352
353
|
p=self.l_p_value,
|
353
354
|
n_bits=self.weights_n_bits,
|
354
355
|
per_channel=self.weights_per_channel_threshold and self.weights_channels_axis is not None,
|
355
|
-
channel_axis=self.weights_channels_axis
|
356
|
+
channel_axis=self.weights_channels_axis.output, # output channel axis
|
356
357
|
min_threshold=min_threshold)[0] # Take only first output, the q-params, as axis is already chosen.
|
357
358
|
)
|
358
359
|
else:
|
@@ -400,7 +401,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
400
401
|
"""
|
401
402
|
def __init__(self, qc: QuantizationConfig,
|
402
403
|
op_cfg: OpQuantizationConfig,
|
403
|
-
weights_channels_axis:
|
404
|
+
weights_channels_axis: ChannelAxisMapping,
|
404
405
|
node_attrs_list: List[str]):
|
405
406
|
"""
|
406
407
|
|
@@ -20,6 +20,7 @@ from typing import List, Callable, Generator
|
|
20
20
|
from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
|
21
21
|
from model_compression_toolkit.core import QuantizationErrorMethod
|
22
22
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
23
|
+
from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
|
23
24
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
24
25
|
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
|
25
26
|
HessianScoresGranularity
|
@@ -44,11 +45,8 @@ def _collect_nodes_for_hmse(nodes_list: List[BaseNode], graph: Graph) -> List[Ba
|
|
44
45
|
"""
|
45
46
|
hmse_nodes = []
|
46
47
|
for n in nodes_list:
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
if kernel_attr_name is not None and n.is_weights_quantization_enabled(kernel_attr_name) and \
|
51
|
-
all([c.weights_quantization_cfg.get_attr_config(kernel_attr_name).weights_error_method ==
|
48
|
+
if n.kernel_attr is not None and n.is_weights_quantization_enabled(n.kernel_attr) and \
|
49
|
+
all([c.weights_quantization_cfg.get_attr_config(n.kernel_attr).weights_error_method ==
|
52
50
|
QuantizationErrorMethod.HMSE for c in n.candidates_quantization_cfg]):
|
53
51
|
hmse_nodes.append(n)
|
54
52
|
|
@@ -114,11 +112,7 @@ def calculate_quantization_params(graph: Graph,
|
|
114
112
|
if attr_cfg.weights_error_method == QuantizationErrorMethod.HMSE:
|
115
113
|
# Although we collected nodes for HMSE before running the loop, we keep this verification to
|
116
114
|
# notify the user in case of HMSE configured for node that is not compatible for this method
|
117
|
-
|
118
|
-
if len(kernel_attr_name) > 0:
|
119
|
-
kernel_attr_name = kernel_attr_name[0]
|
120
|
-
|
121
|
-
if kernel_attr_name is None or kernel_attr_name not in attr:
|
115
|
+
if n.kernel_attr is None or n.kernel_attr not in attr:
|
122
116
|
Logger.warning(f"The HMSE error method for parameters selection is only supported for "
|
123
117
|
f"kernel weights attributes. Running parameters selection for attribute "
|
124
118
|
f"'{attr}' in node '{n.name}' with the default MSE error method instead.")
|
@@ -132,7 +126,7 @@ def calculate_quantization_params(graph: Graph,
|
|
132
126
|
node=n,
|
133
127
|
hessian_info_service=hessian_info_service,
|
134
128
|
num_hessian_samples=num_hessian_samples)
|
135
|
-
attr_cfg.weights_channels_axis = (output_channels_axis, attr_cfg.weights_channels_axis
|
129
|
+
attr_cfg.weights_channels_axis = ChannelAxisMapping(output_channels_axis, attr_cfg.weights_channels_axis.input)
|
136
130
|
attr_cfg.set_weights_quantization_param(weights_params)
|
137
131
|
|
138
132
|
if n.is_activation_quantization_enabled():
|
@@ -20,7 +20,7 @@ from model_compression_toolkit.constants import WEIGHTS, ACTIVATION
|
|
20
20
|
from model_compression_toolkit.core.common import BaseNode
|
21
21
|
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
|
22
22
|
from model_compression_toolkit.logger import Logger
|
23
|
-
from model_compression_toolkit.core.common.framework_info import
|
23
|
+
from model_compression_toolkit.core.common.framework_info import get_fw_info, ChannelAxisMapping
|
24
24
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
25
25
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
26
26
|
CandidateNodeQuantizationConfig
|
@@ -73,7 +73,6 @@ def set_quantization_configuration_to_graph(graph: Graph,
|
|
73
73
|
set_quantization_configs_to_node(node=n,
|
74
74
|
graph=graph,
|
75
75
|
quant_config=quant_config,
|
76
|
-
fw_info=graph.fw_info,
|
77
76
|
fqc=graph.fqc,
|
78
77
|
mixed_precision_enable=mixed_precision_enable,
|
79
78
|
manual_bit_width_override=manual_bit_width_override)
|
@@ -154,7 +153,6 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
154
153
|
def set_quantization_configs_to_node(node: BaseNode,
|
155
154
|
graph: Graph,
|
156
155
|
quant_config: QuantizationConfig,
|
157
|
-
fw_info: FrameworkInfo,
|
158
156
|
fqc: FrameworkQuantizationCapabilities,
|
159
157
|
mixed_precision_enable: bool = False,
|
160
158
|
manual_bit_width_override: Optional[Dict] = None):
|
@@ -165,7 +163,6 @@ def set_quantization_configs_to_node(node: BaseNode,
|
|
165
163
|
node (BaseNode): Node to set its quantization configurations.
|
166
164
|
graph (Graph): Model's internal representation graph.
|
167
165
|
quant_config (QuantizationConfig): Quantization configuration to generate the node's configurations from.
|
168
|
-
fw_info (FrameworkInfo): Information needed for quantization about the specific framework.
|
169
166
|
fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to get default OpQuantizationConfig.
|
170
167
|
mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
|
171
168
|
manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width. Defaults to None.
|
@@ -186,10 +183,8 @@ def set_quantization_configs_to_node(node: BaseNode,
|
|
186
183
|
mixed_precision_enable=mixed_precision_enable)
|
187
184
|
|
188
185
|
# Create QC candidates for weights and activation combined
|
189
|
-
weight_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
|
190
186
|
node.candidates_quantization_cfg = _create_node_candidates_qc(quant_config,
|
191
|
-
|
192
|
-
weight_channel_axis,
|
187
|
+
node.channel_axis,
|
193
188
|
node_qc_options_list,
|
194
189
|
base_config,
|
195
190
|
node,
|
@@ -198,7 +193,7 @@ def set_quantization_configs_to_node(node: BaseNode,
|
|
198
193
|
# sorting the candidates by kernel attribute weights number of bits first and then by activation number of bits
|
199
194
|
# (in reversed order). since only kernel attribute is quantized in weights mixed precision,
|
200
195
|
# if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
|
201
|
-
node.sort_node_candidates(
|
196
|
+
node.sort_node_candidates()
|
202
197
|
|
203
198
|
for candidate_qc in node.candidates_quantization_cfg:
|
204
199
|
if candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.QUANT and \
|
@@ -217,14 +212,12 @@ def set_quantization_configs_to_node(node: BaseNode,
|
|
217
212
|
|
218
213
|
|
219
214
|
def create_node_activation_qc(qc: QuantizationConfig,
|
220
|
-
fw_info: FrameworkInfo,
|
221
215
|
op_cfg: OpQuantizationConfig) -> NodeActivationQuantizationConfig:
|
222
216
|
"""
|
223
217
|
Create an activation quantization configuration from a QuantizationConfig object.
|
224
218
|
|
225
219
|
Args:
|
226
220
|
qc: QuantizationConfig to create the node's config from.
|
227
|
-
fw_info: Information about the specific framework the node was created from (e.g., whether or not its
|
228
221
|
weights/activations should be quantized)
|
229
222
|
op_cfg: OpQuantizationConfig with quantizers types to set in node quantization configuration.
|
230
223
|
|
@@ -232,7 +225,7 @@ def create_node_activation_qc(qc: QuantizationConfig,
|
|
232
225
|
Activation quantization configuration of a node.
|
233
226
|
"""
|
234
227
|
|
235
|
-
activation_quantization_fn =
|
228
|
+
activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
|
236
229
|
if activation_quantization_fn is None:
|
237
230
|
Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
|
238
231
|
|
@@ -245,8 +238,7 @@ def create_node_activation_qc(qc: QuantizationConfig,
|
|
245
238
|
|
246
239
|
|
247
240
|
def _create_node_single_candidate_qc(qc: QuantizationConfig,
|
248
|
-
|
249
|
-
weight_channel_axis: Tuple[int, int],
|
241
|
+
weight_channel_axis: ChannelAxisMapping,
|
250
242
|
op_cfg: OpQuantizationConfig,
|
251
243
|
node_attrs_list: List[str]) -> CandidateNodeQuantizationConfig:
|
252
244
|
"""
|
@@ -256,8 +248,6 @@ def _create_node_single_candidate_qc(qc: QuantizationConfig,
|
|
256
248
|
|
257
249
|
Args:
|
258
250
|
qc: QuantizationConfig to create the node's config from.
|
259
|
-
fw_info: Information about the specific framework the node was created from (e.g., whether its
|
260
|
-
weights/activations should be quantized)
|
261
251
|
weight_channel_axis: (Output, Input) channel index of the node's kernel.
|
262
252
|
op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
|
263
253
|
node_attrs_list: A list of the node's weights attributes names.
|
@@ -269,7 +259,7 @@ def _create_node_single_candidate_qc(qc: QuantizationConfig,
|
|
269
259
|
# parameters for weights attributes quantization are set within CandidateNodeQuantizationConfig initialization
|
270
260
|
|
271
261
|
# get parameters for activation quantization
|
272
|
-
activation_quantization_fn =
|
262
|
+
activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
|
273
263
|
if activation_quantization_fn is None:
|
274
264
|
Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
|
275
265
|
|
@@ -293,8 +283,7 @@ def _create_node_single_candidate_qc(qc: QuantizationConfig,
|
|
293
283
|
|
294
284
|
|
295
285
|
def _create_node_candidates_qc(qc: QuantizationConfig,
|
296
|
-
|
297
|
-
weight_channel_axis: Tuple[int, int],
|
286
|
+
weight_channel_axis: ChannelAxisMapping,
|
298
287
|
node_qc_options_list: List[OpQuantizationConfig],
|
299
288
|
base_config: OpQuantizationConfig,
|
300
289
|
node: BaseNode,
|
@@ -304,8 +293,7 @@ def _create_node_candidates_qc(qc: QuantizationConfig,
|
|
304
293
|
|
305
294
|
Args:
|
306
295
|
qc (QuantizationConfig): Quantization configuration the quantization process should follow.
|
307
|
-
|
308
|
-
weight_channel_axis (Tuple[int, int]): (Output, Input) channel index of the node's kernel.
|
296
|
+
weight_channel_axis (ChannelAxisMapping): (Output, Input) channel index of the node's kernel.
|
309
297
|
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs of node.
|
310
298
|
base_config (OpQuantizationConfig): Base quantization config for node.
|
311
299
|
node (BaseNode): A node to set quantization configuration candidates to.
|
@@ -322,14 +310,12 @@ def _create_node_candidates_qc(qc: QuantizationConfig,
|
|
322
310
|
for op_cfg in node_qc_options_list:
|
323
311
|
candidate_qc = copy.deepcopy(qc)
|
324
312
|
candidates.append(_create_node_single_candidate_qc(candidate_qc,
|
325
|
-
fw_info,
|
326
313
|
weight_channel_axis,
|
327
314
|
op_cfg,
|
328
315
|
node_attrs_list))
|
329
316
|
|
330
317
|
else:
|
331
318
|
candidates.append(_create_node_single_candidate_qc(qc,
|
332
|
-
fw_info,
|
333
319
|
weight_channel_axis,
|
334
320
|
base_config,
|
335
321
|
node_attrs_list))
|
@@ -38,8 +38,7 @@ def apply_activation_bias_correction_to_graph(graph: Graph,
|
|
38
38
|
|
39
39
|
for n in graph.nodes:
|
40
40
|
# Activation bias correction is only relevant for nodes with kernel op
|
41
|
-
|
42
|
-
if core_config.quantization_config.activation_bias_correction and kernel_attr is not None and \
|
41
|
+
if core_config.quantization_config.activation_bias_correction and n.kernel_attr is not None and \
|
43
42
|
n.final_activation_quantization_cfg.activation_bias_correction_term is not None:
|
44
43
|
# If activation bias correction is enabled in n.quantization_cfg, an activation bias correction term was
|
45
44
|
# calculated during model preparation, and is used now in the node's bias term.
|
model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py
CHANGED
@@ -41,9 +41,8 @@ def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
|
|
41
41
|
graph = copy.deepcopy(graph_to_apply_bias_correction)
|
42
42
|
for n in graph.nodes:
|
43
43
|
# bias correction is only relevant for nodes with kernel op
|
44
|
-
|
45
|
-
|
46
|
-
n.is_weights_quantization_enabled(kernel_attr) and \
|
44
|
+
if core_config.quantization_config.weights_bias_correction and n.kernel_attr is not None and \
|
45
|
+
n.is_weights_quantization_enabled(n.kernel_attr) and \
|
47
46
|
not n.final_weights_quantization_cfg.weights_second_moment_correction:
|
48
47
|
# If a kernel was quantized and weights bias correction is enabled in n.quantization_cfg,
|
49
48
|
# a bias correction term was calculated during model preparation, and is used now in the node's bias term.
|