mct-nightly 2.4.0.20250617.613__py3-none-any.whl → 2.4.0.20250618.606__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/METADATA +1 -1
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/RECORD +120 -120
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +2 -5
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
- model_compression_toolkit/core/common/framework_implementation.py +10 -22
- model_compression_toolkit/core/common/framework_info.py +105 -68
- model_compression_toolkit/core/common/graph/base_graph.py +15 -42
- model_compression_toolkit/core/common/graph/base_node.py +103 -42
- model_compression_toolkit/core/common/graph/functional_node.py +18 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
- model_compression_toolkit/core/common/model_collector.py +10 -20
- model_compression_toolkit/core/common/model_validation.py +1 -4
- model_compression_toolkit/core/common/network_editors/actions.py +14 -38
- model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
- model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
- model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
- model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
- model_compression_toolkit/core/common/pruning/pruner.py +1 -6
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
- model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
- model_compression_toolkit/core/graph_prep_runner.py +2 -16
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/keras/default_framework_info.py +138 -87
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
- model_compression_toolkit/core/keras/keras_implementation.py +15 -35
- model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
- model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +0 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/quantization_prep_runner.py +4 -9
- model_compression_toolkit/core/runner.py +5 -15
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
- model_compression_toolkit/gptq/common/gptq_training.py +1 -8
- model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
- model_compression_toolkit/gptq/keras/graph_info.py +4 -6
- model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
- model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/runner.py +1 -7
- model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
- model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
- model_compression_toolkit/ptq/runner.py +1 -4
- model_compression_toolkit/qat/common/qat_config.py +2 -6
- model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
- model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
- model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
- model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
- model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
- model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
- model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/top_level.txt +0 -0
model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15
15
|
import numpy as np
|
16
16
|
from typing import runtime_checkable, Protocol, Callable, Any, List, Tuple
|
17
17
|
|
18
|
-
from model_compression_toolkit.core import
|
18
|
+
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, MpDistanceWeighting
|
19
19
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
20
20
|
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
|
21
21
|
HessianScoresGranularity
|
@@ -62,15 +62,12 @@ class DistanceMetricCalculator(MetricCalculator):
|
|
62
62
|
graph: Graph,
|
63
63
|
mp_config: MixedPrecisionQuantizationConfig,
|
64
64
|
representative_data_gen: Callable,
|
65
|
-
fw_info: FrameworkInfo,
|
66
65
|
fw_impl: Any,
|
67
66
|
hessian_info_service: HessianInfoService = None):
|
68
67
|
"""
|
69
68
|
Args:
|
70
69
|
graph: Graph to search for its MP configuration.
|
71
70
|
mp_config: MP Quantization configuration for how the graph should be quantized.
|
72
|
-
fw_info: FrameworkInfo object about the specific framework
|
73
|
-
(e.g., attributes of different layers' weights to quantize).
|
74
71
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
75
72
|
representative_data_gen: Dataset used for getting batches for inference.
|
76
73
|
hessian_info_service: HessianInfoService to fetch Hessian approximation information.
|
@@ -78,14 +75,13 @@ class DistanceMetricCalculator(MetricCalculator):
|
|
78
75
|
self.graph = graph
|
79
76
|
self.mp_config = mp_config
|
80
77
|
self.representative_data_gen = representative_data_gen
|
81
|
-
self.fw_info = fw_info
|
82
78
|
self.fw_impl = fw_impl
|
83
79
|
|
84
80
|
if self.mp_config.distance_weighting_method == MpDistanceWeighting.HESSIAN:
|
85
81
|
assert hessian_info_service is not None, ('Expected HessianInfoService object to be passed with Hessian '
|
86
82
|
'distance weighting')
|
87
83
|
|
88
|
-
self.sorted_configurable_nodes_names = graph.get_configurable_sorted_nodes_names(
|
84
|
+
self.sorted_configurable_nodes_names = graph.get_configurable_sorted_nodes_names()
|
89
85
|
|
90
86
|
# Get interest points and output points set for distance measurement and set other helper datasets
|
91
87
|
# We define a separate set of output nodes of the model for the purpose of sensitivity computation.
|
@@ -396,8 +392,8 @@ class DistanceMetricCalculator(MetricCalculator):
|
|
396
392
|
"""
|
397
393
|
|
398
394
|
return [n.node for n in graph.get_outputs()
|
399
|
-
if (
|
400
|
-
n.node.is_weights_quantization_enabled(
|
395
|
+
if (n.node.is_kernel_op and
|
396
|
+
n.node.is_weights_quantization_enabled(n.node.kernel_attr)) or
|
401
397
|
n.node.is_activation_quantization_enabled()]
|
402
398
|
|
403
399
|
@staticmethod
|
model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py
CHANGED
@@ -38,7 +38,6 @@ class SensitivityEvaluation:
|
|
38
38
|
graph: Graph,
|
39
39
|
mp_config: MixedPrecisionQuantizationConfig,
|
40
40
|
representative_data_gen: Callable,
|
41
|
-
fw_info: FrameworkInfo,
|
42
41
|
fw_impl: Any,
|
43
42
|
disable_activation_for_metric: bool = False,
|
44
43
|
hessian_info_service: HessianInfoService = None
|
@@ -46,8 +45,6 @@ class SensitivityEvaluation:
|
|
46
45
|
"""
|
47
46
|
Args:
|
48
47
|
graph: Graph to search for its MP configuration.
|
49
|
-
fw_info: FrameworkInfo object about the specific framework
|
50
|
-
(e.g., attributes of different layers' weights to quantize).
|
51
48
|
mp_config: MP Quantization configuration for how the graph should be quantized.
|
52
49
|
representative_data_gen: Dataset used for getting batches for inference.
|
53
50
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
@@ -57,14 +54,13 @@ class SensitivityEvaluation:
|
|
57
54
|
"""
|
58
55
|
self.mp_config = mp_config
|
59
56
|
self.representative_data_gen = representative_data_gen
|
60
|
-
self.fw_info = fw_info
|
61
57
|
self.fw_impl = fw_impl
|
62
58
|
|
63
59
|
if self.mp_config.custom_metric_fn:
|
64
60
|
self.metric_calculator = CustomMetricCalculator(graph, self.mp_config.custom_metric_fn)
|
65
61
|
else:
|
66
62
|
self.metric_calculator = DistanceMetricCalculator(graph, mp_config, representative_data_gen,
|
67
|
-
|
63
|
+
fw_impl=fw_impl,
|
68
64
|
hessian_info_service=hessian_info_service)
|
69
65
|
|
70
66
|
# Build a mixed-precision model which can be configured to use different bitwidth in different layers.
|
@@ -111,8 +107,7 @@ class SensitivityEvaluation:
|
|
111
107
|
|
112
108
|
model_mp, _, conf_node2layers = self.fw_impl.model_builder(evaluation_graph,
|
113
109
|
mode=ModelBuilderMode.MIXEDPRECISION,
|
114
|
-
append2output=outputs
|
115
|
-
fw_info=self.fw_info)
|
110
|
+
append2output=outputs)
|
116
111
|
|
117
112
|
# Disable all configurable quantizers. They will be activated one at a time during sensitivity evaluation.
|
118
113
|
for layer in itertools.chain(*conf_node2layers.values()):
|
@@ -18,7 +18,7 @@ import numpy as np
|
|
18
18
|
from typing import List, Union, Tuple, Optional
|
19
19
|
|
20
20
|
from networkx.algorithms.dag import topological_sort
|
21
|
-
from model_compression_toolkit.core import
|
21
|
+
from model_compression_toolkit.core import QuantizationErrorMethod
|
22
22
|
from model_compression_toolkit.core import common
|
23
23
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
24
24
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
@@ -30,7 +30,6 @@ from model_compression_toolkit.core.common.collectors.statistics_collector impor
|
|
30
30
|
|
31
31
|
|
32
32
|
def create_stats_collector_for_node(node: common.BaseNode,
|
33
|
-
fw_info: FrameworkInfo,
|
34
33
|
quant_node_in_fln: bool) -> BaseStatsCollector:
|
35
34
|
"""
|
36
35
|
Gets a node and a groups list and create and return a statistics collector for a node
|
@@ -39,7 +38,7 @@ def create_stats_collector_for_node(node: common.BaseNode,
|
|
39
38
|
|
40
39
|
Args:
|
41
40
|
node: Node to create its statistics collector.
|
42
|
-
|
41
|
+
quant_node_in_fln: Whether the node should be quantized as part of an FLN.
|
43
42
|
|
44
43
|
Returns:
|
45
44
|
Statistics collector for statistics collection for the node.
|
@@ -48,7 +47,7 @@ def create_stats_collector_for_node(node: common.BaseNode,
|
|
48
47
|
if node.is_activation_quantization_enabled() or quant_node_in_fln:
|
49
48
|
min_output = getattr(node.prior_info, 'min_output', None)
|
50
49
|
max_output = getattr(node.prior_info, 'max_output', None)
|
51
|
-
stats_collector = common.StatsCollector(out_channel_axis=
|
50
|
+
stats_collector = common.StatsCollector(out_channel_axis=node.out_channel_axis,
|
52
51
|
init_min_value=min_output,
|
53
52
|
init_max_value=max_output)
|
54
53
|
else:
|
@@ -58,21 +57,19 @@ def create_stats_collector_for_node(node: common.BaseNode,
|
|
58
57
|
|
59
58
|
|
60
59
|
def create_tensor2node(graph: common.Graph,
|
61
|
-
node: common.BaseNode
|
62
|
-
fw_info: common.FrameworkInfo):
|
60
|
+
node: common.BaseNode):
|
63
61
|
"""
|
64
62
|
Force statistic collector creation and assignment for a node.
|
65
63
|
Args:
|
66
64
|
graph: Graph of the node (for retrieving the current tensor).
|
67
65
|
node: Node to create a tensor for.
|
68
|
-
fw_info: Specific framework information (for example, output channels index).
|
69
66
|
|
70
67
|
"""
|
71
68
|
current_sc = graph.get_out_stats_collector(node)
|
72
69
|
is_list_nostat_collectors = isinstance(current_sc, list) and len(
|
73
70
|
[sc for sc in current_sc if not isinstance(sc, common.NoStatsCollector)]) == 0
|
74
71
|
if isinstance(current_sc, common.NoStatsCollector) or current_sc is None or is_list_nostat_collectors:
|
75
|
-
stats_collector = common.StatsCollector(
|
72
|
+
stats_collector = common.StatsCollector(node.out_channel_axis)
|
76
73
|
graph.set_out_stats_collector_to_node(node, stats_collector)
|
77
74
|
|
78
75
|
|
@@ -140,7 +137,6 @@ class ModelCollector:
|
|
140
137
|
|
141
138
|
def __init__(self, graph: Graph,
|
142
139
|
fw_impl: FrameworkImplementation,
|
143
|
-
fw_info: FrameworkInfo,
|
144
140
|
hessian_info_service: HessianInfoService = None,
|
145
141
|
qc: common.QuantizationConfig = common.DEFAULTCONFIG):
|
146
142
|
"""
|
@@ -149,12 +145,10 @@ class ModelCollector:
|
|
149
145
|
Args:
|
150
146
|
graph: Graph to build a model from it.
|
151
147
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
152
|
-
fw_info: FrameworkInfo object with a specific framework information.
|
153
148
|
qc: Quantization configuration containing parameters for how the graph should be quantized.
|
154
149
|
"""
|
155
150
|
|
156
151
|
self.fw_impl = fw_impl
|
157
|
-
self.fw_info = fw_info
|
158
152
|
self.hessian_service = hessian_info_service
|
159
153
|
self.qc = qc
|
160
154
|
self.model_outputs = [out.node for out in graph.get_outputs()]
|
@@ -162,17 +156,15 @@ class ModelCollector:
|
|
162
156
|
# Assign statistics collectors to nodes
|
163
157
|
for n in graph.get_topo_sorted_nodes():
|
164
158
|
quant_node_in_fln = n.is_fln_quantization() and graph.fusing_info.is_quantized_node_in_fln(n)
|
165
|
-
sc = create_stats_collector_for_node(n,
|
159
|
+
sc = create_stats_collector_for_node(n, quant_node_in_fln=quant_node_in_fln) # Get static collector for the node
|
166
160
|
# If we use bias correction, and the node has kernel weights to quantize, we need to make sure
|
167
161
|
# its previous nodes' tensors are consistent with this node.
|
168
|
-
kernel_attr
|
169
|
-
|
170
|
-
kernel_attr):
|
162
|
+
if qc.weights_bias_correction and n.kernel_attr is not None and n.is_weights_quantization_enabled(
|
163
|
+
n.kernel_attr):
|
171
164
|
for ie in graph.incoming_edges(n):
|
172
165
|
input_node = ie.source_node
|
173
166
|
create_tensor2node(graph,
|
174
|
-
input_node
|
175
|
-
fw_info)
|
167
|
+
input_node)
|
176
168
|
if sc is not None:
|
177
169
|
graph.set_out_stats_collector_to_node(n, sc)
|
178
170
|
|
@@ -205,13 +197,11 @@ class ModelCollector:
|
|
205
197
|
# TODO: Add integration test for this case
|
206
198
|
append2output = outputs_nodes + [n for n in self.model_outputs if n not in outputs_nodes]
|
207
199
|
|
208
|
-
|
209
200
|
# Build a float model and output all layers' outputs
|
210
201
|
# (that should be collected) as the model's outputs
|
211
202
|
self.model, _ = self.fw_impl.model_builder(graph,
|
212
203
|
mode=ModelBuilderMode.FLOAT,
|
213
|
-
append2output=append2output
|
214
|
-
fw_info=self.fw_info)
|
204
|
+
append2output=append2output)
|
215
205
|
|
216
206
|
def infer(self, inputs_list: List[np.ndarray]):
|
217
207
|
"""
|
@@ -10,17 +10,14 @@ class ModelValidation:
|
|
10
10
|
"""
|
11
11
|
|
12
12
|
def __init__(self,
|
13
|
-
model: Any
|
14
|
-
fw_info:FrameworkInfo):
|
13
|
+
model: Any):
|
15
14
|
"""
|
16
15
|
Initialize a ModelValidation object.
|
17
16
|
|
18
17
|
Args:
|
19
18
|
model: Model to check its validity.
|
20
|
-
fw_info: Information about the specific framework of the model.
|
21
19
|
"""
|
22
20
|
self.model = model
|
23
|
-
self.fw_info = fw_info
|
24
21
|
|
25
22
|
@abstractmethod
|
26
23
|
def validate_output_channel_consistency(self):
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.core.common import Graph
|
|
22
22
|
from model_compression_toolkit.logger import Logger
|
23
23
|
|
24
24
|
|
25
|
-
from model_compression_toolkit.core.common.framework_info import
|
25
|
+
from model_compression_toolkit.core.common.framework_info import get_fw_info
|
26
26
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
27
27
|
from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
|
28
28
|
get_activation_quantization_params_fn, get_weights_quantization_params_fn
|
@@ -64,15 +64,13 @@ class BaseAction(ABC):
|
|
64
64
|
"""
|
65
65
|
|
66
66
|
@abstractmethod
|
67
|
-
def apply(self, node: BaseNode, graph
|
67
|
+
def apply(self, node: BaseNode, graph):
|
68
68
|
"""
|
69
69
|
Apply an action on the node after matching the node with a node filter.
|
70
70
|
|
71
71
|
Args:
|
72
72
|
node: Node to apply the action on.
|
73
73
|
graph: Graph to apply the action on.
|
74
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
75
|
-
groups of layers by how they should be quantized, etc.)
|
76
74
|
|
77
75
|
Returns:
|
78
76
|
Node after action is applied.
|
@@ -95,15 +93,13 @@ class ChangeCandidatesWeightsQuantConfigAttr(BaseAction):
|
|
95
93
|
self.kwargs = kwargs
|
96
94
|
self.attr_name = attr_name
|
97
95
|
|
98
|
-
def apply(self, node: BaseNode, graph
|
96
|
+
def apply(self, node: BaseNode, graph):
|
99
97
|
"""
|
100
98
|
Change the attribute 'attr_name' in weights quantization config candidates with 'attr_value'.
|
101
99
|
|
102
100
|
Args:
|
103
101
|
node: Node object to change its quant_config.
|
104
102
|
graph: Graph to apply the action on.
|
105
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
106
|
-
groups of layers by how they should be quantized, etc.)
|
107
103
|
Returns:
|
108
104
|
The node after its weights' quantization config candidates have been modified.
|
109
105
|
"""
|
@@ -128,7 +124,7 @@ class ChangeFinalWeightsQuantConfigAttr(BaseAction):
|
|
128
124
|
self.kwargs = kwargs
|
129
125
|
self.attr_name = attr_name
|
130
126
|
|
131
|
-
def apply(self, node: BaseNode, graph
|
127
|
+
def apply(self, node: BaseNode, graph):
|
132
128
|
if node.final_weights_quantization_cfg is not None:
|
133
129
|
for parameter_name, parameter_value in self.kwargs.items():
|
134
130
|
node.final_weights_quantization_cfg.set_quant_config_attr(parameter_name, parameter_value,
|
@@ -147,17 +143,13 @@ class ChangeCandidatesActivationQuantConfigAttr(BaseAction):
|
|
147
143
|
"""
|
148
144
|
self.kwargs = kwargs
|
149
145
|
|
150
|
-
def apply(self, node: BaseNode, graph
|
146
|
+
def apply(self, node: BaseNode, graph):
|
151
147
|
"""
|
152
148
|
Change the attribute 'attr_name' in activation quantization configuration candidates with 'attr_value'.
|
153
149
|
|
154
150
|
Args:
|
155
151
|
node: Node object to change its quant_config.
|
156
152
|
graph: Graph to apply the action on.
|
157
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
158
|
-
groups of layers by how they should be quantized, etc.)
|
159
|
-
Returns:q
|
160
|
-
The node after its activation quantization configuration candidates have been modified.
|
161
153
|
"""
|
162
154
|
for nqc in node.candidates_quantization_cfg:
|
163
155
|
for parameter_name, parameter_value in self.kwargs.items():
|
@@ -176,7 +168,7 @@ class ChangeFinalActivationQuantConfigAttr(BaseAction):
|
|
176
168
|
"""
|
177
169
|
self.kwargs = kwargs
|
178
170
|
|
179
|
-
def apply(self, node: BaseNode, graph
|
171
|
+
def apply(self, node: BaseNode, graph):
|
180
172
|
if node.final_activation_quantization_cfg is not None:
|
181
173
|
for parameter_name, parameter_value in self.kwargs.items():
|
182
174
|
node.final_activation_quantization_cfg.set_quant_config_attr(parameter_name, parameter_value)
|
@@ -203,15 +195,13 @@ class ChangeQuantizationParamFunction(BaseAction):
|
|
203
195
|
self.weights_quantization_params_fn = weights_quantization_params_fn
|
204
196
|
self.attr_name = attr_name
|
205
197
|
|
206
|
-
def apply(self, node: BaseNode, graph
|
198
|
+
def apply(self, node: BaseNode, graph):
|
207
199
|
"""
|
208
200
|
Change the node's weights/activations quantization params function.
|
209
201
|
|
210
202
|
Args:
|
211
203
|
node: Node object to change its quantization params function.
|
212
204
|
graph: Graph to apply the action on.
|
213
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
214
|
-
groups of layers by how they should be quantized, etc.)
|
215
205
|
|
216
206
|
Returns:
|
217
207
|
The node after its quantization params function has been modified.
|
@@ -240,15 +230,13 @@ class ChangeFinalActivationQuantizationMethod(BaseAction):
|
|
240
230
|
|
241
231
|
self.activation_quantization_method = activation_quantization_method
|
242
232
|
|
243
|
-
def apply(self, node: BaseNode, graph
|
233
|
+
def apply(self, node: BaseNode, graph):
|
244
234
|
"""
|
245
235
|
Change the node's activations quantization function.
|
246
236
|
|
247
237
|
Args:
|
248
238
|
node: Node object to change its threshold selection function.
|
249
239
|
graph: Graph to apply the action on.
|
250
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
251
|
-
groups of layers by how they should be quantized, etc.)
|
252
240
|
|
253
241
|
Returns:
|
254
242
|
The node after its quantization function has been modified.
|
@@ -262,7 +250,7 @@ class ChangeFinalActivationQuantizationMethod(BaseAction):
|
|
262
250
|
node.final_activation_quantization_cfg.set_activation_quantization_params_fn(
|
263
251
|
activation_quantization_params_fn)
|
264
252
|
|
265
|
-
activation_quantization_fn =
|
253
|
+
activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(self.activation_quantization_method)
|
266
254
|
|
267
255
|
node.final_activation_quantization_cfg.set_activation_quantization_fn(activation_quantization_fn)
|
268
256
|
node.final_activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method
|
@@ -282,18 +270,14 @@ class ChangeCandidatesActivationQuantizationMethod(BaseAction):
|
|
282
270
|
"""
|
283
271
|
self.activation_quantization_method = activation_quantization_method
|
284
272
|
|
285
|
-
def apply(self, node: BaseNode, graph
|
273
|
+
def apply(self, node: BaseNode, graph):
|
286
274
|
"""
|
287
275
|
Change the node's activations quantization function.
|
288
276
|
|
289
277
|
Args:
|
290
278
|
node: Node object to change its threshold selection function.
|
291
279
|
graph: Graph to apply the action on.
|
292
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
293
|
-
groups of layers by how they should be quantized, etc.)
|
294
280
|
|
295
|
-
Returns:
|
296
|
-
The node after its quantization function has been modified.
|
297
281
|
"""
|
298
282
|
if self.activation_quantization_method is not None:
|
299
283
|
for qc in node.candidates_quantization_cfg:
|
@@ -301,7 +285,7 @@ class ChangeCandidatesActivationQuantizationMethod(BaseAction):
|
|
301
285
|
self.activation_quantization_method)
|
302
286
|
|
303
287
|
qc.activation_quantization_cfg.set_activation_quantization_params_fn(activation_quantization_params_fn)
|
304
|
-
activation_quantization_fn =
|
288
|
+
activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(
|
305
289
|
self.activation_quantization_method)
|
306
290
|
|
307
291
|
if activation_quantization_fn is None:
|
@@ -328,18 +312,14 @@ class ChangeFinalWeightsQuantizationMethod(BaseAction):
|
|
328
312
|
self.weights_quantization_method = weights_quantization_method
|
329
313
|
self.attr_name = attr_name
|
330
314
|
|
331
|
-
def apply(self, node: BaseNode, graph
|
315
|
+
def apply(self, node: BaseNode, graph):
|
332
316
|
"""
|
333
317
|
Change the node's weights quantization function.
|
334
318
|
|
335
319
|
Args:
|
336
320
|
node: Node object to change its threshold selection function.
|
337
321
|
graph: Graph to apply the action on.
|
338
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
339
|
-
groups of layers by how they should be quantized, etc.)
|
340
322
|
|
341
|
-
Returns:
|
342
|
-
The node after its quantization function has been modified.
|
343
323
|
"""
|
344
324
|
|
345
325
|
if self.weights_quantization_method is not None and node.final_weights_quantization_cfg is not None:
|
@@ -376,15 +356,13 @@ class ChangeCandidatesWeightsQuantizationMethod(BaseAction):
|
|
376
356
|
self.weights_quantization_method = weights_quantization_method
|
377
357
|
self.attr_name = attr_name
|
378
358
|
|
379
|
-
def apply(self, node: BaseNode, graph: Graph
|
359
|
+
def apply(self, node: BaseNode, graph: Graph):
|
380
360
|
"""
|
381
361
|
Change the node's weights quantization function.
|
382
362
|
|
383
363
|
Args:
|
384
364
|
node: Node object to change its threshold selection function.
|
385
365
|
graph: Graph to apply the action on.
|
386
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
387
|
-
groups of layers by how they should be quantized, etc.)
|
388
366
|
|
389
367
|
Returns:
|
390
368
|
The node after its quantization function has been modified.
|
@@ -422,15 +400,13 @@ class ReplaceLayer(BaseAction):
|
|
422
400
|
self.layer_type = layer_type
|
423
401
|
self.get_params_and_weights_fn = get_params_and_weights_fn
|
424
402
|
|
425
|
-
def apply(self, node: BaseNode, graph: Graph
|
403
|
+
def apply(self, node: BaseNode, graph: Graph):
|
426
404
|
"""
|
427
405
|
Replacing node's layer type and configurations
|
428
406
|
|
429
407
|
Args:
|
430
408
|
node: Node object to replace or modify
|
431
409
|
graph: Graph to apply the action on.
|
432
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
433
|
-
groups of layers by how they should be quantized, etc.)
|
434
410
|
|
435
411
|
Returns:
|
436
412
|
The node after its layer functionality has been modified.
|
@@ -14,20 +14,17 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
from typing import List
|
16
16
|
|
17
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
18
17
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
19
18
|
from model_compression_toolkit.core.common.network_editors import EditRule
|
20
19
|
|
21
20
|
|
22
21
|
def edit_network_graph(graph: Graph,
|
23
|
-
fw_info: FrameworkInfo,
|
24
22
|
network_editor: List[EditRule]):
|
25
23
|
"""
|
26
24
|
Apply a list of edit rules on a graph.
|
27
25
|
|
28
26
|
Args:
|
29
27
|
graph: The graph to edit.
|
30
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
31
28
|
groups of layers by how they should be quantized, etc.)
|
32
29
|
network_editor: List of edit rules to apply to the graph.
|
33
30
|
|
@@ -38,5 +35,5 @@ def edit_network_graph(graph: Graph,
|
|
38
35
|
for edit_rule in network_editor:
|
39
36
|
filtered_nodes = graph.filter(edit_rule.filter)
|
40
37
|
for node in filtered_nodes:
|
41
|
-
edit_rule.action.apply(node, graph
|
38
|
+
edit_rule.action.apply(node, graph)
|
42
39
|
# return graph
|
@@ -26,18 +26,14 @@ class ChannelGrouping:
|
|
26
26
|
based on their importance scores and SIMD group sizes.
|
27
27
|
"""
|
28
28
|
|
29
|
-
def __init__(self,
|
30
|
-
prunable_nodes: List[BaseNode],
|
31
|
-
fw_info: FrameworkInfo):
|
29
|
+
def __init__(self, prunable_nodes: List[BaseNode]):
|
32
30
|
"""
|
33
31
|
Initializes the ChannelGrouping with necessary information.
|
34
32
|
|
35
33
|
Args:
|
36
34
|
prunable_nodes: List of nodes that can be pruned.
|
37
|
-
fw_info: Framework-specific information and utilities.
|
38
35
|
"""
|
39
36
|
self.prunable_nodes = prunable_nodes
|
40
|
-
self.fw_info = fw_info
|
41
37
|
# Store for each node a list of numpy arrays. Each numpy array represents the
|
42
38
|
# indices of the channels in an SIMD group.
|
43
39
|
self._simd_groups_indices = {}
|
@@ -38,7 +38,6 @@ class GreedyMaskCalculator:
|
|
38
38
|
"""
|
39
39
|
def __init__(self,
|
40
40
|
prunable_nodes: List[BaseNode],
|
41
|
-
fw_info: FrameworkInfo,
|
42
41
|
simd_groups_scores: Dict[BaseNode, np.ndarray],
|
43
42
|
target_resource_utilization: ResourceUtilization,
|
44
43
|
graph: Graph,
|
@@ -48,7 +47,6 @@ class GreedyMaskCalculator:
|
|
48
47
|
"""
|
49
48
|
Args:
|
50
49
|
prunable_nodes (List[BaseNode]): Nodes that are eligible for pruning.
|
51
|
-
fw_info (FrameworkInfo): Framework-specific information and utilities.
|
52
50
|
simd_groups_scores (Dict[BaseNode, np.ndarray]): Importance scores for each SIMG group in a prunable node.
|
53
51
|
target_resource_utilization (ResourceUtilization): The target resource utilization to achieve.
|
54
52
|
graph (Graph): The computational graph of the model.
|
@@ -57,7 +55,6 @@ class GreedyMaskCalculator:
|
|
57
55
|
simd_groups_indices (Dict[BaseNode, List[List[int]]]): Indices of SIMD groups in each node.
|
58
56
|
"""
|
59
57
|
self.prunable_nodes = prunable_nodes
|
60
|
-
self.fw_info = fw_info
|
61
58
|
self.target_resource_utilization = target_resource_utilization
|
62
59
|
self.graph = graph
|
63
60
|
self.fw_impl = fw_impl
|
@@ -67,14 +64,11 @@ class GreedyMaskCalculator:
|
|
67
64
|
self.simd_groups_scores = simd_groups_scores
|
68
65
|
|
69
66
|
self.oc_pruning_mask = PerSIMDGroupMask(prunable_nodes=prunable_nodes,
|
70
|
-
fw_info=fw_info,
|
71
67
|
simd_groups_indices=simd_groups_indices)
|
72
68
|
|
73
69
|
self.memory_calculator = MemoryCalculator(graph=graph,
|
74
|
-
fw_info=fw_info,
|
75
70
|
fw_impl=fw_impl)
|
76
71
|
|
77
|
-
|
78
72
|
def get_mask(self) -> Dict[BaseNode, np.ndarray]:
|
79
73
|
"""
|
80
74
|
Retrieves the current pruning mask for each prunable node.
|
@@ -38,8 +38,7 @@ class LFHImportanceMetric(BaseImportanceMetric):
|
|
38
38
|
graph: Graph,
|
39
39
|
representative_data_gen: Callable,
|
40
40
|
fw_impl: PruningFrameworkImplementation,
|
41
|
-
pruning_config: PruningConfig
|
42
|
-
fw_info: FrameworkInfo):
|
41
|
+
pruning_config: PruningConfig):
|
43
42
|
"""
|
44
43
|
Initialize the LFHImportanceMetric instance.
|
45
44
|
|
@@ -48,13 +47,11 @@ class LFHImportanceMetric(BaseImportanceMetric):
|
|
48
47
|
representative_data_gen (Callable): Function to generate representative data.
|
49
48
|
fw_impl (PruningFrameworkImplementation): Implementation of pruning for the framework.
|
50
49
|
pruning_config (PruningConfig): Configuration for pruning.
|
51
|
-
fw_info (FrameworkInfo): Framework-specific information.
|
52
50
|
"""
|
53
51
|
self.float_graph = graph
|
54
52
|
self.representative_data_gen = representative_data_gen
|
55
53
|
self.fw_impl = fw_impl
|
56
54
|
self.pruning_config = pruning_config
|
57
|
-
self.fw_info = fw_info
|
58
55
|
|
59
56
|
# Initialize internal dictionaries for storing intermediate computations.
|
60
57
|
self._entry_node_to_hessian_score = {}
|
@@ -158,8 +155,7 @@ class LFHImportanceMetric(BaseImportanceMetric):
|
|
158
155
|
Dict[BaseNode, List[np.ndarray]]: Dictionary of entry nodes mapped to their SIMD group indices.
|
159
156
|
"""
|
160
157
|
# Initialize channel grouping utility.
|
161
|
-
channel_grouping = ChannelGrouping(prunable_nodes=list(entry_node_to_score.keys())
|
162
|
-
fw_info=self.fw_info)
|
158
|
+
channel_grouping = ChannelGrouping(prunable_nodes=list(entry_node_to_score.keys()))
|
163
159
|
|
164
160
|
channel_grouping.group_scores_by_simd_groups(entry_node_to_score)
|
165
161
|
grouped_indices = channel_grouping.simd_groups_indices
|
@@ -249,20 +245,14 @@ class LFHImportanceMetric(BaseImportanceMetric):
|
|
249
245
|
Returns:
|
250
246
|
tuple: A tuple containing the kernel attribute, the number of output channels, and the axis of the output channels.
|
251
247
|
"""
|
252
|
-
kernel_attr = self.fw_info.get_kernel_op_attributes(entry_node.type)
|
253
|
-
# Ensure only one kernel attribute exists for the given node.
|
254
|
-
if len(kernel_attr) != 1:
|
255
|
-
Logger.critical(f"Expected a single attribute but found multiple attributes ({len(kernel_attr)}) for node {entry_node}.")
|
256
|
-
kernel_attr = kernel_attr[0]
|
257
|
-
|
258
248
|
# Retrieve and validate the axis index for the output channels.
|
259
|
-
oc_axis
|
249
|
+
oc_axis = entry_node.channel_axis.output
|
260
250
|
if oc_axis is None or int(oc_axis) != oc_axis:
|
261
251
|
Logger.critical(f"Invalid output channel axis type for node {entry_node}: expected integer but got {oc_axis}.")
|
262
252
|
|
263
253
|
# Get the number of output channels based on the kernel attribute and axis.
|
264
|
-
num_oc = entry_node.get_weights_by_keys(kernel_attr
|
265
|
-
return kernel_attr, num_oc, oc_axis
|
254
|
+
num_oc = entry_node.get_weights_by_keys(entry_node.kernel_attr).shape[oc_axis]
|
255
|
+
return entry_node.kernel_attr, num_oc, oc_axis
|
266
256
|
|
267
257
|
def _concatenate_tensors_by_indices(self,
|
268
258
|
channels: List[np.ndarray],
|
@@ -35,9 +35,8 @@ class MaskIndicator(Enum):
|
|
35
35
|
REMAINED = 1
|
36
36
|
|
37
37
|
|
38
|
-
|
39
38
|
class PerChannelMask:
|
40
|
-
def __init__(self, prunable_nodes: List[BaseNode]
|
39
|
+
def __init__(self, prunable_nodes: List[BaseNode]):
|
41
40
|
"""
|
42
41
|
Initializes the PerChannelMask with prunable nodes and framework information.
|
43
42
|
This class is responsible for maintaining and updating the pruning masks for each
|
@@ -46,10 +45,8 @@ class PerChannelMask:
|
|
46
45
|
|
47
46
|
Args:
|
48
47
|
prunable_nodes: List of nodes in the model that are subject to pruning.
|
49
|
-
fw_info: Framework-specific information required for pruning operations.
|
50
48
|
"""
|
51
49
|
self.prunable_nodes = prunable_nodes
|
52
|
-
self.fw_info = fw_info
|
53
50
|
self._mask = None # Initialize the mask dictionary
|
54
51
|
self._init_masks() # Call to initialize masks for each prunable node
|
55
52
|
|
@@ -106,8 +103,7 @@ class PerChannelMask:
|
|
106
103
|
Returns:
|
107
104
|
int: Number of output channels for the node.
|
108
105
|
"""
|
109
|
-
|
110
|
-
|
111
|
-
num_oc = node.get_weights_by_keys(kernel_attr).shape[oc_axis]
|
106
|
+
oc_axis = node.channel_axis.output
|
107
|
+
num_oc = node.get_weights_by_keys(node.kernel_attr).shape[oc_axis]
|
112
108
|
return num_oc
|
113
109
|
|
@@ -24,10 +24,10 @@ from model_compression_toolkit.core.common.pruning.memory_calculator import Memo
|
|
24
24
|
from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import PruningFrameworkImplementation
|
25
25
|
from model_compression_toolkit.logger import Logger
|
26
26
|
|
27
|
+
|
27
28
|
class PerSIMDGroupMask:
|
28
29
|
def __init__(self,
|
29
30
|
prunable_nodes: List[BaseNode],
|
30
|
-
fw_info: FrameworkInfo,
|
31
31
|
simd_groups_indices: Dict[BaseNode, List[List[int]]]):
|
32
32
|
"""
|
33
33
|
Initializes a mask calculator for SIMD groups in prunable nodes.
|
@@ -35,13 +35,11 @@ class PerSIMDGroupMask:
|
|
35
35
|
|
36
36
|
Args:
|
37
37
|
prunable_nodes: List of nodes that can be pruned.
|
38
|
-
fw_info: Framework-specific information.
|
39
38
|
simd_groups_indices: A dictionary mapping each node to its SIMD groups' indices.
|
40
39
|
"""
|
41
40
|
# Initialize the per-channel mask
|
42
|
-
self.per_channel_mask = PerChannelMask(prunable_nodes=prunable_nodes
|
41
|
+
self.per_channel_mask = PerChannelMask(prunable_nodes=prunable_nodes)
|
43
42
|
self.prunable_nodes = prunable_nodes
|
44
|
-
self.fw_info = fw_info
|
45
43
|
self.simd_groups_indices = simd_groups_indices
|
46
44
|
self._mask_simd = None # Initialize the SIMD group mask dictionary
|
47
45
|
self._init_masks() # Initialize masks for each prunable node
|