mct-nightly 2.4.0.20250630.629__py3-none-any.whl → 2.4.0.20250702.605__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/METADATA +16 -16
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/RECORD +75 -72
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -1
- model_compression_toolkit/core/common/framework_info.py +5 -32
- model_compression_toolkit/core/common/fusion/graph_fuser.py +12 -9
- model_compression_toolkit/core/common/graph/base_graph.py +20 -37
- model_compression_toolkit/core/common/graph/base_node.py +13 -106
- model_compression_toolkit/core/common/graph/functional_node.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +12 -10
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +14 -9
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +9 -15
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +2 -3
- model_compression_toolkit/core/common/network_editors/__init__.py +8 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -96
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +116 -56
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +55 -179
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +21 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +8 -5
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -70
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +10 -12
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +54 -30
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +93 -398
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +2 -5
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -4
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -6
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +12 -6
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -2
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +33 -33
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +2 -4
- model_compression_toolkit/core/graph_prep_runner.py +31 -20
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +5 -2
- model_compression_toolkit/core/keras/default_framework_info.py +0 -11
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +9 -6
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +3 -1
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +2 -1
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +47 -0
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +5 -2
- model_compression_toolkit/core/pytorch/default_framework_info.py +0 -12
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +2 -0
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +2 -1
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +45 -0
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +3 -2
- model_compression_toolkit/core/runner.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +7 -3
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +12 -3
- model_compression_toolkit/pruning/keras/pruning_facade.py +5 -9
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -5
- model_compression_toolkit/ptq/keras/quantization_facade.py +1 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +1 -1
- model_compression_toolkit/quantization_preparation/__init__.py +14 -0
- model_compression_toolkit/quantization_preparation/load_fqc.py +223 -0
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -78
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250702.605.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/lut_fake_quant.py +0 -0
|
@@ -12,73 +12,22 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
import
|
|
16
|
-
from typing import List, Tuple, Dict, Optional
|
|
15
|
+
from typing import List, Tuple, Dict
|
|
17
16
|
|
|
18
|
-
from mct_quantizers.common.constants import WEIGHTS_N_BITS, ACTIVATION_N_BITS
|
|
19
|
-
from model_compression_toolkit.constants import WEIGHTS, ACTIVATION
|
|
20
17
|
from model_compression_toolkit.core.common import BaseNode
|
|
18
|
+
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
21
19
|
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
|
|
20
|
+
from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
|
|
22
21
|
from model_compression_toolkit.logger import Logger
|
|
23
|
-
from model_compression_toolkit.
|
|
24
|
-
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
25
|
-
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
|
26
|
-
CandidateNodeQuantizationConfig
|
|
27
|
-
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig, \
|
|
28
|
-
ActivationQuantizationMode
|
|
29
|
-
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
|
|
30
|
-
QuantizationErrorMethod
|
|
31
|
-
from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
|
|
32
|
-
get_activation_quantization_params_fn
|
|
33
|
-
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR
|
|
34
23
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
|
|
35
24
|
QuantizationConfigOptions
|
|
25
|
+
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
|
|
36
26
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
|
37
27
|
FrameworkQuantizationCapabilities
|
|
38
|
-
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def set_quantization_configuration_to_graph(graph: Graph,
|
|
42
|
-
quant_config: QuantizationConfig,
|
|
43
|
-
bit_width_config: BitWidthConfig = None,
|
|
44
|
-
mixed_precision_enable: bool = False,
|
|
45
|
-
running_gptq: bool = False) -> Graph:
|
|
46
|
-
"""
|
|
47
|
-
Add quantization configuration for each graph node.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
graph (Graph): Graph for which to add quantization info to each node.
|
|
51
|
-
quant_config (QuantizationConfig): Quantization configuration containing parameters for how the graph should be quantized.
|
|
52
|
-
bit_width_config (BitWidthConfig): Configuration for manual bit width selection. Defaults to None.
|
|
53
|
-
mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
|
|
54
|
-
running_gptq (bool): Whether or not a GPTQ optimization is planned to run after the PTQ process. Defaults to False.
|
|
55
|
-
|
|
56
|
-
Returns:
|
|
57
|
-
Graph: The graph with quantization configurations attached to each node in it.
|
|
58
|
-
"""
|
|
59
|
-
|
|
60
|
-
if quant_config.weights_error_method == QuantizationErrorMethod.HMSE:
|
|
61
|
-
if not running_gptq:
|
|
62
|
-
raise ValueError(f"The HMSE error method for parameters selection is only supported when running GPTQ "
|
|
63
|
-
f"optimization due to long execution time that is not suitable for basic PTQ.")
|
|
64
|
-
Logger.warning("Using the HMSE error method for weights quantization parameters search. "
|
|
65
|
-
"Note: This method may significantly increase runtime during the parameter search process.")
|
|
66
|
-
|
|
67
|
-
nodes_to_manipulate_activation_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_activation_bit_widths(graph)
|
|
68
|
-
nodes_to_manipulate_weights_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_weights_bit_widths(graph)
|
|
69
|
-
|
|
70
|
-
for n in graph.get_topo_sorted_nodes():
|
|
71
|
-
manual_bit_width_override = {ACTIVATION: nodes_to_manipulate_activation_bit_widths.get(n),
|
|
72
|
-
WEIGHTS: nodes_to_manipulate_weights_bit_widths.get(n)}
|
|
73
|
-
set_quantization_configs_to_node(node=n,
|
|
74
|
-
graph=graph,
|
|
75
|
-
quant_config=quant_config,
|
|
76
|
-
fqc=graph.fqc,
|
|
77
|
-
mixed_precision_enable=mixed_precision_enable,
|
|
78
|
-
manual_bit_width_override=manual_bit_width_override)
|
|
79
|
-
return graph
|
|
80
28
|
|
|
81
29
|
|
|
30
|
+
# TODO irena refactor (if needed) and move to load_fqc
|
|
82
31
|
def filter_node_qco_by_graph(node: BaseNode,
|
|
83
32
|
fqc: FrameworkQuantizationCapabilities,
|
|
84
33
|
graph: Graph,
|
|
@@ -101,6 +50,8 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
|
101
50
|
that are compatible with next nodes supported input bit-widths.
|
|
102
51
|
|
|
103
52
|
"""
|
|
53
|
+
from model_compression_toolkit.quantization_preparation.load_fqc import fetch_qc_options_for_node
|
|
54
|
+
|
|
104
55
|
# Filter quantization config options that don't match the graph.
|
|
105
56
|
_base_config = node_qc_options.base_config
|
|
106
57
|
_node_qc_options = node_qc_options.quantization_configurations
|
|
@@ -110,7 +61,7 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
|
110
61
|
next_nodes = []
|
|
111
62
|
while len(_next_nodes):
|
|
112
63
|
n = _next_nodes.pop(0)
|
|
113
|
-
qco = n
|
|
64
|
+
qco = fetch_qc_options_for_node(n, fqc)
|
|
114
65
|
qp = [qc.quantization_preserving for qc in qco.quantization_configurations]
|
|
115
66
|
if not all(qp) and any(qp):
|
|
116
67
|
Logger.error(f'Attribute "quantization_preserving" should be the same for all QuantizaionConfigOptions in {n}.')
|
|
@@ -120,7 +71,8 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
|
120
71
|
|
|
121
72
|
if len(next_nodes) == 0:
|
|
122
73
|
return _base_config, _node_qc_options
|
|
123
|
-
|
|
74
|
+
|
|
75
|
+
next_nodes_qc_options = [fetch_qc_options_for_node(_node, fqc) for _node in next_nodes]
|
|
124
76
|
all_next_nodes_supported_input_bitwidth = [max_input_activation_n_bits(op_cfg)
|
|
125
77
|
for qc_opts in next_nodes_qc_options
|
|
126
78
|
for op_cfg in qc_opts.quantization_configurations
|
|
@@ -150,355 +102,98 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
|
150
102
|
return _base_config, _node_qc_options
|
|
151
103
|
|
|
152
104
|
|
|
153
|
-
def
|
|
154
|
-
graph: Graph,
|
|
155
|
-
quant_config: QuantizationConfig,
|
|
156
|
-
fqc: FrameworkQuantizationCapabilities,
|
|
157
|
-
mixed_precision_enable: bool = False,
|
|
158
|
-
manual_bit_width_override: Optional[Dict] = None):
|
|
105
|
+
def set_manual_bitwidth_config(graph, bit_width_config: BitWidthConfig):
|
|
159
106
|
"""
|
|
160
|
-
|
|
107
|
+
Filters candidates per manual bit-width config.
|
|
161
108
|
|
|
162
109
|
Args:
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
quant_config (QuantizationConfig): Quantization configuration to generate the node's configurations from.
|
|
166
|
-
fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to get default OpQuantizationConfig.
|
|
167
|
-
mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
|
|
168
|
-
manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width. Defaults to None.
|
|
110
|
+
graph: graph after candidates have been set on nodes.
|
|
111
|
+
bit_width_config: bit-width config.
|
|
169
112
|
"""
|
|
170
|
-
|
|
171
|
-
|
|
113
|
+
manual_activation_bitwidths = bit_width_config.get_nodes_activation_bit_widths(graph)
|
|
114
|
+
manual_weights_bitwidths = bit_width_config.get_nodes_weights_bit_widths(graph)
|
|
172
115
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
if manual_bit_width_override is None:
|
|
176
|
-
manual_bit_width_override = {ACTIVATION: None, WEIGHTS: None}
|
|
177
|
-
|
|
178
|
-
base_config, node_qc_options_list = filter_qc_options_with_manual_bit_width(
|
|
179
|
-
node=node,
|
|
180
|
-
node_qc_options_list=node_qc_options_list,
|
|
181
|
-
base_config=base_config,
|
|
182
|
-
manual_bit_width_override=manual_bit_width_override,
|
|
183
|
-
mixed_precision_enable=mixed_precision_enable)
|
|
116
|
+
if manual_activation_bitwidths:
|
|
117
|
+
_set_manual_activation_bitwidths(manual_activation_bitwidths)
|
|
184
118
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
node.channel_axis,
|
|
188
|
-
node_qc_options_list,
|
|
189
|
-
base_config,
|
|
190
|
-
node,
|
|
191
|
-
mixed_precision_enable=mixed_precision_enable)
|
|
119
|
+
if manual_weights_bitwidths:
|
|
120
|
+
_set_manual_weights_bitwidths(manual_weights_bitwidths)
|
|
192
121
|
|
|
193
|
-
# sorting the candidates by kernel attribute weights number of bits first and then by activation number of bits
|
|
194
|
-
# (in reversed order). since only kernel attribute is quantized in weights mixed precision,
|
|
195
|
-
# if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
|
|
196
|
-
node.sort_node_candidates()
|
|
197
122
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
not node.get_has_activation():
|
|
201
|
-
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
|
202
|
-
elif candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.PRESERVE_QUANT:
|
|
203
|
-
prev_nodes = graph.get_prev_nodes(node)
|
|
204
|
-
if len(prev_nodes) != 1:
|
|
205
|
-
# Preserving the quantization of more than 1 previous node is ambiguous, so disable it.
|
|
206
|
-
Logger.info(f"Disabling Quantization-Preserving for node {node.name} because it has more than 1 input activations.")
|
|
207
|
-
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
|
208
|
-
elif not prev_nodes[0].is_quantization_preserving() and not prev_nodes[0].is_activation_quantization_enabled():
|
|
209
|
-
# Preserving the quantization of an unquantized node isn't possible, so disable it.
|
|
210
|
-
Logger.info(f"Disabling Quantization-Preserving for node {node.name} because previous node activation quantization is disabled.")
|
|
211
|
-
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
def create_node_activation_qc(qc: QuantizationConfig,
|
|
215
|
-
op_cfg: OpQuantizationConfig) -> NodeActivationQuantizationConfig:
|
|
123
|
+
# TODO irena: check coverage and add missing tests
|
|
124
|
+
def _set_manual_activation_bitwidths(manual_activation_bitwidths: Dict[BaseNode, int]):
|
|
216
125
|
"""
|
|
217
|
-
|
|
126
|
+
Filters out candidates that don't match the requested manual activation bitwidths, and updates the
|
|
127
|
+
activation bitwidth in the base quantization config.
|
|
218
128
|
|
|
219
129
|
Args:
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
130
|
+
manual_activation_bitwidths: nodes' manual activation bitwidth.
|
|
131
|
+
|
|
132
|
+
Raises:
|
|
133
|
+
ValueError: if the manual bitwidth is requested for un-quantized node.
|
|
134
|
+
if the manual bitwidth is not compatible with any candidate.
|
|
135
|
+
"""
|
|
136
|
+
for n, a_nbits in manual_activation_bitwidths.items():
|
|
137
|
+
quant_mode = n.quantization_cfg.get_activation_quant_mode()
|
|
138
|
+
# TODO irena: for FLN I think it should be ignored with warning for layer filter, and error for name filter
|
|
139
|
+
if quant_mode != ActivationQuantizationMode.QUANT:
|
|
140
|
+
raise ValueError(f'Cannot apply manual activation bit-width for node {n} with activation quantization mode'
|
|
141
|
+
f'{quant_mode}, as it does not have its own quantization configuration.')
|
|
142
|
+
candidates = [qc for qc in n.candidates_quantization_cfg
|
|
143
|
+
if qc.activation_quantization_cfg.activation_n_bits == a_nbits]
|
|
144
|
+
if not candidates:
|
|
145
|
+
valid_nbits = sorted(list({qc.activation_quantization_cfg.activation_n_bits
|
|
146
|
+
for qc in n.candidates_quantization_cfg}))
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f'Manually selected activation bit-width {a_nbits} is invalid for node {n}. '
|
|
149
|
+
f'Valid bit-widths: {valid_nbits}.')
|
|
150
|
+
n.quantization_cfg.candidates_quantization_cfg = candidates
|
|
151
|
+
n.quantization_cfg.base_quantization_cfg.activation_quantization_cfg.activation_n_bits = a_nbits
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
# TODO irena: check coverage
|
|
155
|
+
def _set_manual_weights_bitwidths(manual_weights_bitwidths: Dict[BaseNode, Dict[str, int]]):
|
|
156
|
+
"""
|
|
157
|
+
Filters out candidates that don't match the requested weight attributes manual bitwidths, and updates the bitwidths
|
|
158
|
+
in the base quantization config.
|
|
248
159
|
|
|
249
160
|
Args:
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
node: BaseNode,
|
|
290
|
-
mixed_precision_enable: bool = False) -> List[CandidateNodeQuantizationConfig]:
|
|
291
|
-
"""
|
|
292
|
-
Create a list of candidates of weights and activation quantization configurations for a node.
|
|
293
|
-
|
|
294
|
-
Args:
|
|
295
|
-
qc (QuantizationConfig): Quantization configuration the quantization process should follow.
|
|
296
|
-
weight_channel_axis (ChannelAxisMapping): (Output, Input) channel index of the node's kernel.
|
|
297
|
-
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs of node.
|
|
298
|
-
base_config (OpQuantizationConfig): Base quantization config for node.
|
|
299
|
-
node (BaseNode): A node to set quantization configuration candidates to.
|
|
300
|
-
mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
|
|
301
|
-
|
|
302
|
-
Returns:
|
|
303
|
-
List[CandidateNodeQuantizationConfig]: List of candidates of weights quantization configurations to set for a node.
|
|
304
|
-
"""
|
|
305
|
-
|
|
306
|
-
candidates = []
|
|
307
|
-
node_attrs_list = node.get_node_weights_attributes()
|
|
308
|
-
|
|
309
|
-
if mixed_precision_enable:
|
|
310
|
-
for op_cfg in node_qc_options_list:
|
|
311
|
-
candidate_qc = copy.deepcopy(qc)
|
|
312
|
-
candidates.append(_create_node_single_candidate_qc(candidate_qc,
|
|
313
|
-
weight_channel_axis,
|
|
314
|
-
op_cfg,
|
|
315
|
-
node_attrs_list))
|
|
316
|
-
|
|
317
|
-
else:
|
|
318
|
-
candidates.append(_create_node_single_candidate_qc(qc,
|
|
319
|
-
weight_channel_axis,
|
|
320
|
-
base_config,
|
|
321
|
-
node_attrs_list))
|
|
322
|
-
|
|
323
|
-
return candidates
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
def filter_qc_options_with_manual_bit_width(
|
|
327
|
-
node: BaseNode,
|
|
328
|
-
node_qc_options_list: List[OpQuantizationConfig],
|
|
329
|
-
base_config: OpQuantizationConfig,
|
|
330
|
-
manual_bit_width_override: Optional[Dict],
|
|
331
|
-
mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
|
|
332
|
-
"""
|
|
333
|
-
Update the quantization configurations for a node, allowing manual bit-width overrides if specified.
|
|
334
|
-
|
|
335
|
-
Args:
|
|
336
|
-
node (BaseNode): A node to set quantization configuration candidates to.
|
|
337
|
-
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
|
|
338
|
-
base_config (OpQuantizationConfig): Base quantization config for the node.
|
|
339
|
-
manual_bit_width_override (Optional[Dict]): Specifies a custom bit-width to override the node's activation and weights bit-width.
|
|
340
|
-
mixed_precision_enable (bool): Whether mixed precision is enabled.
|
|
341
|
-
|
|
342
|
-
Returns:
|
|
343
|
-
Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
|
|
344
|
-
"""
|
|
345
|
-
base_config, node_qc_options_list = filter_activation_qc_options_with_manual_bit_width(node,
|
|
346
|
-
node_qc_options_list,
|
|
347
|
-
base_config,
|
|
348
|
-
manual_bit_width_override.get(ACTIVATION),
|
|
349
|
-
mixed_precision_enable)
|
|
350
|
-
|
|
351
|
-
base_config, node_qc_options_list = filter_weights_qc_options_with_manual_bit_width(node,
|
|
352
|
-
node_qc_options_list,
|
|
353
|
-
base_config,
|
|
354
|
-
manual_bit_width_override.get(WEIGHTS),
|
|
355
|
-
mixed_precision_enable)
|
|
356
|
-
return base_config, node_qc_options_list
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
def filter_activation_qc_options_with_manual_bit_width(
|
|
360
|
-
node: BaseNode,
|
|
361
|
-
node_qc_options_list: List[OpQuantizationConfig],
|
|
362
|
-
base_config: OpQuantizationConfig,
|
|
363
|
-
activation_manual_bit_width_override: Optional[int],
|
|
364
|
-
mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
|
|
365
|
-
"""
|
|
366
|
-
Update the activation quantization configurations for a node, allowing manual bit-width overrides if specified.
|
|
367
|
-
|
|
368
|
-
Args:
|
|
369
|
-
node (BaseNode): A node to set quantization configuration candidates to.
|
|
370
|
-
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
|
|
371
|
-
base_config (OpQuantizationConfig): Base quantization config for the node.
|
|
372
|
-
activation_manual_bit_width_override (Optional[Dict]): Specifies a custom bit-width to override the node's activation bit-width.
|
|
373
|
-
mixed_precision_enable (bool): Whether mixed precision is enabled.
|
|
374
|
-
|
|
375
|
-
Returns:
|
|
376
|
-
Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
|
|
377
|
-
"""
|
|
378
|
-
if activation_manual_bit_width_override is None:
|
|
379
|
-
return base_config, node_qc_options_list
|
|
380
|
-
|
|
381
|
-
# Filter node_qc_options_list to retain only the options with activation bits equal to activation_manual_bit_width_override.
|
|
382
|
-
node_qc_options_list = [op_cfg for op_cfg in node_qc_options_list if
|
|
383
|
-
activation_manual_bit_width_override == op_cfg.activation_n_bits]
|
|
384
|
-
if len(node_qc_options_list) == 0:
|
|
385
|
-
Logger.critical(f"Manually selected activation bit-width {activation_manual_bit_width_override} is invalid for node {node}.")
|
|
386
|
-
else:
|
|
387
|
-
# Update the base_config to one of the values from the filtered node_qc_options_list.
|
|
388
|
-
# First, check if a configuration similar to the original base_config but with activation bits equal to activation_manual_bit_width_override exists.
|
|
389
|
-
# If it does, use it as the base_config. If not, choose a different configuration from node_qc_options_list.
|
|
390
|
-
Logger.info(f"Setting node {node} bit-width to manually selected bit-width: {activation_manual_bit_width_override} bits.")
|
|
391
|
-
updated_base_config = base_config.clone_and_edit({ACTIVATION_N_BITS, activation_manual_bit_width_override})
|
|
392
|
-
if updated_base_config in node_qc_options_list:
|
|
393
|
-
# If a base_config with the specified activation_manual_bit_width_override exists in the node_qc_options_list,
|
|
394
|
-
# point the base_config to this option.
|
|
395
|
-
base_config = node_qc_options_list[node_qc_options_list.index(updated_base_config)]
|
|
396
|
-
else:
|
|
397
|
-
# Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
|
|
398
|
-
base_config = node_qc_options_list[0]
|
|
399
|
-
if len(node_qc_options_list) > 0 and not mixed_precision_enable:
|
|
400
|
-
Logger.info(
|
|
401
|
-
f"Request received to select {activation_manual_bit_width_override} activation bits. However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
|
|
402
|
-
f" Overriding base_config with an option that uses {activation_manual_bit_width_override} bit activations.") # pragma: no cover
|
|
403
|
-
|
|
404
|
-
return base_config, node_qc_options_list
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
def filter_weights_qc_options_with_manual_bit_width(
|
|
408
|
-
node: BaseNode,
|
|
409
|
-
node_qc_options_list: List[OpQuantizationConfig],
|
|
410
|
-
base_config: OpQuantizationConfig,
|
|
411
|
-
weights_manual_bit_width_override: Optional[Tuple[int, WeightAttrT]],
|
|
412
|
-
mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
|
|
413
|
-
"""
|
|
414
|
-
Update the weights quantization configurations for a node, allowing manual bit-width overrides if specified.
|
|
415
|
-
|
|
416
|
-
Args:
|
|
417
|
-
node (BaseNode): A node to set quantization configuration candidates to.
|
|
418
|
-
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
|
|
419
|
-
base_config (OpQuantizationConfig): Base quantization config for the node.
|
|
420
|
-
weights_manual_bit_width_override (Optional[[int, WeightAttrT]]): Specifies a custom bit-width to override the node's weights bit-width.
|
|
421
|
-
mixed_precision_enable (bool): Whether mixed precision is enabled.
|
|
422
|
-
|
|
423
|
-
Returns:
|
|
424
|
-
Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
|
|
425
|
-
"""
|
|
426
|
-
if not weights_manual_bit_width_override:
|
|
427
|
-
return base_config, node_qc_options_list
|
|
428
|
-
|
|
429
|
-
# Filter node_qc_options_list to retain only the options with weights bits equal to weights_manual_bit_width_override.
|
|
430
|
-
node_qc_options_weights_list = _filter_options(node_qc_options_list, weights_manual_bit_width_override)
|
|
431
|
-
|
|
432
|
-
if len(node_qc_options_weights_list) == 0:
|
|
433
|
-
Logger.critical(f"Manually selected weights bit-width {weights_manual_bit_width_override} is invalid for node {node}.")
|
|
434
|
-
else:
|
|
435
|
-
# Update the base_config to one of the values from the filtered node_qc_options_list.
|
|
436
|
-
# First, check if a configuration similar to the original base_config but with weights bits equal to weights_manual_bit_width_override exists.
|
|
437
|
-
# If it does, use it as the base_config. If not, choose a different configuration from node_qc_options_list.
|
|
438
|
-
updated_base_config = base_config.clone_and_edit()
|
|
439
|
-
|
|
440
|
-
for bit_width, attr in weights_manual_bit_width_override:
|
|
441
|
-
Logger.info(f"Setting node {node} bit-width to manually selected {attr} bit-width: {bit_width} bits.")
|
|
442
|
-
updated_base_config = updated_base_config.clone_and_edit(attr_to_edit={attr : {WEIGHTS_N_BITS: bit_width}})
|
|
443
|
-
|
|
444
|
-
if updated_base_config in node_qc_options_weights_list:
|
|
445
|
-
# If a base_config with the specified weights_manual_bit_width_override exists in the node_qc_options_list,
|
|
446
|
-
# point the base_config to this option.
|
|
447
|
-
base_config = node_qc_options_weights_list[node_qc_options_weights_list.index(updated_base_config)]
|
|
448
|
-
else:
|
|
449
|
-
# Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
|
|
450
|
-
base_config = node_qc_options_weights_list[0]
|
|
451
|
-
if len(node_qc_options_weights_list) > 0 and not mixed_precision_enable:
|
|
452
|
-
Logger.info(
|
|
453
|
-
f"Request received to select weights bit-widths {weights_manual_bit_width_override}."
|
|
454
|
-
f"However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
|
|
455
|
-
f" Overriding base_config with an option that uses manually selected weights bit-widths {weights_manual_bit_width_override}.") # pragma: no cover
|
|
456
|
-
|
|
457
|
-
return base_config, node_qc_options_weights_list
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
def _is_valid_option(
|
|
461
|
-
op_cfg: OpQuantizationConfig,
|
|
462
|
-
attr: WeightAttrT,
|
|
463
|
-
bit_width: int) -> bool:
|
|
464
|
-
"""
|
|
465
|
-
Judge whether the specified option is valid based on the specified attribute and bit width.
|
|
466
|
-
|
|
467
|
-
Args:
|
|
468
|
-
op_cfg (OpQuantizationConfig): The quantization configuration to be judged.
|
|
469
|
-
attr (WeightAttrT): The filtered node's attributes to apply bit-width manipulation to.
|
|
470
|
-
bit_width (int): The bit width to be applied to the selected nodes.
|
|
471
|
-
|
|
472
|
-
Returns:
|
|
473
|
-
Result to judge whether the specified option is valid based on the specified attribute and bit width
|
|
474
|
-
"""
|
|
475
|
-
weights_attrs = op_cfg.attr_weights_configs_mapping.keys()
|
|
476
|
-
|
|
477
|
-
if attr not in weights_attrs:
|
|
478
|
-
return False
|
|
479
|
-
|
|
480
|
-
weights_n_bits = op_cfg.attr_weights_configs_mapping[attr].weights_n_bits
|
|
481
|
-
return weights_n_bits == bit_width
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
def _filter_options(
|
|
485
|
-
node_qc_options_list: List[OpQuantizationConfig],
|
|
486
|
-
weights_manual_bit_width_override: Tuple[int, WeightAttrT]) -> List[OpQuantizationConfig]:
|
|
487
|
-
"""
|
|
488
|
-
Filter the options based on the specified bit width and attribute.
|
|
489
|
-
|
|
490
|
-
Args:
|
|
491
|
-
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
|
|
492
|
-
weights_manual_bit_width_override (Tuple[int, WeightAttrT])): Specifies a custom bit-width to override the node's weights bit-width.
|
|
493
|
-
|
|
494
|
-
Returns:
|
|
495
|
-
List[OpQuantizationConfig]: Filtered the options based on the specified bit width and attribute.
|
|
496
|
-
"""
|
|
497
|
-
filtered_options = []
|
|
498
|
-
|
|
499
|
-
for bit_width, attr in weights_manual_bit_width_override:
|
|
500
|
-
for op_cfg in node_qc_options_list:
|
|
501
|
-
if _is_valid_option(op_cfg, attr, bit_width):
|
|
502
|
-
filtered_options.append(op_cfg)
|
|
503
|
-
|
|
504
|
-
return filtered_options
|
|
161
|
+
manual_activation_bitwidths: nodes' manual activation bitwidth.
|
|
162
|
+
|
|
163
|
+
Raises:
|
|
164
|
+
ValueError: if the manual bitwidth is requested for non-existing attribute.
|
|
165
|
+
if the manual bitwidth is requested for un-quantized weights attribute.
|
|
166
|
+
if the manual bitwidth is not compatible with any candidate.
|
|
167
|
+
"""
|
|
168
|
+
def qc_attr_nbits(qc, attr, n):
|
|
169
|
+
if attr == POSITIONAL_ATTR:
|
|
170
|
+
pos_attrs = qc.weights_quantization_cfg.pos_attributes_config_mapping
|
|
171
|
+
if not pos_attrs:
|
|
172
|
+
raise ValueError('Unexpected positional attribute in manual weights bit-width for node {n}.')
|
|
173
|
+
if any(cfg.enable_weights_quantization is False for cfg in pos_attrs.values()):
|
|
174
|
+
raise ValueError(f'Cannot apply manual bit-width configuration for positional attribute of node {n} as '
|
|
175
|
+
f'the attribute is not quantized.')
|
|
176
|
+
assert len({cfg.weights_n_bits for cfg in pos_attrs.values()}) == 1
|
|
177
|
+
return list(pos_attrs.values())[0].weights_n_bits
|
|
178
|
+
if attr not in qc.weights_quantization_cfg.all_weight_attrs:
|
|
179
|
+
raise ValueError(f'Unexpected attribute {attr} in manual weights bit-width configuration for node {n}.')
|
|
180
|
+
attr_cfg = qc.weights_quantization_cfg.get_attr_config(attr)
|
|
181
|
+
if not attr_cfg.enable_weights_quantization:
|
|
182
|
+
raise ValueError(f'Cannot apply manual bit-width configuration for weights attribute {attr} of node {n} as '
|
|
183
|
+
f'the attribute is not quantized.')
|
|
184
|
+
return qc.weights_quantization_cfg.get_attr_config(attr).weights_n_bits
|
|
185
|
+
|
|
186
|
+
for n, manual_wbits in manual_weights_bitwidths.items():
|
|
187
|
+
candidates = [qc for qc in n.candidates_quantization_cfg
|
|
188
|
+
if all(qc_attr_nbits(qc, attr, n) == w_nbits for attr, w_nbits in manual_wbits.items())]
|
|
189
|
+
if not candidates:
|
|
190
|
+
raise ValueError(f'Cannot apply manual weights bit-width configuration {manual_wbits} for node {n} as it '
|
|
191
|
+
f'does not match any of the quantization candidates.')
|
|
192
|
+
n.quantization_cfg.candidates_quantization_cfg = candidates
|
|
193
|
+
for attr, w_nbits in manual_wbits.items():
|
|
194
|
+
base_weights_cfg = n.quantization_cfg.base_quantization_cfg.weights_quantization_cfg
|
|
195
|
+
if attr == POSITIONAL_ATTR:
|
|
196
|
+
for pos_attr in base_weights_cfg.pos_attributes_config_mapping:
|
|
197
|
+
base_weights_cfg.get_attr_config(pos_attr).weights_n_bits = w_nbits
|
|
198
|
+
else:
|
|
199
|
+
base_weights_cfg.get_attr_config(attr).weights_n_bits = w_nbits
|
|
@@ -42,13 +42,12 @@ def apply_activation_bias_correction_to_graph(graph: Graph,
|
|
|
42
42
|
n.final_activation_quantization_cfg.activation_bias_correction_term is not None:
|
|
43
43
|
# If activation bias correction is enabled in n.quantization_cfg, an activation bias correction term was
|
|
44
44
|
# calculated during model preparation, and is used now in the node's bias term.
|
|
45
|
-
_apply_activation_bias_correction_to_node(n, fw_impl
|
|
45
|
+
_apply_activation_bias_correction_to_node(n, fw_impl)
|
|
46
46
|
return graph
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
def _apply_activation_bias_correction_to_node(node: BaseNode,
|
|
50
|
-
fw_impl: FrameworkImplementation
|
|
51
|
-
qc: QuantizationConfig):
|
|
50
|
+
fw_impl: FrameworkImplementation):
|
|
52
51
|
"""
|
|
53
52
|
Set new bias to node using the activation bias correction term that is stored in the
|
|
54
53
|
final activation quantization configuration.
|
|
@@ -56,7 +55,6 @@ def _apply_activation_bias_correction_to_node(node: BaseNode,
|
|
|
56
55
|
Args:
|
|
57
56
|
node: Node to set its corrected bias after activation bias correction.
|
|
58
57
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
|
59
|
-
qc: QuantizationConfig containing parameters of how the model should be quantized.
|
|
60
58
|
|
|
61
59
|
"""
|
|
62
60
|
correction = node.final_activation_quantization_cfg.activation_bias_correction_term
|
|
@@ -72,7 +70,6 @@ def _apply_activation_bias_correction_to_node(node: BaseNode,
|
|
|
72
70
|
# Configure the quantization of the bias as disabled.
|
|
73
71
|
node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS,
|
|
74
72
|
WeightsAttrQuantizationConfig(
|
|
75
|
-
qc,
|
|
76
73
|
AttributeQuantizationConfig(
|
|
77
74
|
enable_weights_quantization=False)))
|
|
78
75
|
else:
|
model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py
CHANGED
|
@@ -77,7 +77,5 @@ def _apply_bias_correction_to_node(node: BaseNode,
|
|
|
77
77
|
node.set_weights_by_keys(fw_impl.constants.BIAS, - correction)
|
|
78
78
|
node.framework_attr[fw_impl.constants.USE_BIAS] = True # Mark the use_bias attribute of the node.
|
|
79
79
|
node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS,
|
|
80
|
-
WeightsAttrQuantizationConfig(
|
|
81
|
-
|
|
82
|
-
AttributeQuantizationConfig(
|
|
83
|
-
enable_weights_quantization=False)))
|
|
80
|
+
WeightsAttrQuantizationConfig(AttributeQuantizationConfig(
|
|
81
|
+
enable_weights_quantization=False)))
|
|
@@ -24,7 +24,7 @@ from model_compression_toolkit.core.common.model_builder_mode import ModelBuilde
|
|
|
24
24
|
from model_compression_toolkit.core.common.model_collector import ModelCollector
|
|
25
25
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
|
26
26
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
|
|
27
|
-
import
|
|
27
|
+
import compute_activation_qparams
|
|
28
28
|
from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
|
|
29
29
|
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
|
|
30
30
|
|
|
@@ -50,12 +50,11 @@ def _collect_and_assign_act_threshold(graph: Graph,
|
|
|
50
50
|
for _data in tqdm(representative_data_gen()):
|
|
51
51
|
mi.infer(_data)
|
|
52
52
|
|
|
53
|
-
for n in
|
|
53
|
+
for n in graph.nodes:
|
|
54
54
|
if n.is_activation_quantization_enabled():
|
|
55
|
-
activation_params =
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
out_stats_container=graph.get_out_stats_collector(n))
|
|
55
|
+
activation_params = compute_activation_qparams(activation_quant_cfg=n.final_activation_quantization_cfg,
|
|
56
|
+
node_prior_info=n.prior_info,
|
|
57
|
+
out_stats_container=graph.get_out_stats_collector(n))
|
|
59
58
|
n.final_activation_quantization_cfg.set_activation_quantization_param(activation_params)
|
|
60
59
|
|
|
61
60
|
|