mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -19,6 +19,7 @@ from model_compression_toolkit.core.common.pruning.pruning_framework_implementat
|
|
19
19
|
PruningFrameworkImplementation
|
20
20
|
from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
|
21
21
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
22
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
22
23
|
from model_compression_toolkit.core.common import BaseNode
|
23
24
|
from model_compression_toolkit.core.pytorch.constants import BIAS, GROUPS, OUT_CHANNELS, OUT_FEATURES, NUM_FEATURES, \
|
24
25
|
IN_CHANNELS, IN_FEATURES, NUM_PARAMETERS
|
@@ -29,10 +30,6 @@ import numpy as np
|
|
29
30
|
from model_compression_toolkit.logger import Logger
|
30
31
|
|
31
32
|
|
32
|
-
# default output channel axis to use when it's not defined in node's fw_info.
|
33
|
-
_default_output_channel_axis = 1
|
34
|
-
|
35
|
-
|
36
33
|
class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplementation):
|
37
34
|
"""
|
38
35
|
Implementation of the PruningFramework for the Pytorch framework. This class provides
|
@@ -42,23 +39,27 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
42
39
|
|
43
40
|
def prune_entry_node(self,
|
44
41
|
node: BaseNode,
|
45
|
-
output_mask: np.ndarray
|
42
|
+
output_mask: np.ndarray,
|
43
|
+
fw_info: FrameworkInfo):
|
46
44
|
"""
|
47
45
|
Prunes the entry node of a model in Pytorch.
|
48
46
|
|
49
47
|
Args:
|
50
48
|
node (BaseNode): The entry node to be pruned.
|
51
49
|
output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
|
50
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
52
51
|
|
53
52
|
"""
|
54
53
|
return _prune_pytorch_edge_node(node=node,
|
55
54
|
mask=output_mask,
|
55
|
+
fw_info=fw_info,
|
56
56
|
is_exit_node=False)
|
57
57
|
|
58
58
|
def prune_intermediate_node(self,
|
59
59
|
node: BaseNode,
|
60
60
|
input_mask: np.ndarray,
|
61
|
-
output_mask: np.ndarray
|
61
|
+
output_mask: np.ndarray,
|
62
|
+
fw_info: FrameworkInfo):
|
62
63
|
"""
|
63
64
|
Prunes an intermediate node in a Pytorch model.
|
64
65
|
|
@@ -66,11 +67,12 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
66
67
|
node (BaseNode): The intermediate node to be pruned.
|
67
68
|
input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
|
68
69
|
output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
|
70
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
69
71
|
|
70
72
|
"""
|
71
73
|
# TODO (reuvenp/liord): Address handling of node parameters that can be either a single value across all channels or distinct per channel, e.g., PReLU. Consider developing a structured approach.
|
72
74
|
pruning_en = True
|
73
|
-
_edit_node_input_shape(node, input_mask)
|
75
|
+
_edit_node_input_shape(node, input_mask, fw_info)
|
74
76
|
pruned_parameters = {}
|
75
77
|
mask_bool = output_mask.astype(bool)
|
76
78
|
node.weights = pruned_parameters
|
@@ -89,17 +91,20 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
89
91
|
|
90
92
|
def prune_exit_node(self,
|
91
93
|
node: BaseNode,
|
92
|
-
input_mask: np.ndarray
|
94
|
+
input_mask: np.ndarray,
|
95
|
+
fw_info: FrameworkInfo):
|
93
96
|
"""
|
94
97
|
Prunes the exit node of a model in Pytorch.
|
95
98
|
|
96
99
|
Args:
|
97
100
|
node (BaseNode): The exit node to be pruned.
|
98
101
|
input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
|
102
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
99
103
|
|
100
104
|
"""
|
101
105
|
return _prune_pytorch_edge_node(node=node,
|
102
106
|
mask=input_mask,
|
107
|
+
fw_info=fw_info,
|
103
108
|
is_exit_node=True)
|
104
109
|
|
105
110
|
def is_node_entry_node(self, node: BaseNode) -> bool:
|
@@ -116,19 +121,22 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
116
121
|
|
117
122
|
def is_node_exit_node(self,
|
118
123
|
node: BaseNode,
|
119
|
-
corresponding_entry_node: BaseNode
|
124
|
+
corresponding_entry_node: BaseNode,
|
125
|
+
fw_info: FrameworkInfo) -> bool:
|
120
126
|
"""
|
121
127
|
Determines whether a node is an exit node in a Pytorch model.
|
122
128
|
|
123
129
|
Args:
|
124
130
|
node (BaseNode): The node to be checked.
|
125
131
|
corresponding_entry_node (BaseNode): The entry node of the pruning section that is checked.
|
132
|
+
fw_info (FrameworkInfo) Framework-specific information object.
|
126
133
|
|
127
134
|
Returns:
|
128
135
|
bool: Boolean indicating if the node is an exit node.
|
129
136
|
"""
|
130
137
|
return _is_pytorch_node_pruning_section_edge(node) and PruningSection.has_matching_channel_count(node,
|
131
|
-
corresponding_entry_node
|
138
|
+
corresponding_entry_node,
|
139
|
+
fw_info)
|
132
140
|
|
133
141
|
def is_node_intermediate_pruning_section(self, node: BaseNode) -> bool:
|
134
142
|
"""
|
@@ -147,7 +155,8 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
147
155
|
torch.nn.Linear]
|
148
156
|
|
149
157
|
def attrs_oi_channels_info_for_pruning(self,
|
150
|
-
node: BaseNode
|
158
|
+
node: BaseNode,
|
159
|
+
fw_info: FrameworkInfo) -> Dict[str, Tuple[int, int]]:
|
151
160
|
"""
|
152
161
|
Retrieves the attributes of a given node along with the output/input (OI) channel axis
|
153
162
|
for each attribute used to prune these attributes.
|
@@ -164,6 +173,7 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
164
173
|
|
165
174
|
Args:
|
166
175
|
node (BaseNode): The node from the computational graph.
|
176
|
+
fw_info (FrameworkInfo): Contains framework-specific information and utilities.
|
167
177
|
|
168
178
|
Returns:
|
169
179
|
Dict[str, Tuple[int, int]]: A dictionary where each key is an attribute name (like 'weight' or 'bias')
|
@@ -171,8 +181,13 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
171
181
|
"""
|
172
182
|
|
173
183
|
attributes_with_axis = {}
|
174
|
-
if node.
|
175
|
-
|
184
|
+
if fw_info.is_kernel_op(node.type):
|
185
|
+
kernel_attributes = fw_info.get_kernel_op_attributes(node.type)
|
186
|
+
if kernel_attributes is None or len(kernel_attributes) == 0:
|
187
|
+
Logger.critical(f"Expected to find kernel attributes but none were identified for node '{node.name}' of type {node.type}.")
|
188
|
+
|
189
|
+
for attr in kernel_attributes:
|
190
|
+
attributes_with_axis[attr] = fw_info.kernel_channels_mapping.get(node.type)
|
176
191
|
|
177
192
|
# Bias is a vector at the length of the number of output channels.
|
178
193
|
# For this reason, input channel axis is irrelevant to the bias attribute.
|
@@ -187,17 +202,13 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
187
202
|
# If the number of float parameters is 1 or less - is the case where
|
188
203
|
# we have one parameter for all channels. For this case, we don't
|
189
204
|
# want to prune the parameter.
|
190
|
-
if node.get_num_parameters()[1] <= 1:
|
205
|
+
if node.get_num_parameters(fw_info)[1] <= 1:
|
191
206
|
attributes_with_axis[attr] = (None, None)
|
192
207
|
else:
|
193
208
|
attributes_with_axis[attr] = (-1, None)
|
194
209
|
|
195
210
|
return attributes_with_axis
|
196
211
|
|
197
|
-
@property
|
198
|
-
def default_output_channel_axis(self):
|
199
|
-
return _default_output_channel_axis
|
200
|
-
|
201
212
|
|
202
213
|
def _is_pytorch_node_pruning_section_edge(node: BaseNode) -> bool:
|
203
214
|
"""
|
@@ -223,6 +234,7 @@ def _is_pytorch_node_pruning_section_edge(node: BaseNode) -> bool:
|
|
223
234
|
|
224
235
|
def _prune_pytorch_edge_node(node: BaseNode,
|
225
236
|
mask: np.ndarray,
|
237
|
+
fw_info: FrameworkInfo,
|
226
238
|
is_exit_node: bool):
|
227
239
|
"""
|
228
240
|
Prunes the given Pytorch node by applying the mask to the node's weights (weights and biases).
|
@@ -231,18 +243,21 @@ def _prune_pytorch_edge_node(node: BaseNode,
|
|
231
243
|
Args:
|
232
244
|
node (BaseNode): The node to be pruned.
|
233
245
|
mask (np.ndarray): The pruning mask to be applied.
|
246
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
234
247
|
is_exit_node (bool): A boolean indicating whether the node is an exit node.
|
235
248
|
|
236
249
|
"""
|
237
250
|
|
238
251
|
# Retrieve the kernel attribute and the axes to prune.
|
239
|
-
|
240
|
-
|
252
|
+
kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
|
253
|
+
io_axis = fw_info.kernel_channels_mapping.get(node.type)
|
254
|
+
axis_to_prune = io_axis[int(is_exit_node)]
|
255
|
+
kernel = node.get_weights_by_keys(kernel_attr)
|
241
256
|
# Convert mask to boolean.
|
242
257
|
mask_bool = mask.astype(bool)
|
243
258
|
|
244
259
|
pruned_kernel = kernel.compress(mask_bool, axis=axis_to_prune)
|
245
|
-
node.set_weights_by_keys(name=
|
260
|
+
node.set_weights_by_keys(name=kernel_attr, tensor=pruned_kernel)
|
246
261
|
|
247
262
|
if not is_exit_node and node.framework_attr[BIAS]:
|
248
263
|
# Prune the bias if applicable and it's an entry node.
|
@@ -270,11 +285,12 @@ def _prune_pytorch_edge_node(node: BaseNode,
|
|
270
285
|
Logger.critical(f"{node.type} is currently not supported"
|
271
286
|
f"as an edge node in a pruning section")
|
272
287
|
# Adjust the input shape for the last node in the section.
|
273
|
-
_edit_node_input_shape(node, mask_bool)
|
288
|
+
_edit_node_input_shape(node, mask_bool, fw_info)
|
274
289
|
|
275
290
|
|
276
291
|
def _edit_node_input_shape(node: BaseNode,
|
277
|
-
input_mask: np.ndarray
|
292
|
+
input_mask: np.ndarray,
|
293
|
+
fw_info: FrameworkInfo):
|
278
294
|
"""
|
279
295
|
Adjusts the input shape of a node based on the given input mask.
|
280
296
|
|
@@ -285,13 +301,14 @@ def _edit_node_input_shape(node: BaseNode,
|
|
285
301
|
Args:
|
286
302
|
node (BaseNode): The node whose input shape needs to be adjusted.
|
287
303
|
input_mask (np.ndarray): A binary array where 1 indicates the channel is kept and 0 means pruned.
|
304
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
288
305
|
"""
|
289
306
|
# Start with the current input shape of the node.
|
290
307
|
new_input_shape = list(node.input_shape)
|
291
308
|
|
292
309
|
# Adjust the last dimension of the shape to match the number of unpruned (retained) channels.
|
293
310
|
# This is done by summing the mask, as each '1' in the mask represents a retained channel.
|
294
|
-
channel_axis =
|
311
|
+
channel_axis = fw_info.out_channel_axis_mapping.get(node.type)
|
295
312
|
new_input_shape[0][channel_axis] = int(np.sum(input_mask))
|
296
313
|
|
297
314
|
# Update the node's input shape with the new dimensions.
|
@@ -26,7 +26,7 @@ from torch.nn import Module, Sigmoid, Softmax
|
|
26
26
|
|
27
27
|
import model_compression_toolkit.core.pytorch.constants as pytorch_constants
|
28
28
|
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
|
29
|
-
from model_compression_toolkit.core import QuantizationConfig, CoreConfig
|
29
|
+
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig
|
30
30
|
from model_compression_toolkit.core import common
|
31
31
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
32
32
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
@@ -37,6 +37,7 @@ from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
|
|
37
37
|
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse, compute_kl_divergence, compute_cs
|
38
38
|
from model_compression_toolkit.core.pytorch.back2framework import get_pytorch_model_builder
|
39
39
|
from model_compression_toolkit.core.pytorch.data_util import data_gen_to_dataloader
|
40
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
40
41
|
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.batchnorm_folding import \
|
41
42
|
pytorch_batchnorm_folding, pytorch_batchnorm_forward_folding
|
42
43
|
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.batchnorm_reconstruction import \
|
@@ -177,6 +178,7 @@ class PytorchImplementation(FrameworkImplementation):
|
|
177
178
|
graph: Graph,
|
178
179
|
mode: ModelBuilderMode,
|
179
180
|
append2output: List[Any] = None,
|
181
|
+
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
180
182
|
return_float_outputs: bool = False) -> Tuple:
|
181
183
|
"""
|
182
184
|
Build a Pytorch module from a graph.
|
@@ -187,6 +189,7 @@ class PytorchImplementation(FrameworkImplementation):
|
|
187
189
|
graph: Graph to build the module from it.
|
188
190
|
mode: Mode for how to build the module.
|
189
191
|
append2output: List of Nodes to set as the module's outputs.
|
192
|
+
fw_info: FrameworkInfo object with information about the specific framework's module
|
190
193
|
return_float_outputs (bool): whether to return outputs before or after quantization nodes (default)
|
191
194
|
|
192
195
|
Returns:
|
@@ -195,6 +198,7 @@ class PytorchImplementation(FrameworkImplementation):
|
|
195
198
|
pytorch_model_builder = get_pytorch_model_builder(mode)
|
196
199
|
return pytorch_model_builder(graph=graph,
|
197
200
|
append2output=append2output,
|
201
|
+
fw_info=fw_info,
|
198
202
|
return_float_outputs=return_float_outputs).build_model()
|
199
203
|
|
200
204
|
def run_model_inference(self,
|
@@ -228,55 +232,63 @@ class PytorchImplementation(FrameworkImplementation):
|
|
228
232
|
|
229
233
|
def shift_negative_correction(self,
|
230
234
|
graph: Graph,
|
231
|
-
core_config: CoreConfig
|
235
|
+
core_config: CoreConfig,
|
236
|
+
fw_info: FrameworkInfo) -> Graph:
|
232
237
|
"""
|
233
238
|
Apply shift negative correction (SNC) on a graph.
|
234
239
|
|
235
240
|
Args:
|
236
241
|
graph: Graph to apply SNC on.
|
237
242
|
core_config: Quantization configuration.
|
243
|
+
fw_info: FrameworkInfo object with information about the specific framework's module.
|
238
244
|
|
239
245
|
Returns:
|
240
246
|
Graph after SNC.
|
241
247
|
"""
|
242
248
|
return pytorch_apply_shift_negative_correction(graph,
|
243
|
-
core_config
|
249
|
+
core_config,
|
250
|
+
fw_info)
|
244
251
|
|
245
252
|
def compute_activation_bias_correction(self,
|
246
253
|
graph: Graph,
|
247
|
-
quant_config: QuantizationConfig
|
254
|
+
quant_config: QuantizationConfig,
|
255
|
+
fw_info: FrameworkInfo):
|
248
256
|
"""
|
249
257
|
Compute activation bias correction on a graph.
|
250
258
|
|
251
259
|
Args:
|
252
260
|
graph: Graph to apply activation bias correction on.
|
253
261
|
quant_config: QuantizationConfig of how the model should be quantized.
|
262
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
254
263
|
|
255
264
|
Returns:
|
256
265
|
Graph after activation bias correction computing.
|
257
266
|
"""
|
258
267
|
return pytorch_compute_activation_bias_correction_of_graph(graph=graph,
|
259
268
|
quant_config=quant_config,
|
269
|
+
fw_info=fw_info,
|
260
270
|
fw_impl=self)
|
261
271
|
|
262
272
|
def get_substitutions_channel_equalization(self,
|
263
|
-
quant_config: QuantizationConfig
|
273
|
+
quant_config: QuantizationConfig,
|
274
|
+
fw_info: FrameworkInfo) -> List[common.BaseSubstitution]:
|
264
275
|
"""
|
265
276
|
Return a list of the framework substitutions used for channel equalization.
|
266
277
|
|
267
278
|
Args:
|
268
279
|
quant_config: QuantizationConfig to determine which substitutions to return.
|
280
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
269
281
|
|
270
282
|
Returns:
|
271
283
|
A list of the framework substitutions used after we collect statistics.
|
272
284
|
"""
|
273
285
|
substitutions_list = []
|
274
286
|
if quant_config.activation_channel_equalization:
|
275
|
-
substitutions_list.extend([ScaleEqualization(quant_config),
|
276
|
-
ScaleEqualizationWithPad(quant_config)])
|
287
|
+
substitutions_list.extend([ScaleEqualization(quant_config, fw_info),
|
288
|
+
ScaleEqualizationWithPad(quant_config, fw_info)])
|
277
289
|
return substitutions_list
|
278
290
|
|
279
|
-
def get_substitutions_prepare_graph(self) -> List[common.BaseSubstitution]:
|
291
|
+
def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List[common.BaseSubstitution]:
|
280
292
|
"""
|
281
293
|
|
282
294
|
Returns: A list of the framework substitutions used before we collect the prior information.
|
@@ -287,7 +299,7 @@ class PytorchImplementation(FrameworkImplementation):
|
|
287
299
|
ScaledDotProductDecomposition(),
|
288
300
|
MatMulDecomposition(),
|
289
301
|
TransformFunctionCallMethod(),
|
290
|
-
FunctionalConvSubstitution(),
|
302
|
+
FunctionalConvSubstitution(fw_info),
|
291
303
|
FunctionalBatchNorm(),
|
292
304
|
FunctionalLayerNorm(),
|
293
305
|
FunctionalLinear(),
|
@@ -389,17 +401,20 @@ class PytorchImplementation(FrameworkImplementation):
|
|
389
401
|
|
390
402
|
def get_node_prior_info(self,
|
391
403
|
node: BaseNode,
|
404
|
+
fw_info: FrameworkInfo,
|
392
405
|
graph: Graph) -> NodePriorInfo:
|
393
406
|
"""
|
394
407
|
Get a NodePriorInfo object for a node that represents a Pytorch layer.
|
395
408
|
Args:
|
396
409
|
node: Node to get its prior info.
|
410
|
+
fw_info: Framework specific information needed to create the prior info of the node.
|
397
411
|
graph: Graph to check the next node type.
|
398
412
|
Returns:
|
399
413
|
NodePriorInfo with information about the node.
|
400
414
|
"""
|
401
415
|
|
402
416
|
return create_node_prior_info(node=node,
|
417
|
+
fw_info=fw_info,
|
403
418
|
graph=graph)
|
404
419
|
|
405
420
|
def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
|
@@ -461,19 +476,23 @@ class PytorchImplementation(FrameworkImplementation):
|
|
461
476
|
return node.layer_class not in [argmax, softmax, Softmax]
|
462
477
|
|
463
478
|
def get_node_mac_operations(self,
|
464
|
-
node: BaseNode
|
479
|
+
node: BaseNode,
|
480
|
+
fw_info: FrameworkInfo) -> float:
|
465
481
|
"""
|
466
482
|
Gets the MAC operation count for a given operation.
|
467
483
|
|
468
484
|
Args:
|
469
485
|
node: A graph node that wraps the operation for which the MAC count is computed.
|
486
|
+
fw_info: FrameworkInfo object with information about the Pytorch model.
|
470
487
|
|
471
488
|
Returns: The MAC count of the operation
|
472
489
|
"""
|
473
|
-
|
490
|
+
kernels = fw_info.get_kernel_op_attributes(node.type)
|
491
|
+
if not kernels or kernels[0] is None:
|
474
492
|
return 0
|
475
493
|
|
476
|
-
|
494
|
+
assert len(kernels) == 1
|
495
|
+
kernel_shape = node.get_weights_by_keys(kernels[0]).shape
|
477
496
|
|
478
497
|
if node.is_match_type(Conv2d) or node.is_match_type(ConvTranspose2d):
|
479
498
|
h, w = node.get_output_shapes_list()[0][-2:]
|
@@ -481,7 +500,8 @@ class PytorchImplementation(FrameworkImplementation):
|
|
481
500
|
|
482
501
|
if node.is_match_type(Linear):
|
483
502
|
# IN * OUT * (all previous dims[:-1])
|
484
|
-
|
503
|
+
_, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
|
504
|
+
return node.get_total_output_params() * kernel_shape[input_channel_axis]
|
485
505
|
|
486
506
|
return 0
|
487
507
|
|
@@ -23,19 +23,23 @@ from model_compression_toolkit.core.pytorch.constants import MOVING_MEAN, MOVING
|
|
23
23
|
|
24
24
|
|
25
25
|
def create_node_prior_info(node: BaseNode,
|
26
|
+
fw_info: FrameworkInfo,
|
26
27
|
graph: Graph):
|
27
28
|
"""
|
28
29
|
Create a NodePriorInfo object for a given node.
|
29
30
|
|
30
31
|
Args:
|
31
32
|
node: Node to create its prior info.
|
33
|
+
fw_info: Information about a specific framework the node was generated from.
|
32
34
|
graph: Graph to check the next node type.
|
33
35
|
|
34
36
|
Returns:
|
35
37
|
NodePriorInfo object with info about the node.
|
36
38
|
"""
|
37
39
|
|
38
|
-
min_output, max_output =
|
40
|
+
min_output, max_output = None, None
|
41
|
+
if fw_info.layers_has_min_max(node.type):
|
42
|
+
min_output, max_output = fw_info.layer_min_max_mapping[node.type]
|
39
43
|
mean_output, std_output = _get_mean_std_outputs(node=node,
|
40
44
|
graph=graph)
|
41
45
|
return NodePriorInfo(min_output=min_output,
|
@@ -27,7 +27,7 @@ from model_compression_toolkit.target_platform_capabilities.tpc_io_handler impor
|
|
27
27
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
28
28
|
|
29
29
|
if FOUND_TORCH:
|
30
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import
|
30
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
31
31
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
32
32
|
from torch.nn import Module
|
33
33
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
|
@@ -38,7 +38,6 @@ if FOUND_TORCH:
|
|
38
38
|
PYTORCH_DEFAULT_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
39
39
|
|
40
40
|
|
41
|
-
@set_pytorch_info
|
42
41
|
def pytorch_resource_utilization_data(in_model: Module,
|
43
42
|
representative_data_gen: Callable,
|
44
43
|
core_config: CoreConfig = CoreConfig(),
|
@@ -94,6 +93,7 @@ if FOUND_TORCH:
|
|
94
93
|
representative_data_gen,
|
95
94
|
core_config,
|
96
95
|
target_platform_capabilities,
|
96
|
+
DEFAULT_PYTORCH_INFO,
|
97
97
|
fw_impl)
|
98
98
|
|
99
99
|
else:
|
@@ -18,7 +18,7 @@ from torch.nn import Conv2d, Linear, ConvTranspose2d
|
|
18
18
|
from model_compression_toolkit.core import QuantizationConfig
|
19
19
|
from model_compression_toolkit.core.common import Graph
|
20
20
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
21
|
-
from model_compression_toolkit.core.
|
21
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
22
22
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
23
23
|
from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \
|
24
24
|
compute_activation_bias_correction_of_graph
|
@@ -33,6 +33,7 @@ def activation_bias_correction_node_matchers():
|
|
33
33
|
|
34
34
|
def pytorch_compute_activation_bias_correction_of_graph(graph: Graph,
|
35
35
|
quant_config: QuantizationConfig,
|
36
|
+
fw_info: FrameworkInfo,
|
36
37
|
fw_impl: FrameworkImplementation) -> Graph:
|
37
38
|
"""
|
38
39
|
Compute the activation bias correction term for graph based on a PyTorch model.
|
@@ -40,6 +41,7 @@ def pytorch_compute_activation_bias_correction_of_graph(graph: Graph,
|
|
40
41
|
Args:
|
41
42
|
graph: Graph with nodes to compute the activation bias correction.
|
42
43
|
quant_config: QuantizationConfig of how the model should be quantized.
|
44
|
+
fw_info: Framework info like lists of nodes their kernel should quantized.
|
43
45
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
44
46
|
|
45
47
|
Returns:
|
@@ -47,9 +49,9 @@ def pytorch_compute_activation_bias_correction_of_graph(graph: Graph,
|
|
47
49
|
"""
|
48
50
|
graph = compute_activation_bias_correction_of_graph(graph=graph,
|
49
51
|
quant_config=quant_config,
|
52
|
+
fw_info=fw_info,
|
50
53
|
fw_impl=fw_impl,
|
51
54
|
activation_bias_correction_node_matchers=
|
52
55
|
activation_bias_correction_node_matchers,
|
53
|
-
kernel_size=KERNEL_SIZE
|
54
|
-
get_activation_quantization_fn_factory=get_activation_quantization_fn_factory)
|
56
|
+
kernel_size=KERNEL_SIZE)
|
55
57
|
return graph
|
@@ -37,6 +37,7 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
|
|
37
37
|
def quantization_preparation_runner(graph: Graph,
|
38
38
|
representative_data_gen: Callable,
|
39
39
|
core_config: CoreConfig,
|
40
|
+
fw_info: FrameworkInfo,
|
40
41
|
fw_impl: FrameworkImplementation,
|
41
42
|
tb_w: TensorboardWriter = None,
|
42
43
|
hessian_info_service: HessianInfoService = None, ) -> Graph:
|
@@ -52,6 +53,8 @@ def quantization_preparation_runner(graph: Graph,
|
|
52
53
|
graph: A graph representation of the model to be quantized.
|
53
54
|
representative_data_gen: Dataset used for calibration.
|
54
55
|
core_config: CoreConfig containing parameters of how the model should be quantized
|
56
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
57
|
+
groups of layers by how they should be quantized, etc.).
|
55
58
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
56
59
|
tb_w: TensorboardWriter object for logging
|
57
60
|
hessian_info_service: HessianInfoService object for retrieving Hessian-based scores.
|
@@ -65,6 +68,7 @@ def quantization_preparation_runner(graph: Graph,
|
|
65
68
|
######################################
|
66
69
|
mi = ModelCollector(graph,
|
67
70
|
fw_impl,
|
71
|
+
fw_info,
|
68
72
|
hessian_info_service,
|
69
73
|
core_config.quantization_config) # Mark points for statistics collection
|
70
74
|
|
@@ -81,14 +85,14 @@ def quantization_preparation_runner(graph: Graph,
|
|
81
85
|
# Notice that not all actions affect at this stage (for example, actions that edit the final configuration as
|
82
86
|
# there are no final configurations at this stage of the optimization). For this reason we edit the graph
|
83
87
|
# again at the end of the optimization process.
|
84
|
-
edit_network_graph(graph, core_config.debug_config.network_editor)
|
88
|
+
edit_network_graph(graph, fw_info, core_config.debug_config.network_editor)
|
85
89
|
|
86
90
|
######################################
|
87
91
|
# Calculate quantization params
|
88
92
|
######################################
|
89
93
|
|
90
|
-
calculate_quantization_params(graph,
|
91
|
-
|
94
|
+
calculate_quantization_params(graph, fw_impl=fw_impl, repr_data_gen_fn=representative_data_gen,
|
95
|
+
hessian_info_service=hessian_info_service)
|
92
96
|
|
93
97
|
if tb_w is not None:
|
94
98
|
tb_w.add_graph(graph, 'thresholds_selection')
|
@@ -105,7 +109,8 @@ def quantization_preparation_runner(graph: Graph,
|
|
105
109
|
######################################
|
106
110
|
if core_config.quantization_config.shift_negative_activation_correction:
|
107
111
|
transformed_graph = fw_impl.shift_negative_correction(transformed_graph,
|
108
|
-
core_config
|
112
|
+
core_config,
|
113
|
+
fw_info)
|
109
114
|
if tb_w is not None:
|
110
115
|
tb_w.add_graph(transformed_graph, 'after_shift_negative_correction')
|
111
116
|
tb_w.add_all_statistics(transformed_graph, 'after_shift_negative_correction')
|
@@ -117,9 +122,9 @@ def quantization_preparation_runner(graph: Graph,
|
|
117
122
|
######################################
|
118
123
|
# Statistics Correction
|
119
124
|
######################################
|
120
|
-
tg_with_bias = statistics_correction_runner(transformed_graph, core_config, fw_impl, tb_w)
|
125
|
+
tg_with_bias = statistics_correction_runner(transformed_graph, core_config, fw_info, fw_impl, tb_w)
|
121
126
|
|
122
127
|
for n in tg_with_bias.nodes:
|
123
128
|
assert n.final_weights_quantization_cfg is None
|
124
129
|
|
125
|
-
return tg_with_bias
|
130
|
+
return tg_with_bias
|
@@ -16,6 +16,7 @@
|
|
16
16
|
import copy
|
17
17
|
from typing import Callable, Any, List, Optional
|
18
18
|
|
19
|
+
from model_compression_toolkit.core.common import FrameworkInfo
|
19
20
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
20
21
|
from model_compression_toolkit.core.common.fusion.graph_fuser import GraphFuser
|
21
22
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
@@ -45,6 +46,7 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
|
|
45
46
|
def core_runner(in_model: Any,
|
46
47
|
representative_data_gen: Callable,
|
47
48
|
core_config: CoreConfig,
|
49
|
+
fw_info: FrameworkInfo,
|
48
50
|
fw_impl: FrameworkImplementation,
|
49
51
|
fqc: FrameworkQuantizationCapabilities,
|
50
52
|
target_resource_utilization: ResourceUtilization = None,
|
@@ -63,6 +65,7 @@ def core_runner(in_model: Any,
|
|
63
65
|
in_model: Model to quantize.
|
64
66
|
representative_data_gen: Dataset used for calibration.
|
65
67
|
core_config: CoreConfig containing parameters of how the model should be quantized
|
68
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
66
69
|
groups of layers by how they should be quantized, etc.).
|
67
70
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
68
71
|
fqc: FrameworkQuantizationCapabilities object that models the inference target platform and
|
@@ -96,6 +99,7 @@ def core_runner(in_model: Any,
|
|
96
99
|
graph = graph_preparation_runner(in_model,
|
97
100
|
representative_data_gen,
|
98
101
|
core_config.quantization_config,
|
102
|
+
fw_info,
|
99
103
|
fw_impl,
|
100
104
|
fqc,
|
101
105
|
core_config.bit_width_config,
|
@@ -108,6 +112,7 @@ def core_runner(in_model: Any,
|
|
108
112
|
tg = quantization_preparation_runner(graph=graph,
|
109
113
|
representative_data_gen=representative_data_gen,
|
110
114
|
core_config=core_config,
|
115
|
+
fw_info=fw_info,
|
111
116
|
fw_impl=fw_impl,
|
112
117
|
tb_w=tb_w,
|
113
118
|
hessian_info_service=hessian_info_service)
|
@@ -118,8 +123,9 @@ def core_runner(in_model: Any,
|
|
118
123
|
if core_config.is_mixed_precision_enabled:
|
119
124
|
if core_config.mixed_precision_config.configuration_overwrite is None:
|
120
125
|
|
121
|
-
filter_candidates_for_mixed_precision(graph, target_resource_utilization)
|
126
|
+
filter_candidates_for_mixed_precision(graph, target_resource_utilization, fw_info, fqc)
|
122
127
|
bit_widths_config = search_bit_width(tg,
|
128
|
+
fw_info,
|
123
129
|
fw_impl,
|
124
130
|
target_resource_utilization,
|
125
131
|
core_config.mixed_precision_config,
|
@@ -147,20 +153,22 @@ def core_runner(in_model: Any,
|
|
147
153
|
######################################
|
148
154
|
if core_config.quantization_config.activation_bias_correction:
|
149
155
|
tg = fw_impl.compute_activation_bias_correction(graph=tg,
|
150
|
-
quant_config=core_config.quantization_config
|
156
|
+
quant_config=core_config.quantization_config,
|
157
|
+
fw_info=fw_info)
|
151
158
|
|
152
159
|
# Edit the graph again after finalizing the configurations.
|
153
160
|
# This is since some actions regard the final configuration and should be edited.
|
154
|
-
edit_network_graph(tg, core_config.debug_config.network_editor)
|
161
|
+
edit_network_graph(tg, fw_info, core_config.debug_config.network_editor)
|
155
162
|
|
156
163
|
_set_final_resource_utilization(graph=tg,
|
157
164
|
final_bit_widths_config=bit_widths_config,
|
158
165
|
target_resource_utilization=target_resource_utilization,
|
166
|
+
fw_info=fw_info,
|
159
167
|
fw_impl=fw_impl)
|
160
168
|
|
161
169
|
if core_config.is_mixed_precision_enabled:
|
162
170
|
# Retrieve lists of tuples (node, node's final weights/activation bitwidth)
|
163
|
-
weights_conf_nodes_bitwidth = tg.get_final_weights_config()
|
171
|
+
weights_conf_nodes_bitwidth = tg.get_final_weights_config(fw_info)
|
164
172
|
activation_conf_nodes_bitwidth = tg.get_final_activation_config()
|
165
173
|
|
166
174
|
if len(weights_conf_nodes_bitwidth) > 0:
|
@@ -192,6 +200,7 @@ def core_runner(in_model: Any,
|
|
192
200
|
def _set_final_resource_utilization(graph: Graph,
|
193
201
|
final_bit_widths_config: List[int],
|
194
202
|
target_resource_utilization: Optional[ResourceUtilization],
|
203
|
+
fw_info: FrameworkInfo,
|
195
204
|
fw_impl: FrameworkImplementation):
|
196
205
|
"""
|
197
206
|
Computing the resource utilization of the model according to the final bit-width configuration,
|
@@ -201,13 +210,14 @@ def _set_final_resource_utilization(graph: Graph,
|
|
201
210
|
graph: Graph to compute the resource utilization for.
|
202
211
|
final_bit_widths_config: The final bit-width configuration to quantize the model accordingly.
|
203
212
|
target_resource_utilization: Requested target resource utilization if relevant.
|
213
|
+
fw_info: A FrameworkInfo object.
|
204
214
|
fw_impl: FrameworkImplementation object with specific framework methods implementation.
|
205
215
|
|
206
216
|
"""
|
207
217
|
ru_targets = target_resource_utilization.get_restricted_targets() if target_resource_utilization else None
|
208
218
|
final_ru = None
|
209
219
|
if ru_targets:
|
210
|
-
ru_calculator = ResourceUtilizationCalculator(graph, fw_impl)
|
220
|
+
ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
|
211
221
|
w_qcs = {n.name: n.final_weights_quantization_cfg for n in graph.nodes}
|
212
222
|
a_qcs = {n.name: n.final_activation_quantization_cfg for n in graph.nodes}
|
213
223
|
final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantizedNonFused,
|