mct-nightly 2.4.0.20250924.535__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250924.535.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -20,9 +20,15 @@ from typing import Callable
|
|
20
20
|
from mct_quantizers import QuantizationMethod
|
21
21
|
from model_compression_toolkit.core.common import Graph
|
22
22
|
from model_compression_toolkit.logger import Logger
|
23
|
-
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
24
23
|
|
25
24
|
|
25
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
26
|
+
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
27
|
+
from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
|
28
|
+
get_activation_quantization_params_fn, get_weights_quantization_params_fn
|
29
|
+
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
|
30
|
+
get_weights_quantization_fn
|
31
|
+
|
26
32
|
_EditRule = namedtuple('EditRule', 'filter action')
|
27
33
|
|
28
34
|
|
@@ -58,13 +64,15 @@ class BaseAction(ABC):
|
|
58
64
|
"""
|
59
65
|
|
60
66
|
@abstractmethod
|
61
|
-
def apply(self, node: BaseNode, graph):
|
67
|
+
def apply(self, node: BaseNode, graph, fw_info):
|
62
68
|
"""
|
63
69
|
Apply an action on the node after matching the node with a node filter.
|
64
70
|
|
65
71
|
Args:
|
66
72
|
node: Node to apply the action on.
|
67
73
|
graph: Graph to apply the action on.
|
74
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
75
|
+
groups of layers by how they should be quantized, etc.)
|
68
76
|
|
69
77
|
Returns:
|
70
78
|
Node after action is applied.
|
@@ -87,13 +95,15 @@ class ChangeCandidatesWeightsQuantConfigAttr(BaseAction):
|
|
87
95
|
self.kwargs = kwargs
|
88
96
|
self.attr_name = attr_name
|
89
97
|
|
90
|
-
def apply(self, node: BaseNode, graph):
|
98
|
+
def apply(self, node: BaseNode, graph, fw_info):
|
91
99
|
"""
|
92
100
|
Change the attribute 'attr_name' in weights quantization config candidates with 'attr_value'.
|
93
101
|
|
94
102
|
Args:
|
95
103
|
node: Node object to change its quant_config.
|
96
104
|
graph: Graph to apply the action on.
|
105
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
106
|
+
groups of layers by how they should be quantized, etc.)
|
97
107
|
Returns:
|
98
108
|
The node after its weights' quantization config candidates have been modified.
|
99
109
|
"""
|
@@ -118,7 +128,7 @@ class ChangeFinalWeightsQuantConfigAttr(BaseAction):
|
|
118
128
|
self.kwargs = kwargs
|
119
129
|
self.attr_name = attr_name
|
120
130
|
|
121
|
-
def apply(self, node: BaseNode, graph):
|
131
|
+
def apply(self, node: BaseNode, graph, fw_info):
|
122
132
|
if node.final_weights_quantization_cfg is not None:
|
123
133
|
for parameter_name, parameter_value in self.kwargs.items():
|
124
134
|
node.final_weights_quantization_cfg.set_quant_config_attr(parameter_name, parameter_value,
|
@@ -137,13 +147,17 @@ class ChangeCandidatesActivationQuantConfigAttr(BaseAction):
|
|
137
147
|
"""
|
138
148
|
self.kwargs = kwargs
|
139
149
|
|
140
|
-
def apply(self, node: BaseNode, graph):
|
150
|
+
def apply(self, node: BaseNode, graph, fw_info):
|
141
151
|
"""
|
142
152
|
Change the attribute 'attr_name' in activation quantization configuration candidates with 'attr_value'.
|
143
153
|
|
144
154
|
Args:
|
145
155
|
node: Node object to change its quant_config.
|
146
156
|
graph: Graph to apply the action on.
|
157
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
158
|
+
groups of layers by how they should be quantized, etc.)
|
159
|
+
Returns:q
|
160
|
+
The node after its activation quantization configuration candidates have been modified.
|
147
161
|
"""
|
148
162
|
for nqc in node.candidates_quantization_cfg:
|
149
163
|
for parameter_name, parameter_value in self.kwargs.items():
|
@@ -162,12 +176,55 @@ class ChangeFinalActivationQuantConfigAttr(BaseAction):
|
|
162
176
|
"""
|
163
177
|
self.kwargs = kwargs
|
164
178
|
|
165
|
-
def apply(self, node: BaseNode, graph):
|
179
|
+
def apply(self, node: BaseNode, graph, fw_info):
|
166
180
|
if node.final_activation_quantization_cfg is not None:
|
167
181
|
for parameter_name, parameter_value in self.kwargs.items():
|
168
182
|
node.final_activation_quantization_cfg.set_quant_config_attr(parameter_name, parameter_value)
|
169
183
|
|
170
184
|
|
185
|
+
class ChangeQuantizationParamFunction(BaseAction):
|
186
|
+
"""
|
187
|
+
Class ChangeQuantizationParamFunction to change a node's weights/activations quantization params function.
|
188
|
+
"""
|
189
|
+
|
190
|
+
def __init__(self,
|
191
|
+
attr_name: str = None,
|
192
|
+
activation_quantization_params_fn: Callable = None,
|
193
|
+
weights_quantization_params_fn: Callable = None):
|
194
|
+
"""
|
195
|
+
Init a ChangeQuantizationParamFunction object.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
attr_name: The weights attribute's name to set the weights quantization params function for (if setting weights params).
|
199
|
+
activation_quantization_params_fn: a params function for a node's activations.
|
200
|
+
weights_quantization_params_fn: a params function for a node's weights.
|
201
|
+
"""
|
202
|
+
self.activation_quantization_params_fn = activation_quantization_params_fn
|
203
|
+
self.weights_quantization_params_fn = weights_quantization_params_fn
|
204
|
+
self.attr_name = attr_name
|
205
|
+
|
206
|
+
def apply(self, node: BaseNode, graph, fw_info):
|
207
|
+
"""
|
208
|
+
Change the node's weights/activations quantization params function.
|
209
|
+
|
210
|
+
Args:
|
211
|
+
node: Node object to change its quantization params function.
|
212
|
+
graph: Graph to apply the action on.
|
213
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
214
|
+
groups of layers by how they should be quantized, etc.)
|
215
|
+
|
216
|
+
Returns:
|
217
|
+
The node after its quantization params function has been modified.
|
218
|
+
"""
|
219
|
+
for nqc in node.candidates_quantization_cfg:
|
220
|
+
if self.activation_quantization_params_fn is not None:
|
221
|
+
nqc.activation_quantization_cfg.set_activation_quantization_params_fn(
|
222
|
+
self.activation_quantization_params_fn)
|
223
|
+
if self.weights_quantization_params_fn is not None:
|
224
|
+
(nqc.weights_quantization_cfg.get_attr_config(self.attr_name)
|
225
|
+
.set_weights_quantization_params_fn(self.weights_quantization_params_fn))
|
226
|
+
|
227
|
+
|
171
228
|
class ChangeFinalActivationQuantizationMethod(BaseAction):
|
172
229
|
"""
|
173
230
|
Class ChangeFinalActivationQuantizationMethod to change a node's weights/activations quantizer function.
|
@@ -183,19 +240,31 @@ class ChangeFinalActivationQuantizationMethod(BaseAction):
|
|
183
240
|
|
184
241
|
self.activation_quantization_method = activation_quantization_method
|
185
242
|
|
186
|
-
def apply(self, node: BaseNode, graph):
|
243
|
+
def apply(self, node: BaseNode, graph, fw_info):
|
187
244
|
"""
|
188
245
|
Change the node's activations quantization function.
|
189
246
|
|
190
247
|
Args:
|
191
248
|
node: Node object to change its threshold selection function.
|
192
249
|
graph: Graph to apply the action on.
|
250
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
251
|
+
groups of layers by how they should be quantized, etc.)
|
193
252
|
|
194
253
|
Returns:
|
195
254
|
The node after its quantization function has been modified.
|
196
255
|
"""
|
197
256
|
|
198
257
|
if self.activation_quantization_method is not None and node.final_activation_quantization_cfg is not None:
|
258
|
+
|
259
|
+
activation_quantization_params_fn = get_activation_quantization_params_fn(
|
260
|
+
self.activation_quantization_method)
|
261
|
+
|
262
|
+
node.final_activation_quantization_cfg.set_activation_quantization_params_fn(
|
263
|
+
activation_quantization_params_fn)
|
264
|
+
|
265
|
+
activation_quantization_fn = fw_info.activation_quantizer_mapping.get(self.activation_quantization_method)
|
266
|
+
|
267
|
+
node.final_activation_quantization_cfg.set_activation_quantization_fn(activation_quantization_fn)
|
199
268
|
node.final_activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method
|
200
269
|
|
201
270
|
|
@@ -213,23 +282,38 @@ class ChangeCandidatesActivationQuantizationMethod(BaseAction):
|
|
213
282
|
"""
|
214
283
|
self.activation_quantization_method = activation_quantization_method
|
215
284
|
|
216
|
-
def apply(self, node: BaseNode, graph):
|
285
|
+
def apply(self, node: BaseNode, graph, fw_info):
|
217
286
|
"""
|
218
287
|
Change the node's activations quantization function.
|
219
288
|
|
220
289
|
Args:
|
221
290
|
node: Node object to change its threshold selection function.
|
222
291
|
graph: Graph to apply the action on.
|
292
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
293
|
+
groups of layers by how they should be quantized, etc.)
|
223
294
|
|
295
|
+
Returns:
|
296
|
+
The node after its quantization function has been modified.
|
224
297
|
"""
|
225
298
|
if self.activation_quantization_method is not None:
|
226
299
|
for qc in node.candidates_quantization_cfg:
|
300
|
+
activation_quantization_params_fn = get_activation_quantization_params_fn(
|
301
|
+
self.activation_quantization_method)
|
302
|
+
|
303
|
+
qc.activation_quantization_cfg.set_activation_quantization_params_fn(activation_quantization_params_fn)
|
304
|
+
activation_quantization_fn = fw_info.activation_quantizer_mapping.get(
|
305
|
+
self.activation_quantization_method)
|
306
|
+
|
307
|
+
if activation_quantization_fn is None:
|
308
|
+
Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
|
309
|
+
|
310
|
+
qc.activation_quantization_cfg.set_activation_quantization_fn(activation_quantization_fn)
|
227
311
|
qc.activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method
|
228
312
|
|
229
313
|
|
230
314
|
class ChangeFinalWeightsQuantizationMethod(BaseAction):
|
231
315
|
"""
|
232
|
-
Class ChangeFinalWeightsQuantizationMethod to change a node's weights/activations quantizer
|
316
|
+
Class ChangeFinalWeightsQuantizationMethod to change a node's weights/activations quantizer function.
|
233
317
|
"""
|
234
318
|
|
235
319
|
def __init__(self, attr_name: str, weights_quantization_method=None):
|
@@ -244,19 +328,36 @@ class ChangeFinalWeightsQuantizationMethod(BaseAction):
|
|
244
328
|
self.weights_quantization_method = weights_quantization_method
|
245
329
|
self.attr_name = attr_name
|
246
330
|
|
247
|
-
def apply(self, node: BaseNode, graph):
|
331
|
+
def apply(self, node: BaseNode, graph, fw_info):
|
248
332
|
"""
|
249
333
|
Change the node's weights quantization function.
|
250
334
|
|
251
335
|
Args:
|
252
336
|
node: Node object to change its threshold selection function.
|
253
337
|
graph: Graph to apply the action on.
|
338
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
339
|
+
groups of layers by how they should be quantized, etc.)
|
254
340
|
|
341
|
+
Returns:
|
342
|
+
The node after its quantization function has been modified.
|
255
343
|
"""
|
256
344
|
|
257
345
|
if self.weights_quantization_method is not None and node.final_weights_quantization_cfg is not None:
|
258
|
-
|
259
|
-
|
346
|
+
|
347
|
+
weights_quantization_params_fn = get_weights_quantization_params_fn(self.weights_quantization_method)
|
348
|
+
|
349
|
+
(node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
|
350
|
+
.set_weights_quantization_params_fn(weights_quantization_params_fn))
|
351
|
+
|
352
|
+
weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method)
|
353
|
+
|
354
|
+
if weights_quantization_fn is None:
|
355
|
+
Logger.critical('Unknown weights quantization method specified.') # pragma: no cover
|
356
|
+
|
357
|
+
(node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
|
358
|
+
.set_weights_quantization_fn(weights_quantization_fn))
|
359
|
+
node.final_weights_quantization_cfg.get_attr_config(self.attr_name).weights_quantization_method = \
|
360
|
+
self.weights_quantization_method
|
260
361
|
|
261
362
|
|
262
363
|
class ChangeCandidatesWeightsQuantizationMethod(BaseAction):
|
@@ -275,13 +376,15 @@ class ChangeCandidatesWeightsQuantizationMethod(BaseAction):
|
|
275
376
|
self.weights_quantization_method = weights_quantization_method
|
276
377
|
self.attr_name = attr_name
|
277
378
|
|
278
|
-
def apply(self, node: BaseNode, graph: Graph):
|
379
|
+
def apply(self, node: BaseNode, graph: Graph, fw_info: FrameworkInfo):
|
279
380
|
"""
|
280
381
|
Change the node's weights quantization function.
|
281
382
|
|
282
383
|
Args:
|
283
384
|
node: Node object to change its threshold selection function.
|
284
385
|
graph: Graph to apply the action on.
|
386
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
387
|
+
groups of layers by how they should be quantized, etc.)
|
285
388
|
|
286
389
|
Returns:
|
287
390
|
The node after its quantization function has been modified.
|
@@ -289,7 +392,18 @@ class ChangeCandidatesWeightsQuantizationMethod(BaseAction):
|
|
289
392
|
|
290
393
|
if self.weights_quantization_method is not None:
|
291
394
|
for qc in node.candidates_quantization_cfg:
|
395
|
+
|
396
|
+
weights_quantization_params_fn = get_weights_quantization_params_fn(self.weights_quantization_method)
|
397
|
+
|
292
398
|
attr_qc = qc.weights_quantization_cfg.get_attr_config(self.attr_name)
|
399
|
+
attr_qc.set_weights_quantization_params_fn(weights_quantization_params_fn)
|
400
|
+
|
401
|
+
weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method)
|
402
|
+
|
403
|
+
if weights_quantization_fn is None:
|
404
|
+
Logger.critical('Unknown weights quantization method specified.') # pragma: no cover
|
405
|
+
|
406
|
+
attr_qc.set_weights_quantization_fn(weights_quantization_fn)
|
293
407
|
attr_qc.weights_quantization_method = self.weights_quantization_method
|
294
408
|
|
295
409
|
|
@@ -308,13 +422,15 @@ class ReplaceLayer(BaseAction):
|
|
308
422
|
self.layer_type = layer_type
|
309
423
|
self.get_params_and_weights_fn = get_params_and_weights_fn
|
310
424
|
|
311
|
-
def apply(self, node: BaseNode, graph: Graph):
|
425
|
+
def apply(self, node: BaseNode, graph: Graph, fw_info: FrameworkInfo):
|
312
426
|
"""
|
313
427
|
Replacing node's layer type and configurations
|
314
428
|
|
315
429
|
Args:
|
316
430
|
node: Node object to replace or modify
|
317
431
|
graph: Graph to apply the action on.
|
432
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
433
|
+
groups of layers by how they should be quantized, etc.)
|
318
434
|
|
319
435
|
Returns:
|
320
436
|
The node after its layer functionality has been modified.
|
@@ -14,17 +14,20 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
from typing import List
|
16
16
|
|
17
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
17
18
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
18
19
|
from model_compression_toolkit.core.common.network_editors import EditRule
|
19
20
|
|
20
21
|
|
21
22
|
def edit_network_graph(graph: Graph,
|
23
|
+
fw_info: FrameworkInfo,
|
22
24
|
network_editor: List[EditRule]):
|
23
25
|
"""
|
24
26
|
Apply a list of edit rules on a graph.
|
25
27
|
|
26
28
|
Args:
|
27
29
|
graph: The graph to edit.
|
30
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
28
31
|
groups of layers by how they should be quantized, etc.)
|
29
32
|
network_editor: List of edit rules to apply to the graph.
|
30
33
|
|
@@ -35,5 +38,5 @@ def edit_network_graph(graph: Graph,
|
|
35
38
|
for edit_rule in network_editor:
|
36
39
|
filtered_nodes = graph.filter(edit_rule.filter)
|
37
40
|
for node in filtered_nodes:
|
38
|
-
edit_rule.action.apply(node, graph)
|
41
|
+
edit_rule.action.apply(node, graph, fw_info)
|
39
42
|
# return graph
|
@@ -26,14 +26,18 @@ class ChannelGrouping:
|
|
26
26
|
based on their importance scores and SIMD group sizes.
|
27
27
|
"""
|
28
28
|
|
29
|
-
def __init__(self,
|
29
|
+
def __init__(self,
|
30
|
+
prunable_nodes: List[BaseNode],
|
31
|
+
fw_info: FrameworkInfo):
|
30
32
|
"""
|
31
33
|
Initializes the ChannelGrouping with necessary information.
|
32
34
|
|
33
35
|
Args:
|
34
36
|
prunable_nodes: List of nodes that can be pruned.
|
37
|
+
fw_info: Framework-specific information and utilities.
|
35
38
|
"""
|
36
39
|
self.prunable_nodes = prunable_nodes
|
40
|
+
self.fw_info = fw_info
|
37
41
|
# Store for each node a list of numpy arrays. Each numpy array represents the
|
38
42
|
# indices of the channels in an SIMD group.
|
39
43
|
self._simd_groups_indices = {}
|
@@ -38,6 +38,7 @@ class GreedyMaskCalculator:
|
|
38
38
|
"""
|
39
39
|
def __init__(self,
|
40
40
|
prunable_nodes: List[BaseNode],
|
41
|
+
fw_info: FrameworkInfo,
|
41
42
|
simd_groups_scores: Dict[BaseNode, np.ndarray],
|
42
43
|
target_resource_utilization: ResourceUtilization,
|
43
44
|
graph: Graph,
|
@@ -47,6 +48,7 @@ class GreedyMaskCalculator:
|
|
47
48
|
"""
|
48
49
|
Args:
|
49
50
|
prunable_nodes (List[BaseNode]): Nodes that are eligible for pruning.
|
51
|
+
fw_info (FrameworkInfo): Framework-specific information and utilities.
|
50
52
|
simd_groups_scores (Dict[BaseNode, np.ndarray]): Importance scores for each SIMG group in a prunable node.
|
51
53
|
target_resource_utilization (ResourceUtilization): The target resource utilization to achieve.
|
52
54
|
graph (Graph): The computational graph of the model.
|
@@ -55,6 +57,7 @@ class GreedyMaskCalculator:
|
|
55
57
|
simd_groups_indices (Dict[BaseNode, List[List[int]]]): Indices of SIMD groups in each node.
|
56
58
|
"""
|
57
59
|
self.prunable_nodes = prunable_nodes
|
60
|
+
self.fw_info = fw_info
|
58
61
|
self.target_resource_utilization = target_resource_utilization
|
59
62
|
self.graph = graph
|
60
63
|
self.fw_impl = fw_impl
|
@@ -64,11 +67,14 @@ class GreedyMaskCalculator:
|
|
64
67
|
self.simd_groups_scores = simd_groups_scores
|
65
68
|
|
66
69
|
self.oc_pruning_mask = PerSIMDGroupMask(prunable_nodes=prunable_nodes,
|
70
|
+
fw_info=fw_info,
|
67
71
|
simd_groups_indices=simd_groups_indices)
|
68
72
|
|
69
73
|
self.memory_calculator = MemoryCalculator(graph=graph,
|
74
|
+
fw_info=fw_info,
|
70
75
|
fw_impl=fw_impl)
|
71
76
|
|
77
|
+
|
72
78
|
def get_mask(self) -> Dict[BaseNode, np.ndarray]:
|
73
79
|
"""
|
74
80
|
Retrieves the current pruning mask for each prunable node.
|
@@ -38,7 +38,8 @@ class LFHImportanceMetric(BaseImportanceMetric):
|
|
38
38
|
graph: Graph,
|
39
39
|
representative_data_gen: Callable,
|
40
40
|
fw_impl: PruningFrameworkImplementation,
|
41
|
-
pruning_config: PruningConfig
|
41
|
+
pruning_config: PruningConfig,
|
42
|
+
fw_info: FrameworkInfo):
|
42
43
|
"""
|
43
44
|
Initialize the LFHImportanceMetric instance.
|
44
45
|
|
@@ -47,11 +48,13 @@ class LFHImportanceMetric(BaseImportanceMetric):
|
|
47
48
|
representative_data_gen (Callable): Function to generate representative data.
|
48
49
|
fw_impl (PruningFrameworkImplementation): Implementation of pruning for the framework.
|
49
50
|
pruning_config (PruningConfig): Configuration for pruning.
|
51
|
+
fw_info (FrameworkInfo): Framework-specific information.
|
50
52
|
"""
|
51
53
|
self.float_graph = graph
|
52
54
|
self.representative_data_gen = representative_data_gen
|
53
55
|
self.fw_impl = fw_impl
|
54
56
|
self.pruning_config = pruning_config
|
57
|
+
self.fw_info = fw_info
|
55
58
|
|
56
59
|
# Initialize internal dictionaries for storing intermediate computations.
|
57
60
|
self._entry_node_to_hessian_score = {}
|
@@ -155,7 +158,8 @@ class LFHImportanceMetric(BaseImportanceMetric):
|
|
155
158
|
Dict[BaseNode, List[np.ndarray]]: Dictionary of entry nodes mapped to their SIMD group indices.
|
156
159
|
"""
|
157
160
|
# Initialize channel grouping utility.
|
158
|
-
channel_grouping = ChannelGrouping(prunable_nodes=list(entry_node_to_score.keys())
|
161
|
+
channel_grouping = ChannelGrouping(prunable_nodes=list(entry_node_to_score.keys()),
|
162
|
+
fw_info=self.fw_info)
|
159
163
|
|
160
164
|
channel_grouping.group_scores_by_simd_groups(entry_node_to_score)
|
161
165
|
grouped_indices = channel_grouping.simd_groups_indices
|
@@ -245,14 +249,20 @@ class LFHImportanceMetric(BaseImportanceMetric):
|
|
245
249
|
Returns:
|
246
250
|
tuple: A tuple containing the kernel attribute, the number of output channels, and the axis of the output channels.
|
247
251
|
"""
|
252
|
+
kernel_attr = self.fw_info.get_kernel_op_attributes(entry_node.type)
|
253
|
+
# Ensure only one kernel attribute exists for the given node.
|
254
|
+
if len(kernel_attr) != 1:
|
255
|
+
Logger.critical(f"Expected a single attribute but found multiple attributes ({len(kernel_attr)}) for node {entry_node}.")
|
256
|
+
kernel_attr = kernel_attr[0]
|
257
|
+
|
248
258
|
# Retrieve and validate the axis index for the output channels.
|
249
|
-
oc_axis = entry_node.
|
259
|
+
oc_axis, _ = self.fw_info.kernel_channels_mapping.get(entry_node.type)
|
250
260
|
if oc_axis is None or int(oc_axis) != oc_axis:
|
251
261
|
Logger.critical(f"Invalid output channel axis type for node {entry_node}: expected integer but got {oc_axis}.")
|
252
262
|
|
253
263
|
# Get the number of output channels based on the kernel attribute and axis.
|
254
|
-
num_oc = entry_node.get_weights_by_keys(
|
255
|
-
return
|
264
|
+
num_oc = entry_node.get_weights_by_keys(kernel_attr[0]).shape[oc_axis]
|
265
|
+
return kernel_attr, num_oc, oc_axis
|
256
266
|
|
257
267
|
def _concatenate_tensors_by_indices(self,
|
258
268
|
channels: List[np.ndarray],
|
@@ -35,8 +35,9 @@ class MaskIndicator(Enum):
|
|
35
35
|
REMAINED = 1
|
36
36
|
|
37
37
|
|
38
|
+
|
38
39
|
class PerChannelMask:
|
39
|
-
def __init__(self, prunable_nodes: List[BaseNode]):
|
40
|
+
def __init__(self, prunable_nodes: List[BaseNode], fw_info: FrameworkInfo):
|
40
41
|
"""
|
41
42
|
Initializes the PerChannelMask with prunable nodes and framework information.
|
42
43
|
This class is responsible for maintaining and updating the pruning masks for each
|
@@ -45,8 +46,10 @@ class PerChannelMask:
|
|
45
46
|
|
46
47
|
Args:
|
47
48
|
prunable_nodes: List of nodes in the model that are subject to pruning.
|
49
|
+
fw_info: Framework-specific information required for pruning operations.
|
48
50
|
"""
|
49
51
|
self.prunable_nodes = prunable_nodes
|
52
|
+
self.fw_info = fw_info
|
50
53
|
self._mask = None # Initialize the mask dictionary
|
51
54
|
self._init_masks() # Call to initialize masks for each prunable node
|
52
55
|
|
@@ -103,7 +106,8 @@ class PerChannelMask:
|
|
103
106
|
Returns:
|
104
107
|
int: Number of output channels for the node.
|
105
108
|
"""
|
106
|
-
|
107
|
-
|
109
|
+
kernel_attr = self.fw_info.get_kernel_op_attributes(node.type)[0]
|
110
|
+
oc_axis = self.fw_info.kernel_channels_mapping.get(node.type)[0]
|
111
|
+
num_oc = node.get_weights_by_keys(kernel_attr).shape[oc_axis]
|
108
112
|
return num_oc
|
109
113
|
|
@@ -24,10 +24,10 @@ from model_compression_toolkit.core.common.pruning.memory_calculator import Memo
|
|
24
24
|
from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import PruningFrameworkImplementation
|
25
25
|
from model_compression_toolkit.logger import Logger
|
26
26
|
|
27
|
-
|
28
27
|
class PerSIMDGroupMask:
|
29
28
|
def __init__(self,
|
30
29
|
prunable_nodes: List[BaseNode],
|
30
|
+
fw_info: FrameworkInfo,
|
31
31
|
simd_groups_indices: Dict[BaseNode, List[List[int]]]):
|
32
32
|
"""
|
33
33
|
Initializes a mask calculator for SIMD groups in prunable nodes.
|
@@ -35,11 +35,13 @@ class PerSIMDGroupMask:
|
|
35
35
|
|
36
36
|
Args:
|
37
37
|
prunable_nodes: List of nodes that can be pruned.
|
38
|
+
fw_info: Framework-specific information.
|
38
39
|
simd_groups_indices: A dictionary mapping each node to its SIMD groups' indices.
|
39
40
|
"""
|
40
41
|
# Initialize the per-channel mask
|
41
|
-
self.per_channel_mask = PerChannelMask(prunable_nodes=prunable_nodes)
|
42
|
+
self.per_channel_mask = PerChannelMask(prunable_nodes=prunable_nodes, fw_info=fw_info)
|
42
43
|
self.prunable_nodes = prunable_nodes
|
44
|
+
self.fw_info = fw_info
|
43
45
|
self.simd_groups_indices = simd_groups_indices
|
44
46
|
self._mask_simd = None # Initialize the SIMD group mask dictionary
|
45
47
|
self._init_masks() # Initialize masks for each prunable node
|
@@ -34,16 +34,18 @@ class MemoryCalculator:
|
|
34
34
|
which is crucial for deploying models on memory-constrained devices or optimizing for computational efficiency.
|
35
35
|
"""
|
36
36
|
|
37
|
-
def __init__(self, graph: Graph, fw_impl: PruningFrameworkImplementation):
|
37
|
+
def __init__(self, graph: Graph, fw_info: FrameworkInfo, fw_impl: PruningFrameworkImplementation):
|
38
38
|
"""
|
39
39
|
Initializes the MemoryCalculator with necessary information about the model's graph,
|
40
40
|
framework-specific details, and pruning implementation.
|
41
41
|
|
42
42
|
Args:
|
43
43
|
graph (Graph): Computational graph of the model.
|
44
|
+
fw_info (FrameworkInfo): Contains framework-specific information.
|
44
45
|
fw_impl (PruningFrameworkImplementation): Implementation details for pruning.
|
45
46
|
"""
|
46
47
|
self.graph = graph
|
48
|
+
self.fw_info = fw_info
|
47
49
|
self.fw_impl = fw_impl
|
48
50
|
|
49
51
|
def get_pruned_graph_memory(self,
|
@@ -202,13 +204,19 @@ class MemoryCalculator:
|
|
202
204
|
if node == section.exit_node:
|
203
205
|
return masks.get(section.entry_node)
|
204
206
|
|
207
|
+
kernel_attr = self.fw_info.get_kernel_op_attributes(node.type)
|
208
|
+
# Ensure only one kernel attribute exists for the given node.
|
209
|
+
if len(kernel_attr) != 1:
|
210
|
+
Logger.critical(f"Expected a single attribute, but found {len(kernel_attr)} attributes for node '{node}'. Ensure the node configuration is correct.")
|
211
|
+
kernel_attr = kernel_attr[0]
|
212
|
+
|
205
213
|
# Retrieve and validate the axis index for the output channels.
|
206
|
-
ic_axis = node.
|
214
|
+
_, ic_axis = self.fw_info.kernel_channels_mapping.get(node.type)
|
207
215
|
if ic_axis is None or int(ic_axis) != ic_axis:
|
208
216
|
Logger.critical(f"Invalid input channel axis type for node '{node}': expected integer but got '{ic_axis}'.")
|
209
217
|
|
210
218
|
# Get the number of output channels based on the kernel attribute and axis.
|
211
|
-
num_ic = node.get_weights_by_keys(
|
219
|
+
num_ic = node.get_weights_by_keys(kernel_attr).shape[ic_axis]
|
212
220
|
mask = np.ones(num_ic, dtype=bool)
|
213
221
|
return mask
|
214
222
|
|
@@ -281,7 +289,7 @@ class MemoryCalculator:
|
|
281
289
|
int: The total number of parameters in the node after pruning.
|
282
290
|
"""
|
283
291
|
total_params = 0
|
284
|
-
attributes_and_oc_axis = self.fw_impl.attrs_oi_channels_info_for_pruning(node)
|
292
|
+
attributes_and_oc_axis = self.fw_impl.attrs_oi_channels_info_for_pruning(node, self.fw_info)
|
285
293
|
|
286
294
|
# Iterate over the node's weights and apply pruning based on the masks.
|
287
295
|
for w_attr, w in node.weights.items():
|
@@ -303,7 +311,7 @@ class MemoryCalculator:
|
|
303
311
|
num_oc = np.sum(output_mask)
|
304
312
|
else:
|
305
313
|
# Get the node channel axis from framework info
|
306
|
-
channel_axis = self.
|
314
|
+
channel_axis = self.fw_info.out_channel_axis_mapping.get(node.type)
|
307
315
|
if channel_axis is None:
|
308
316
|
Logger.critical(f"The channel axis is undefined. Please ensure the channel axis is explicitly defined for node {node.type} in the framework info.")
|
309
317
|
|
@@ -27,6 +27,7 @@ from model_compression_toolkit.logger import Logger
|
|
27
27
|
|
28
28
|
def build_pruned_graph(graph: Graph,
|
29
29
|
masks: Dict[BaseNode, np.ndarray],
|
30
|
+
fw_info: FrameworkInfo,
|
30
31
|
fw_impl: FrameworkImplementation) -> Graph:
|
31
32
|
"""
|
32
33
|
Prunes the provided graph according to the given pruning output-channels masks.
|
@@ -34,6 +35,7 @@ def build_pruned_graph(graph: Graph,
|
|
34
35
|
Args:
|
35
36
|
graph: The original computational graph to be pruned.
|
36
37
|
masks: A dictionary mapping each prunable node to its pruning mask.
|
38
|
+
fw_info: Framework-specific information object.
|
37
39
|
fw_impl: Framework-specific implementation object.
|
38
40
|
|
39
41
|
Returns:
|
@@ -64,7 +66,8 @@ def build_pruned_graph(graph: Graph,
|
|
64
66
|
section_mask = PruningSectionMask(entry_node_oc_mask=mask,
|
65
67
|
exit_node_ic_mask=mask)
|
66
68
|
pruning_section.apply_inner_section_mask(section_mask,
|
67
|
-
fw_impl
|
69
|
+
fw_impl,
|
70
|
+
fw_info)
|
68
71
|
|
69
72
|
return graph_to_prune
|
70
73
|
|
@@ -40,6 +40,7 @@ class Pruner:
|
|
40
40
|
"""
|
41
41
|
def __init__(self,
|
42
42
|
float_graph: Graph,
|
43
|
+
fw_info: FrameworkInfo,
|
43
44
|
fw_impl: PruningFrameworkImplementation,
|
44
45
|
target_resource_utilization: ResourceUtilization,
|
45
46
|
representative_data_gen: Callable,
|
@@ -48,6 +49,7 @@ class Pruner:
|
|
48
49
|
"""
|
49
50
|
Args:
|
50
51
|
float_graph (Graph): The floating-point representation of the model's computation graph.
|
52
|
+
fw_info (FrameworkInfo): Contains metadata and helper functions for the framework.
|
51
53
|
fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning.
|
52
54
|
target_resource_utilization (ResourceUtilization): The target resource utilization to be achieved after pruning.
|
53
55
|
representative_data_gen (Callable): Generator function for representative dataset used in pruning analysis.
|
@@ -55,6 +57,7 @@ class Pruner:
|
|
55
57
|
target_platform_capabilities (FrameworkQuantizationCapabilities): Object encapsulating the capabilities of the target hardware platform.
|
56
58
|
"""
|
57
59
|
self.float_graph = float_graph
|
60
|
+
self.fw_info = fw_info
|
58
61
|
self.fw_impl = fw_impl
|
59
62
|
self.target_resource_utilization = target_resource_utilization
|
60
63
|
self.representative_data_gen = representative_data_gen
|
@@ -81,6 +84,7 @@ class Pruner:
|
|
81
84
|
# Apply Greedy strategy to compute masks based on importance scores.
|
82
85
|
if self.pruning_config.channels_filtering_strategy == ChannelsFilteringStrategy.GREEDY:
|
83
86
|
mask_calculator = GreedyMaskCalculator(entry_nodes,
|
87
|
+
self.fw_info,
|
84
88
|
self.simd_scores,
|
85
89
|
self.target_resource_utilization,
|
86
90
|
self.float_graph,
|
@@ -95,6 +99,7 @@ class Pruner:
|
|
95
99
|
Logger.info("Start pruning graph...")
|
96
100
|
_pruned_graph = build_pruned_graph(self.float_graph,
|
97
101
|
self.per_oc_mask,
|
102
|
+
self.fw_info,
|
98
103
|
self.fw_impl)
|
99
104
|
return _pruned_graph
|
100
105
|
|
@@ -111,7 +116,7 @@ class Pruner:
|
|
111
116
|
# Retrieve and initialize the importance metric.
|
112
117
|
im = get_importance_metric(self.pruning_config.importance_metric, graph=self.float_graph,
|
113
118
|
representative_data_gen=self.representative_data_gen, fw_impl=self.fw_impl,
|
114
|
-
pruning_config=self.pruning_config)
|
119
|
+
pruning_config=self.pruning_config, fw_info=self.fw_info)
|
115
120
|
entry_node_to_simd_score, simd_groups_indices = im.get_entry_node_to_simd_score(entry_nodes)
|
116
121
|
return entry_node_to_simd_score, simd_groups_indices
|
117
122
|
|