mct-nightly 2.4.0.20250629.706__py3-none-any.whl → 2.4.0.20250701.185106__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/METADATA +16 -16
- {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/RECORD +75 -72
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -1
- model_compression_toolkit/core/common/framework_info.py +5 -32
- model_compression_toolkit/core/common/fusion/graph_fuser.py +12 -9
- model_compression_toolkit/core/common/graph/base_graph.py +20 -37
- model_compression_toolkit/core/common/graph/base_node.py +13 -106
- model_compression_toolkit/core/common/graph/functional_node.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +12 -10
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +14 -9
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +9 -15
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +2 -3
- model_compression_toolkit/core/common/network_editors/__init__.py +8 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -96
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +116 -56
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +55 -179
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +21 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +8 -5
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -70
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +10 -12
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +54 -30
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +93 -398
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +2 -5
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -4
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -6
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +12 -6
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -2
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +33 -33
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +2 -4
- model_compression_toolkit/core/graph_prep_runner.py +31 -20
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +5 -2
- model_compression_toolkit/core/keras/default_framework_info.py +0 -11
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +9 -6
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +3 -1
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +2 -1
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +47 -0
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +5 -2
- model_compression_toolkit/core/pytorch/default_framework_info.py +0 -12
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +2 -0
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +2 -1
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +45 -0
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +3 -2
- model_compression_toolkit/core/runner.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +7 -3
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +12 -3
- model_compression_toolkit/pruning/keras/pruning_facade.py +5 -9
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -5
- model_compression_toolkit/ptq/keras/quantization_facade.py +1 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +1 -1
- model_compression_toolkit/quantization_preparation/__init__.py +14 -0
- model_compression_toolkit/quantization_preparation/load_fqc.py +223 -0
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -78
- {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/lut_fake_quant.py +0 -0
|
@@ -14,12 +14,12 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
import copy
|
|
17
|
-
from typing import
|
|
17
|
+
from typing import Tuple
|
|
18
18
|
|
|
19
19
|
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator
|
|
20
20
|
from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
|
|
21
|
-
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import
|
|
22
|
-
|
|
21
|
+
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
|
22
|
+
CandidateNodeQuantizationConfig, NodeQuantizationConfig
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class FusedLayerType:
|
|
@@ -30,6 +30,7 @@ class FusedLayerType:
|
|
|
30
30
|
def __init__(self):
|
|
31
31
|
self.__name__ = 'FusedLayer'
|
|
32
32
|
|
|
33
|
+
|
|
33
34
|
class GraphFuser:
|
|
34
35
|
def apply_node_fusion(self, graph: Graph) -> Graph:
|
|
35
36
|
"""
|
|
@@ -64,7 +65,6 @@ class GraphFuser:
|
|
|
64
65
|
|
|
65
66
|
return graph_copy
|
|
66
67
|
|
|
67
|
-
|
|
68
68
|
@staticmethod
|
|
69
69
|
def _create_fused_node(fused_node_id: str, nodes: Tuple[BaseNode]) -> BaseNode:
|
|
70
70
|
"""
|
|
@@ -86,10 +86,15 @@ class GraphFuser:
|
|
|
86
86
|
weights={},
|
|
87
87
|
layer_class=FusedLayerType)
|
|
88
88
|
|
|
89
|
+
base_cfg = CandidateNodeQuantizationConfig(
|
|
90
|
+
activation_quantization_cfg=nodes[-1].quantization_cfg.base_quantization_cfg.activation_quantization_cfg,
|
|
91
|
+
weights_quantization_cfg=None
|
|
92
|
+
)
|
|
89
93
|
activation_cfgs = [c.activation_quantization_cfg for c in nodes[-1].candidates_quantization_cfg]
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
94
|
+
candidates = [CandidateNodeQuantizationConfig(weights_quantization_cfg=None, activation_quantization_cfg=a)
|
|
95
|
+
for a in activation_cfgs]
|
|
96
|
+
fused_node.quantization_cfg = NodeQuantizationConfig(base_quantization_cfg=base_cfg,
|
|
97
|
+
candidates_quantization_cfg=candidates)
|
|
93
98
|
|
|
94
99
|
# Keep the final configurations if they were set already.
|
|
95
100
|
fused_node.final_weights_quantization_cfg = nodes[0].final_weights_quantization_cfg
|
|
@@ -158,5 +163,3 @@ class GraphFuser:
|
|
|
158
163
|
|
|
159
164
|
# Finally, add the new fused node to the graph
|
|
160
165
|
graph.add_node(fused_node)
|
|
161
|
-
|
|
162
|
-
|
|
@@ -39,6 +39,7 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
|
|
|
39
39
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
|
40
40
|
FrameworkQuantizationCapabilities
|
|
41
41
|
|
|
42
|
+
|
|
42
43
|
def validate_graph_after_change(method: Callable) -> Callable:
|
|
43
44
|
"""
|
|
44
45
|
Decorator for graph-mutating methods. After the decorated method executes,
|
|
@@ -120,28 +121,13 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
120
121
|
def fusing_info(self, fusing_info: FusingInfo):
|
|
121
122
|
self._fusing_info = fusing_info
|
|
122
123
|
|
|
123
|
-
def set_fqc(self,
|
|
124
|
-
fqc: FrameworkQuantizationCapabilities):
|
|
124
|
+
def set_fqc(self, fqc: FrameworkQuantizationCapabilities):
|
|
125
125
|
"""
|
|
126
126
|
Set the graph's FQC.
|
|
127
127
|
Args:
|
|
128
128
|
fqc: FrameworkQuantizationCapabilities object.
|
|
129
129
|
"""
|
|
130
|
-
#
|
|
131
|
-
# Validate graph nodes are either built-in layers from the framework or custom layers defined in the FQC
|
|
132
|
-
fqc_layers = fqc.op_sets_to_layers.get_layers()
|
|
133
|
-
fqc_filtered_layers = [layer for layer in fqc_layers if isinstance(layer, LayerFilterParams)]
|
|
134
|
-
for n in self.nodes:
|
|
135
|
-
is_node_in_fqc = any([n.is_match_type(_type) for _type in fqc_layers]) or \
|
|
136
|
-
any([n.is_match_filter_params(filtered_layer) for filtered_layer in fqc_filtered_layers])
|
|
137
|
-
if n.is_custom:
|
|
138
|
-
if not is_node_in_fqc:
|
|
139
|
-
Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. '
|
|
140
|
-
' Please add the custom layer to Framework Quantization Capabilities (FQC), or file a feature '
|
|
141
|
-
'request or an issue if you believe this should be supported.') # pragma: no cover
|
|
142
|
-
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(fqc).quantization_configurations]):
|
|
143
|
-
Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') # pragma: no cover
|
|
144
|
-
|
|
130
|
+
# TODO irena: this is only passed for negative shift activation.
|
|
145
131
|
self.fqc = fqc
|
|
146
132
|
|
|
147
133
|
def get_topo_sorted_nodes(self):
|
|
@@ -578,7 +564,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
578
564
|
A list of nodes that their weights can be configured (namely, has one or more weight qc candidate).
|
|
579
565
|
"""
|
|
580
566
|
# configurability is only relevant for kernel attribute quantization
|
|
581
|
-
potential_conf_nodes = [n for n in
|
|
567
|
+
potential_conf_nodes = [n for n in self.nodes if n.kernel_attr]
|
|
582
568
|
|
|
583
569
|
def is_configurable(n):
|
|
584
570
|
return n.is_configurable_weight(n.kernel_attr) and (not n.reuse or include_reused_nodes)
|
|
@@ -693,10 +679,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
693
679
|
"""
|
|
694
680
|
Gets the final number of bits for quantization of each weights' configurable layer.
|
|
695
681
|
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
Returns: A list of pairs of (node type, node's weights quantization bitwidth).
|
|
682
|
+
Returns:
|
|
683
|
+
A list of pairs of (node type, node's weights quantization bitwidth).
|
|
700
684
|
|
|
701
685
|
"""
|
|
702
686
|
sorted_conf_weights = self.get_sorted_weights_configurable_nodes()
|
|
@@ -876,32 +860,31 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
876
860
|
|
|
877
861
|
return intermediate_nodes, next_node
|
|
878
862
|
|
|
863
|
+
# TODO irena move to load_fqc and clean up tests (currently tests_pytest/common_tests/unit_tests/core/graph/test_base_graph.py)
|
|
879
864
|
def override_fused_node_activation_quantization_candidates(self):
|
|
880
865
|
"""
|
|
881
866
|
Override fused node activation quantization candidates for all nodes in fused operations,
|
|
882
867
|
except for the last node in each fused group.
|
|
883
868
|
Update the value of quantization_config with the value of op_quaitization_cfg from FusingInfo.
|
|
884
869
|
"""
|
|
885
|
-
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig
|
|
886
|
-
|
|
887
870
|
nodes_in_fln = self.fusing_info.get_inner_fln_nodes()
|
|
888
871
|
for node in nodes_in_fln:
|
|
889
872
|
fused_node_op_id = self.fusing_info.get_fused_op_id_for_node(node.name)
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
activation_quantization_fn=org_candidate.activation_quantization_cfg.activation_quantization_fn,
|
|
897
|
-
activation_quantization_params_fn=org_candidate.activation_quantization_cfg.activation_quantization_params_fn)
|
|
898
|
-
activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
|
|
899
|
-
for qc in node.candidates_quantization_cfg:
|
|
900
|
-
qc.activation_quantization_cfg = activation_quantization_cfg
|
|
873
|
+
fusing_op_quantization_cfg = self.fusing_info.get_fused_op_quantization_config(fused_node_op_id)
|
|
874
|
+
if fusing_op_quantization_cfg is not None and fusing_op_quantization_cfg.enable_activation_quantization:
|
|
875
|
+
def update(qc):
|
|
876
|
+
qc.activation_quantization_cfg = NodeActivationQuantizationConfig(fusing_op_quantization_cfg)
|
|
877
|
+
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
|
|
878
|
+
node.quantization_cfg.update_all(update, remove_duplicates=True)
|
|
901
879
|
else:
|
|
902
|
-
|
|
880
|
+
node.quantization_cfg.update_activation_quantization_mode(ActivationQuantizationMode.FLN_NO_QUANT)
|
|
881
|
+
# Remove duplicate candidates. We cannot compare whole candidates since activation configs might not
|
|
882
|
+
# be identical, but we do want to treat them as such. So we only check duplication by weight configs.
|
|
883
|
+
uniq_qcs = []
|
|
903
884
|
for qc in node.candidates_quantization_cfg:
|
|
904
|
-
qc.
|
|
885
|
+
if not any(qc.weights_quantization_cfg == uqc.weights_quantization_cfg for uqc in uniq_qcs):
|
|
886
|
+
uniq_qcs.append(qc)
|
|
887
|
+
node.quantization_cfg.candidates_quantization_cfg = uniq_qcs
|
|
905
888
|
|
|
906
889
|
def validate(self):
|
|
907
890
|
"""
|
|
@@ -21,15 +21,11 @@ import numpy as np
|
|
|
21
21
|
from model_compression_toolkit.core.common.framework_info import get_fw_info, ChannelAxisMapping
|
|
22
22
|
from model_compression_toolkit.constants import WEIGHTS_NBITS_ATTRIBUTE, CORRECTED_BIAS_ATTRIBUTE, \
|
|
23
23
|
ACTIVATION_N_BITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER
|
|
24
|
+
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import NodeQuantizationConfig
|
|
24
25
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
|
|
25
26
|
ActivationQuantizationMode
|
|
26
27
|
from model_compression_toolkit.logger import Logger
|
|
27
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \
|
|
28
|
-
OpQuantizationConfig
|
|
29
|
-
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
|
|
30
28
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
|
|
31
|
-
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
|
32
|
-
FrameworkQuantizationCapabilities
|
|
33
29
|
|
|
34
30
|
|
|
35
31
|
WeightAttrT = Union[str, int]
|
|
@@ -43,7 +39,6 @@ class NodeFrameworkInfo(NamedTuple):
|
|
|
43
39
|
out_channel_axis: int
|
|
44
40
|
minmax: Tuple[float, float]
|
|
45
41
|
kernel_attr: str
|
|
46
|
-
is_kernel_op: bool
|
|
47
42
|
|
|
48
43
|
|
|
49
44
|
class BaseNode:
|
|
@@ -95,7 +90,7 @@ class BaseNode:
|
|
|
95
90
|
self.inputs_as_list = inputs_as_list
|
|
96
91
|
self.final_weights_quantization_cfg = None
|
|
97
92
|
self.final_activation_quantization_cfg = None
|
|
98
|
-
self.
|
|
93
|
+
self.quantization_cfg: NodeQuantizationConfig = None
|
|
99
94
|
self.prior_info = None
|
|
100
95
|
self.has_activation = has_activation
|
|
101
96
|
self.is_custom = is_custom
|
|
@@ -108,7 +103,6 @@ class BaseNode:
|
|
|
108
103
|
fw_info.get_out_channel_axis(node_type),
|
|
109
104
|
fw_info.get_layer_min_max(node_type, framework_attr),
|
|
110
105
|
fw_info.get_kernel_op_attribute(node_type),
|
|
111
|
-
fw_info.is_kernel_op(node_type)
|
|
112
106
|
)
|
|
113
107
|
|
|
114
108
|
def _assert_fw_info_exists(self):
|
|
@@ -162,15 +156,9 @@ class BaseNode:
|
|
|
162
156
|
return self.node_fw_info.kernel_attr
|
|
163
157
|
|
|
164
158
|
@property
|
|
165
|
-
def
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
Returns:
|
|
170
|
-
Whether the node has a kernel or not.
|
|
171
|
-
"""
|
|
172
|
-
self._assert_fw_info_exists()
|
|
173
|
-
return self.node_fw_info.is_kernel_op
|
|
159
|
+
def candidates_quantization_cfg(self):
|
|
160
|
+
assert self.quantization_cfg
|
|
161
|
+
return self.quantization_cfg.candidates_quantization_cfg
|
|
174
162
|
|
|
175
163
|
@property
|
|
176
164
|
def type(self):
|
|
@@ -181,15 +169,6 @@ class BaseNode:
|
|
|
181
169
|
"""
|
|
182
170
|
return self.layer_class
|
|
183
171
|
|
|
184
|
-
def get_has_activation(self):
|
|
185
|
-
"""
|
|
186
|
-
Returns has_activation attribute.
|
|
187
|
-
|
|
188
|
-
Returns: Whether the node has activation to quantize.
|
|
189
|
-
|
|
190
|
-
"""
|
|
191
|
-
return self.has_activation
|
|
192
|
-
|
|
193
172
|
@property
|
|
194
173
|
def has_positional_weights(self):
|
|
195
174
|
"""
|
|
@@ -646,8 +625,9 @@ class BaseNode:
|
|
|
646
625
|
Returns: True if the node has at list one quantization configuration candidate with activation quantization enabled.
|
|
647
626
|
"""
|
|
648
627
|
|
|
649
|
-
return len(self.candidates_quantization_cfg) > 0 and
|
|
650
|
-
|
|
628
|
+
return (len(self.candidates_quantization_cfg) > 0 and
|
|
629
|
+
any([c.activation_quantization_cfg.enable_activation_quantization
|
|
630
|
+
for c in self.candidates_quantization_cfg]))
|
|
651
631
|
|
|
652
632
|
def get_all_weights_attr_candidates(self, attr: str) -> List[WeightsAttrQuantizationConfig]:
|
|
653
633
|
"""
|
|
@@ -663,79 +643,6 @@ class BaseNode:
|
|
|
663
643
|
# the inner method would log an exception.
|
|
664
644
|
return [c.weights_quantization_cfg.get_attr_config(attr) for c in self.candidates_quantization_cfg]
|
|
665
645
|
|
|
666
|
-
def get_qco(self, fqc: FrameworkQuantizationCapabilities) -> QuantizationConfigOptions:
|
|
667
|
-
"""
|
|
668
|
-
Get the QuantizationConfigOptions of the node according
|
|
669
|
-
to the mappings from layers/LayerFilterParams to the OperatorsSet in the TargetPlatformCapabilities.
|
|
670
|
-
|
|
671
|
-
Args:
|
|
672
|
-
fqc: FQC to extract the QuantizationConfigOptions for the node.
|
|
673
|
-
|
|
674
|
-
Returns:
|
|
675
|
-
QuantizationConfigOptions of the node.
|
|
676
|
-
"""
|
|
677
|
-
|
|
678
|
-
if fqc is None:
|
|
679
|
-
Logger.critical(f'Can not retrieve QC options for None FQC') # pragma: no cover
|
|
680
|
-
|
|
681
|
-
for fl, qco in fqc.filterlayer2qco.items():
|
|
682
|
-
if self.is_match_filter_params(fl):
|
|
683
|
-
return qco
|
|
684
|
-
# Extract qco with is_match_type to overcome mismatch of function types in TF 2.15
|
|
685
|
-
matching_qcos = [_qco for _type, _qco in fqc.layer2qco.items() if self.is_match_type(_type)]
|
|
686
|
-
if matching_qcos:
|
|
687
|
-
if all([_qco == matching_qcos[0] for _qco in matching_qcos]):
|
|
688
|
-
return matching_qcos[0]
|
|
689
|
-
else:
|
|
690
|
-
Logger.critical(f"Found duplicate qco types for node '{self.name}' of type '{self.type}'!") # pragma: no cover
|
|
691
|
-
return fqc.tpc.default_qco
|
|
692
|
-
|
|
693
|
-
def filter_node_qco_by_graph(self, fqc: FrameworkQuantizationCapabilities,
|
|
694
|
-
next_nodes: List, node_qc_options: QuantizationConfigOptions
|
|
695
|
-
) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
|
|
696
|
-
"""
|
|
697
|
-
Filter quantization config options that don't match the graph.
|
|
698
|
-
A node may have several quantization config options with 'activation_n_bits' values, and
|
|
699
|
-
the next nodes in the graph may support different bit-width as input activation. This function
|
|
700
|
-
filters out quantization config that don't comply to these attributes.
|
|
701
|
-
|
|
702
|
-
Args:
|
|
703
|
-
fqc: FQC to extract the QuantizationConfigOptions for the next nodes.
|
|
704
|
-
next_nodes: Output nodes of current node.
|
|
705
|
-
node_qc_options: Node's QuantizationConfigOptions.
|
|
706
|
-
|
|
707
|
-
Returns:
|
|
708
|
-
|
|
709
|
-
"""
|
|
710
|
-
# Filter quantization config options that don't match the graph.
|
|
711
|
-
_base_config = node_qc_options.base_config
|
|
712
|
-
_node_qc_options = node_qc_options.quantization_configurations
|
|
713
|
-
if len(next_nodes):
|
|
714
|
-
next_nodes_qc_options = [_node.get_qco(fqc) for _node in next_nodes]
|
|
715
|
-
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
|
|
716
|
-
for qc_opts in next_nodes_qc_options
|
|
717
|
-
for op_cfg in qc_opts.quantization_configurations])
|
|
718
|
-
|
|
719
|
-
# Filter node's QC options that match next nodes input bit-width.
|
|
720
|
-
_node_qc_options = [_option for _option in _node_qc_options
|
|
721
|
-
if _option.activation_n_bits <= next_nodes_supported_input_bitwidth]
|
|
722
|
-
if len(_node_qc_options) == 0:
|
|
723
|
-
Logger.critical(f"Graph doesn't match FQC bit configurations: {self} -> {next_nodes}.") # pragma: no cover
|
|
724
|
-
|
|
725
|
-
# Verify base config match
|
|
726
|
-
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
|
|
727
|
-
for qc_opt in next_nodes_qc_options]):
|
|
728
|
-
# base_config activation bits doesn't match next node supported input bit-width -> replace with
|
|
729
|
-
# a qco from quantization_configurations with maximum activation bit-width.
|
|
730
|
-
if len(_node_qc_options) > 0:
|
|
731
|
-
output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
|
|
732
|
-
_base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
|
|
733
|
-
Logger.warning(f"Node {self} base quantization config changed to match Graph and FQC configuration.\nCause: {self} -> {next_nodes}.")
|
|
734
|
-
else:
|
|
735
|
-
Logger.critical(f"Graph doesn't match FQC bit configurations: {self} -> {next_nodes}.") # pragma: no cover
|
|
736
|
-
|
|
737
|
-
return _base_config, _node_qc_options
|
|
738
|
-
|
|
739
646
|
def is_match_type(self, _type: Type) -> bool:
|
|
740
647
|
"""
|
|
741
648
|
Check if input type matches the node type, either in instance type or in type name.
|
|
@@ -768,7 +675,7 @@ class BaseNode:
|
|
|
768
675
|
return False
|
|
769
676
|
|
|
770
677
|
# Get attributes from node to filter
|
|
771
|
-
layer_config = self.framework_attr
|
|
678
|
+
layer_config = self.framework_attr.copy()
|
|
772
679
|
if hasattr(self, "op_call_kwargs"):
|
|
773
680
|
layer_config.update(self.op_call_kwargs)
|
|
774
681
|
|
|
@@ -812,11 +719,11 @@ class BaseNode:
|
|
|
812
719
|
the candidates in descending order.
|
|
813
720
|
The operation is done inplace.
|
|
814
721
|
"""
|
|
815
|
-
if self.candidates_quantization_cfg is not None:
|
|
722
|
+
if self.quantization_cfg.candidates_quantization_cfg is not None:
|
|
816
723
|
if self.kernel_attr is not None:
|
|
817
|
-
self.candidates_quantization_cfg.sort(
|
|
724
|
+
self.quantization_cfg.candidates_quantization_cfg.sort(
|
|
818
725
|
key=lambda c: (c.weights_quantization_cfg.get_attr_config(self.kernel_attr).weights_n_bits,
|
|
819
726
|
c.activation_quantization_cfg.activation_n_bits), reverse=True)
|
|
820
727
|
else:
|
|
821
|
-
self.candidates_quantization_cfg.sort(
|
|
822
|
-
|
|
728
|
+
self.quantization_cfg.candidates_quantization_cfg.sort(
|
|
729
|
+
key=lambda c: c.activation_quantization_cfg.activation_n_bits, reverse=True)
|
|
@@ -19,9 +19,8 @@ from model_compression_toolkit.constants import VIRTUAL_ACTIVATION_WEIGHTS_NODE_
|
|
|
19
19
|
VIRTUAL_WEIGHTS_SUFFIX, VIRTUAL_ACTIVATION_SUFFIX, FLOAT_BITWIDTH
|
|
20
20
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
21
21
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
|
22
|
-
CandidateNodeQuantizationConfig
|
|
22
|
+
CandidateNodeQuantizationConfig, NodeQuantizationConfig
|
|
23
23
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
|
|
24
|
-
from model_compression_toolkit.core.common.framework_info import DEFAULT_KERNEL_ATTRIBUTE
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
class VirtualNode(BaseNode, abc.ABC):
|
|
@@ -76,8 +75,11 @@ class VirtualSplitWeightsNode(VirtualSplitNode):
|
|
|
76
75
|
|
|
77
76
|
self.name = origin_node.name + VIRTUAL_WEIGHTS_SUFFIX
|
|
78
77
|
|
|
79
|
-
self.
|
|
80
|
-
|
|
78
|
+
self.quantization_cfg = NodeQuantizationConfig(
|
|
79
|
+
candidates_quantization_cfg=origin_node.get_unique_weights_candidates(kernel_attr),
|
|
80
|
+
base_quantization_cfg=None, validate=False
|
|
81
|
+
)
|
|
82
|
+
for c in self.quantization_cfg.candidates_quantization_cfg:
|
|
81
83
|
c.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
|
82
84
|
c.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH
|
|
83
85
|
|
|
@@ -106,10 +108,9 @@ class VirtualSplitActivationNode(VirtualSplitNode):
|
|
|
106
108
|
self.weights = {}
|
|
107
109
|
self.layer_class = activation_class
|
|
108
110
|
|
|
109
|
-
self.
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
c.weights_quantization_cfg.weights_n_bits = FLOAT_BITWIDTH
|
|
111
|
+
self.quantization_cfg = NodeQuantizationConfig(candidates_quantization_cfg=origin_node.get_unique_activation_candidates(),
|
|
112
|
+
base_quantization_cfg=None, validate=False)
|
|
113
|
+
self.quantization_cfg.disable_weights_quantization()
|
|
113
114
|
|
|
114
115
|
|
|
115
116
|
class VirtualActivationWeightsNode(VirtualNode):
|
|
@@ -143,7 +144,7 @@ class VirtualActivationWeightsNode(VirtualNode):
|
|
|
143
144
|
weights = weights_node.weights.copy()
|
|
144
145
|
act_node_w_rename = {}
|
|
145
146
|
if act_node.weights:
|
|
146
|
-
if act_node.kernel_attr
|
|
147
|
+
if act_node.kernel_attr:
|
|
147
148
|
raise NotImplementedError(f'Node {act_node} with kernel cannot be used as activation for '
|
|
148
149
|
f'VirtualActivationWeightsNode.')
|
|
149
150
|
if act_node.has_any_configurable_weight():
|
|
@@ -200,4 +201,5 @@ class VirtualActivationWeightsNode(VirtualNode):
|
|
|
200
201
|
v_candidates.sort(key=lambda c: (c.weights_quantization_cfg.get_attr_config(weights_node.kernel_attr).weights_n_bits,
|
|
201
202
|
c.activation_quantization_cfg.activation_n_bits), reverse=True)
|
|
202
203
|
|
|
203
|
-
self.
|
|
204
|
+
self.quantization_cfg = NodeQuantizationConfig(candidates_quantization_cfg=v_candidates,
|
|
205
|
+
base_quantization_cfg=None, validate=False)
|
|
@@ -18,6 +18,8 @@ import numpy as np
|
|
|
18
18
|
|
|
19
19
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
|
20
20
|
CandidateNodeQuantizationConfig
|
|
21
|
+
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import (get_activation_quantization_fn,
|
|
22
|
+
get_weights_quantization_fn)
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
def verify_candidates_descending_order(node_q_cfg: List[CandidateNodeQuantizationConfig],
|
|
@@ -77,20 +79,21 @@ def init_quantized_weights(node_q_cfg: List[CandidateNodeQuantizationConfig],
|
|
|
77
79
|
quantized_weights = []
|
|
78
80
|
for qc in node_q_cfg:
|
|
79
81
|
qc_weights_attr = qc.weights_quantization_cfg.get_attr_config(kernel_attr)
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
82
|
+
weights_quantization_fn = get_weights_quantization_fn(qc_weights_attr.weights_quantization_method)
|
|
83
|
+
q_weight = weights_quantization_fn(float_weights,
|
|
84
|
+
qc_weights_attr.weights_n_bits,
|
|
85
|
+
True,
|
|
86
|
+
qc_weights_attr.weights_quantization_params,
|
|
87
|
+
qc_weights_attr.weights_per_channel_threshold,
|
|
88
|
+
qc_weights_attr.weights_channels_axis[0]) # output channel axis
|
|
87
89
|
|
|
88
90
|
quantized_weights.append(fw_tensor_convert_func(q_weight))
|
|
89
91
|
|
|
90
92
|
return quantized_weights
|
|
91
93
|
|
|
92
94
|
|
|
93
|
-
def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]
|
|
95
|
+
def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig],
|
|
96
|
+
get_activation_quantization_fn_factory: Callable) -> List:
|
|
94
97
|
"""
|
|
95
98
|
Builds a list of quantizers for each of the bitwidth candidates for activation quantization,
|
|
96
99
|
to be stored and used during MP search.
|
|
@@ -98,6 +101,7 @@ def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]
|
|
|
98
101
|
Args:
|
|
99
102
|
node_q_cfg: Quantization configuration candidates of the node that generated the layer that will
|
|
100
103
|
use this quantizer.
|
|
104
|
+
get_activation_quantization_fn_factory: activation quantization functions factory.
|
|
101
105
|
|
|
102
106
|
Returns: a list of activation quantizers - for each bitwidth and layer's attribute to be quantized.
|
|
103
107
|
"""
|
|
@@ -105,6 +109,7 @@ def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]
|
|
|
105
109
|
activation_quantizers = []
|
|
106
110
|
for index, qc in enumerate(node_q_cfg):
|
|
107
111
|
q_activation = node_q_cfg[index].activation_quantization_cfg
|
|
108
|
-
|
|
112
|
+
quantizer = get_activation_quantization_fn(q_activation, get_activation_quantization_fn_factory)
|
|
113
|
+
activation_quantizers.append(quantizer)
|
|
109
114
|
|
|
110
115
|
return activation_quantizers
|
|
@@ -12,17 +12,12 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
from model_compression_toolkit.core import ResourceUtilization, FrameworkInfo
|
|
15
|
+
from model_compression_toolkit.core import ResourceUtilization
|
|
18
16
|
from model_compression_toolkit.core.common import Graph
|
|
19
|
-
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
|
20
|
-
FrameworkQuantizationCapabilities
|
|
21
17
|
|
|
22
18
|
|
|
23
19
|
def filter_candidates_for_mixed_precision(graph: Graph,
|
|
24
|
-
target_resource_utilization: ResourceUtilization
|
|
25
|
-
fqc: FrameworkQuantizationCapabilities):
|
|
20
|
+
target_resource_utilization: ResourceUtilization):
|
|
26
21
|
"""
|
|
27
22
|
Filters out candidates in case of mixed precision search for only weights or activation compression.
|
|
28
23
|
For instance, if running only weights compression - filters out candidates of activation configurable nodes
|
|
@@ -34,8 +29,6 @@ def filter_candidates_for_mixed_precision(graph: Graph,
|
|
|
34
29
|
Args:
|
|
35
30
|
graph: A graph representation of the model to be quantized.
|
|
36
31
|
target_resource_utilization: The resource utilization of the target device.
|
|
37
|
-
fqc: FrameworkQuantizationCapabilities object that describes the desired inference target platform.
|
|
38
|
-
|
|
39
32
|
"""
|
|
40
33
|
|
|
41
34
|
tru = target_resource_utilization
|
|
@@ -47,20 +40,21 @@ def filter_candidates_for_mixed_precision(graph: Graph,
|
|
|
47
40
|
# filter out candidates activation only configurable node
|
|
48
41
|
activation_configurable_nodes = [n for n in graph.get_activation_configurable_nodes()]
|
|
49
42
|
for n in activation_configurable_nodes:
|
|
50
|
-
base_cfg_nbits = n.
|
|
51
|
-
|
|
43
|
+
base_cfg_nbits = n.quantization_cfg.base_quantization_cfg.activation_quantization_cfg.activation_n_bits
|
|
44
|
+
filtered_cfgs = [c for c in n.candidates_quantization_cfg if
|
|
52
45
|
c.activation_quantization_cfg.enable_activation_quantization and
|
|
53
46
|
c.activation_quantization_cfg.activation_n_bits == base_cfg_nbits]
|
|
54
47
|
|
|
55
|
-
n.candidates_quantization_cfg =
|
|
48
|
+
n.quantization_cfg.candidates_quantization_cfg = filtered_cfgs
|
|
56
49
|
|
|
57
50
|
elif tru.activation_restricted() and not tru.weight_restricted():
|
|
58
51
|
# Running mixed precision for activation compression only -
|
|
59
52
|
# filter out candidates weights only configurable node
|
|
60
53
|
weight_configurable_nodes = [n for n in graph.get_weights_configurable_nodes()]
|
|
61
54
|
for n in weight_configurable_nodes:
|
|
62
|
-
base_cfg_nbits = n.
|
|
63
|
-
|
|
55
|
+
base_cfg_nbits = (n.quantization_cfg.base_quantization_cfg.weights_quantization_cfg.
|
|
56
|
+
get_attr_config(n.kernel_attr).weights_n_bits)
|
|
57
|
+
filtered_cfgs = [c for c in n.candidates_quantization_cfg if
|
|
64
58
|
c.weights_quantization_cfg.get_attr_config(n.kernel_attr).enable_weights_quantization and
|
|
65
59
|
c.weights_quantization_cfg.get_attr_config(n.kernel_attr).weights_n_bits == base_cfg_nbits]
|
|
66
|
-
n.candidates_quantization_cfg =
|
|
60
|
+
n.quantization_cfg.candidates_quantization_cfg = filtered_cfgs
|
model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py
CHANGED
|
@@ -392,9 +392,8 @@ class DistanceMetricCalculator(MetricCalculator):
|
|
|
392
392
|
"""
|
|
393
393
|
|
|
394
394
|
return [n.node for n in graph.get_outputs()
|
|
395
|
-
if (n.node.
|
|
396
|
-
|
|
397
|
-
n.node.is_activation_quantization_enabled()]
|
|
395
|
+
if (n.node.kernel_attr and n.node.is_weights_quantization_enabled(n.node.kernel_attr))
|
|
396
|
+
or n.node.is_activation_quantization_enabled()]
|
|
398
397
|
|
|
399
398
|
@staticmethod
|
|
400
399
|
def bound_num_interest_points(sorted_ip_list: List[BaseNode], num_ip_factor: float) -> List[BaseNode]:
|
|
@@ -13,7 +13,14 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from model_compression_toolkit.core.common.network_editors.actions import
|
|
16
|
+
from model_compression_toolkit.core.common.network_editors.actions import (
|
|
17
|
+
ChangeCandidatesWeightsQuantConfigAttr,
|
|
18
|
+
ChangeFinalWeightsQuantConfigAttr,
|
|
19
|
+
ChangeCandidatesActivationQuantConfigAttr,
|
|
20
|
+
ChangeCandidatesActivationQuantizationMethod,
|
|
21
|
+
ChangeFinalWeightsQuantizationMethod,
|
|
22
|
+
ChangeCandidatesWeightsQuantizationMethod,
|
|
23
|
+
ChangeFinalActivationQuantConfigAttr)
|
|
17
24
|
from model_compression_toolkit.core.common.network_editors.actions import EditRule
|
|
18
25
|
from model_compression_toolkit.core.common.network_editors.node_filters import NodeTypeFilter, NodeNameScopeFilter, \
|
|
19
26
|
NodeNameFilter
|