mct-nightly 2.4.0.20250629.706__py3-none-any.whl → 2.4.0.20250701.185106__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/METADATA +16 -16
- {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/RECORD +75 -72
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -1
- model_compression_toolkit/core/common/framework_info.py +5 -32
- model_compression_toolkit/core/common/fusion/graph_fuser.py +12 -9
- model_compression_toolkit/core/common/graph/base_graph.py +20 -37
- model_compression_toolkit/core/common/graph/base_node.py +13 -106
- model_compression_toolkit/core/common/graph/functional_node.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +12 -10
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +14 -9
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +9 -15
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +2 -3
- model_compression_toolkit/core/common/network_editors/__init__.py +8 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -96
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +116 -56
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +55 -179
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +21 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +8 -5
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -70
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +10 -12
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +54 -30
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +93 -398
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +2 -5
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -4
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -6
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +12 -6
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -2
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +33 -33
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +2 -4
- model_compression_toolkit/core/graph_prep_runner.py +31 -20
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +5 -2
- model_compression_toolkit/core/keras/default_framework_info.py +0 -11
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +9 -6
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +3 -1
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +2 -1
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +47 -0
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +5 -2
- model_compression_toolkit/core/pytorch/default_framework_info.py +0 -12
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +2 -0
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +2 -1
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +45 -0
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +3 -2
- model_compression_toolkit/core/runner.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +7 -3
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +12 -3
- model_compression_toolkit/pruning/keras/pruning_facade.py +5 -9
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -5
- model_compression_toolkit/ptq/keras/quantization_facade.py +1 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +1 -1
- model_compression_toolkit/quantization_preparation/__init__.py +14 -0
- model_compression_toolkit/quantization_preparation/load_fqc.py +223 -0
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -78
- {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250629.706.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/lut_fake_quant.py +0 -0
|
@@ -20,14 +20,8 @@ from typing import Callable
|
|
|
20
20
|
from mct_quantizers import QuantizationMethod
|
|
21
21
|
from model_compression_toolkit.core.common import Graph
|
|
22
22
|
from model_compression_toolkit.logger import Logger
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
from model_compression_toolkit.core.common.framework_info import get_fw_info
|
|
26
23
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
27
|
-
|
|
28
|
-
get_activation_quantization_params_fn, get_weights_quantization_params_fn
|
|
29
|
-
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
|
|
30
|
-
get_weights_quantization_fn
|
|
24
|
+
|
|
31
25
|
|
|
32
26
|
_EditRule = namedtuple('EditRule', 'filter action')
|
|
33
27
|
|
|
@@ -174,47 +168,6 @@ class ChangeFinalActivationQuantConfigAttr(BaseAction):
|
|
|
174
168
|
node.final_activation_quantization_cfg.set_quant_config_attr(parameter_name, parameter_value)
|
|
175
169
|
|
|
176
170
|
|
|
177
|
-
class ChangeQuantizationParamFunction(BaseAction):
|
|
178
|
-
"""
|
|
179
|
-
Class ChangeQuantizationParamFunction to change a node's weights/activations quantization params function.
|
|
180
|
-
"""
|
|
181
|
-
|
|
182
|
-
def __init__(self,
|
|
183
|
-
attr_name: str = None,
|
|
184
|
-
activation_quantization_params_fn: Callable = None,
|
|
185
|
-
weights_quantization_params_fn: Callable = None):
|
|
186
|
-
"""
|
|
187
|
-
Init a ChangeQuantizationParamFunction object.
|
|
188
|
-
|
|
189
|
-
Args:
|
|
190
|
-
attr_name: The weights attribute's name to set the weights quantization params function for (if setting weights params).
|
|
191
|
-
activation_quantization_params_fn: a params function for a node's activations.
|
|
192
|
-
weights_quantization_params_fn: a params function for a node's weights.
|
|
193
|
-
"""
|
|
194
|
-
self.activation_quantization_params_fn = activation_quantization_params_fn
|
|
195
|
-
self.weights_quantization_params_fn = weights_quantization_params_fn
|
|
196
|
-
self.attr_name = attr_name
|
|
197
|
-
|
|
198
|
-
def apply(self, node: BaseNode, graph):
|
|
199
|
-
"""
|
|
200
|
-
Change the node's weights/activations quantization params function.
|
|
201
|
-
|
|
202
|
-
Args:
|
|
203
|
-
node: Node object to change its quantization params function.
|
|
204
|
-
graph: Graph to apply the action on.
|
|
205
|
-
|
|
206
|
-
Returns:
|
|
207
|
-
The node after its quantization params function has been modified.
|
|
208
|
-
"""
|
|
209
|
-
for nqc in node.candidates_quantization_cfg:
|
|
210
|
-
if self.activation_quantization_params_fn is not None:
|
|
211
|
-
nqc.activation_quantization_cfg.set_activation_quantization_params_fn(
|
|
212
|
-
self.activation_quantization_params_fn)
|
|
213
|
-
if self.weights_quantization_params_fn is not None:
|
|
214
|
-
(nqc.weights_quantization_cfg.get_attr_config(self.attr_name)
|
|
215
|
-
.set_weights_quantization_params_fn(self.weights_quantization_params_fn))
|
|
216
|
-
|
|
217
|
-
|
|
218
171
|
class ChangeFinalActivationQuantizationMethod(BaseAction):
|
|
219
172
|
"""
|
|
220
173
|
Class ChangeFinalActivationQuantizationMethod to change a node's weights/activations quantizer function.
|
|
@@ -243,16 +196,6 @@ class ChangeFinalActivationQuantizationMethod(BaseAction):
|
|
|
243
196
|
"""
|
|
244
197
|
|
|
245
198
|
if self.activation_quantization_method is not None and node.final_activation_quantization_cfg is not None:
|
|
246
|
-
|
|
247
|
-
activation_quantization_params_fn = get_activation_quantization_params_fn(
|
|
248
|
-
self.activation_quantization_method)
|
|
249
|
-
|
|
250
|
-
node.final_activation_quantization_cfg.set_activation_quantization_params_fn(
|
|
251
|
-
activation_quantization_params_fn)
|
|
252
|
-
|
|
253
|
-
activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(self.activation_quantization_method)
|
|
254
|
-
|
|
255
|
-
node.final_activation_quantization_cfg.set_activation_quantization_fn(activation_quantization_fn)
|
|
256
199
|
node.final_activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method
|
|
257
200
|
|
|
258
201
|
|
|
@@ -281,23 +224,12 @@ class ChangeCandidatesActivationQuantizationMethod(BaseAction):
|
|
|
281
224
|
"""
|
|
282
225
|
if self.activation_quantization_method is not None:
|
|
283
226
|
for qc in node.candidates_quantization_cfg:
|
|
284
|
-
activation_quantization_params_fn = get_activation_quantization_params_fn(
|
|
285
|
-
self.activation_quantization_method)
|
|
286
|
-
|
|
287
|
-
qc.activation_quantization_cfg.set_activation_quantization_params_fn(activation_quantization_params_fn)
|
|
288
|
-
activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(
|
|
289
|
-
self.activation_quantization_method)
|
|
290
|
-
|
|
291
|
-
if activation_quantization_fn is None:
|
|
292
|
-
Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
|
|
293
|
-
|
|
294
|
-
qc.activation_quantization_cfg.set_activation_quantization_fn(activation_quantization_fn)
|
|
295
227
|
qc.activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method
|
|
296
228
|
|
|
297
229
|
|
|
298
230
|
class ChangeFinalWeightsQuantizationMethod(BaseAction):
|
|
299
231
|
"""
|
|
300
|
-
Class ChangeFinalWeightsQuantizationMethod to change a node's weights/activations quantizer
|
|
232
|
+
Class ChangeFinalWeightsQuantizationMethod to change a node's weights/activations quantizer method.
|
|
301
233
|
"""
|
|
302
234
|
|
|
303
235
|
def __init__(self, attr_name: str, weights_quantization_method=None):
|
|
@@ -323,21 +255,8 @@ class ChangeFinalWeightsQuantizationMethod(BaseAction):
|
|
|
323
255
|
"""
|
|
324
256
|
|
|
325
257
|
if self.weights_quantization_method is not None and node.final_weights_quantization_cfg is not None:
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
(node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
|
|
330
|
-
.set_weights_quantization_params_fn(weights_quantization_params_fn))
|
|
331
|
-
|
|
332
|
-
weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method)
|
|
333
|
-
|
|
334
|
-
if weights_quantization_fn is None:
|
|
335
|
-
Logger.critical('Unknown weights quantization method specified.') # pragma: no cover
|
|
336
|
-
|
|
337
|
-
(node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
|
|
338
|
-
.set_weights_quantization_fn(weights_quantization_fn))
|
|
339
|
-
node.final_weights_quantization_cfg.get_attr_config(self.attr_name).weights_quantization_method = \
|
|
340
|
-
self.weights_quantization_method
|
|
258
|
+
attr_config = node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
|
|
259
|
+
attr_config.weights_quantization_method = self.weights_quantization_method
|
|
341
260
|
|
|
342
261
|
|
|
343
262
|
class ChangeCandidatesWeightsQuantizationMethod(BaseAction):
|
|
@@ -370,18 +289,7 @@ class ChangeCandidatesWeightsQuantizationMethod(BaseAction):
|
|
|
370
289
|
|
|
371
290
|
if self.weights_quantization_method is not None:
|
|
372
291
|
for qc in node.candidates_quantization_cfg:
|
|
373
|
-
|
|
374
|
-
weights_quantization_params_fn = get_weights_quantization_params_fn(self.weights_quantization_method)
|
|
375
|
-
|
|
376
292
|
attr_qc = qc.weights_quantization_cfg.get_attr_config(self.attr_name)
|
|
377
|
-
attr_qc.set_weights_quantization_params_fn(weights_quantization_params_fn)
|
|
378
|
-
|
|
379
|
-
weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method)
|
|
380
|
-
|
|
381
|
-
if weights_quantization_fn is None:
|
|
382
|
-
Logger.critical('Unknown weights quantization method specified.') # pragma: no cover
|
|
383
|
-
|
|
384
|
-
attr_qc.set_weights_quantization_fn(weights_quantization_fn)
|
|
385
293
|
attr_qc.weights_quantization_method = self.weights_quantization_method
|
|
386
294
|
|
|
387
295
|
|
|
@@ -19,8 +19,8 @@ from model_compression_toolkit.core.common import Graph
|
|
|
19
19
|
from model_compression_toolkit.core.common.matchers.node_matcher import BaseNodeMatcher
|
|
20
20
|
from model_compression_toolkit.logger import Logger
|
|
21
21
|
|
|
22
|
-
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
|
|
23
|
-
from model_compression_toolkit.target_platform_capabilities.constants import
|
|
22
|
+
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT, BaseNode
|
|
23
|
+
from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
@dataclass
|
|
@@ -95,7 +95,7 @@ class BitWidthConfig:
|
|
|
95
95
|
for attr, bit_width, filter in zip (attrs, bit_widths, filters):
|
|
96
96
|
self.manual_weights_bit_width_selection_list += [ManualWeightsBitWidthSelection(filter, bit_width, attr)]
|
|
97
97
|
|
|
98
|
-
def
|
|
98
|
+
def get_nodes_activation_bit_widths(self, graph: Graph) -> Dict[BaseNode, int]:
|
|
99
99
|
"""
|
|
100
100
|
Retrieve nodes from the graph that need their bit-widths for activation changed according to the manual bit-width selections.
|
|
101
101
|
|
|
@@ -108,7 +108,7 @@ class BitWidthConfig:
|
|
|
108
108
|
activation_nodes_to_change_bit_width = self._construct_node_to_new_activation_bit_mapping(graph)
|
|
109
109
|
return activation_nodes_to_change_bit_width
|
|
110
110
|
|
|
111
|
-
def
|
|
111
|
+
def get_nodes_weights_bit_widths(self, graph: Graph) -> Dict[BaseNode, Dict[str, int]]:
|
|
112
112
|
"""
|
|
113
113
|
Retrieve nodes from the graph that need their bit-widths for weights changed according to the manual bit-width selections.
|
|
114
114
|
|
|
@@ -166,7 +166,7 @@ class BitWidthConfig:
|
|
|
166
166
|
attrs = BitWidthConfig._expand_to_list_core(filters, attrs)
|
|
167
167
|
return attrs, bit_widths, filters
|
|
168
168
|
|
|
169
|
-
def _construct_node_to_new_activation_bit_mapping(self, graph) -> Dict:
|
|
169
|
+
def _construct_node_to_new_activation_bit_mapping(self, graph) -> Dict[BaseNode, int]:
|
|
170
170
|
"""
|
|
171
171
|
Retrieve nodes from the graph that need their activation bit-widths changed according to the manual bit-width selections.
|
|
172
172
|
|
|
@@ -192,7 +192,7 @@ class BitWidthConfig:
|
|
|
192
192
|
unit_nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width})
|
|
193
193
|
return unit_nodes_to_change_bit_width
|
|
194
194
|
|
|
195
|
-
def _construct_node_to_new_weights_bit_mapping(self, graph) -> Dict:
|
|
195
|
+
def _construct_node_to_new_weights_bit_mapping(self, graph) -> Dict[BaseNode, Dict[str, int]]:
|
|
196
196
|
"""
|
|
197
197
|
Retrieve nodes from the graph that need their weights bit-widths changed according to the manual bit-width selections.
|
|
198
198
|
|
|
@@ -212,7 +212,7 @@ class BitWidthConfig:
|
|
|
212
212
|
f"to change their bit width to {manual_bit_width_selection.bit_width}.")
|
|
213
213
|
|
|
214
214
|
for n in filtered_nodes:
|
|
215
|
-
attr_to_change_bit_width =
|
|
215
|
+
attr_to_change_bit_width = {}
|
|
216
216
|
|
|
217
217
|
attrs_str = n.get_node_weights_attributes()
|
|
218
218
|
if len(attrs_str) == 0:
|
|
@@ -225,8 +225,8 @@ class BitWidthConfig:
|
|
|
225
225
|
attr.append(attr_str)
|
|
226
226
|
# this is a positional attribute, so it needs to be handled separately.
|
|
227
227
|
# Search manual_bit_width_selection's attribute that contain the POS_ATTR string.
|
|
228
|
-
elif isinstance(attr_str, int) and
|
|
229
|
-
attr.append(
|
|
228
|
+
elif isinstance(attr_str, int) and POSITIONAL_ATTR in manual_bit_width_selection.attr:
|
|
229
|
+
attr.append(POSITIONAL_ATTR)
|
|
230
230
|
if len(attr) == 0:
|
|
231
231
|
Logger.critical(f'The requested attribute {manual_bit_width_selection.attr} to change the bit width for {n} does not exist.')
|
|
232
232
|
|
|
@@ -239,7 +239,7 @@ class BitWidthConfig:
|
|
|
239
239
|
f"Node {n} has an existing manual bit width configuration of {manual_bit_width_selection.attr}."
|
|
240
240
|
f"A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
|
|
241
241
|
|
|
242
|
-
attr_to_change_bit_width
|
|
242
|
+
attr_to_change_bit_width[manual_bit_width_selection.attr] = manual_bit_width_selection.bit_width
|
|
243
243
|
unit_nodes_to_change_bit_width.update({n: attr_to_change_bit_width})
|
|
244
244
|
|
|
245
245
|
return unit_nodes_to_change_bit_width
|
|
@@ -12,73 +12,133 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
|
|
15
|
+
import copy
|
|
16
|
+
from dataclasses import dataclass, InitVar
|
|
17
|
+
from typing import Callable, List, Optional
|
|
16
18
|
|
|
17
|
-
from model_compression_toolkit.core import QuantizationConfig
|
|
18
|
-
from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
|
|
19
19
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
|
|
20
|
-
NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
|
|
21
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig, \
|
|
22
|
-
OpQuantizationConfig
|
|
23
|
-
from model_compression_toolkit.logger import Logger
|
|
20
|
+
NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig, ActivationQuantizationMode
|
|
24
21
|
|
|
25
22
|
|
|
26
|
-
|
|
27
|
-
# Every node holds a quantization configuration
|
|
28
|
-
# for its weights quantization, and a different quantization
|
|
29
|
-
# configuration for its activation quantization configuration.
|
|
30
|
-
##########################################
|
|
31
|
-
|
|
23
|
+
@dataclass(eq=True)
|
|
32
24
|
class CandidateNodeQuantizationConfig(BaseNodeQuantizationConfig):
|
|
33
25
|
"""
|
|
34
|
-
|
|
26
|
+
Candidate quantization configuration for a node.
|
|
35
27
|
"""
|
|
28
|
+
activation_quantization_cfg: NodeActivationQuantizationConfig
|
|
29
|
+
# TODO irena: None is passed in several places, need to check if it's handled properly or it's only passed in cases
|
|
30
|
+
# that do not affect anything (my guess is it's the second).
|
|
31
|
+
# I think in general it makes more sense to set it to None when there are no weights, and maybe when all weights
|
|
32
|
+
# are unquantized, and handle it properly everywhere.
|
|
33
|
+
weights_quantization_cfg: Optional[NodeWeightsQuantizationConfig]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# TODO irena: currently all code still looks at candidates_quantization_cfg as previously, so this is just an initial
|
|
37
|
+
# implementation. For now base config is completely separated from candidates (base config must be equal to one of the
|
|
38
|
+
# candidates, but we create a separate copy), and updating in place is allowed. Also we require quantization mode to
|
|
39
|
+
# be identical between all configs.
|
|
40
|
+
@dataclass
|
|
41
|
+
class NodeQuantizationConfig:
|
|
42
|
+
# quantization config for single precision
|
|
43
|
+
base_quantization_cfg: CandidateNodeQuantizationConfig
|
|
44
|
+
# quantization candidate configs for mixed precision
|
|
45
|
+
candidates_quantization_cfg: List[CandidateNodeQuantizationConfig]
|
|
46
|
+
|
|
47
|
+
validate: InitVar[bool] = True
|
|
48
|
+
|
|
49
|
+
def update_all(self, update_fn: Callable[[CandidateNodeQuantizationConfig], None], remove_duplicates: bool = True):
|
|
50
|
+
"""
|
|
51
|
+
Apply update function on the base config and all candidates configs.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
update_fn: function to apply.
|
|
55
|
+
remove_duplicates: remove duplicate candidates.
|
|
56
|
+
"""
|
|
57
|
+
if self.base_quantization_cfg:
|
|
58
|
+
update_fn(self.base_quantization_cfg)
|
|
59
|
+
for cfg in self.candidates_quantization_cfg:
|
|
60
|
+
update_fn(cfg)
|
|
61
|
+
if remove_duplicates:
|
|
62
|
+
self.remove_duplicates()
|
|
36
63
|
|
|
37
|
-
def
|
|
38
|
-
qc: QuantizationConfig = None,
|
|
39
|
-
op_cfg: OpQuantizationConfig = None,
|
|
40
|
-
activation_quantization_cfg: NodeActivationQuantizationConfig = None,
|
|
41
|
-
activation_quantization_fn: Callable = None,
|
|
42
|
-
activation_quantization_params_fn: Callable = None,
|
|
43
|
-
weights_quantization_cfg: NodeWeightsQuantizationConfig = None,
|
|
44
|
-
weights_channels_axis: ChannelAxisMapping = None,
|
|
45
|
-
node_attrs_list: List[str] = None):
|
|
64
|
+
def update_activation_quantization_mode(self, mode: ActivationQuantizationMode):
|
|
46
65
|
"""
|
|
66
|
+
Update activation quantization mode for the base config and all candidates configs.
|
|
47
67
|
|
|
48
68
|
Args:
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
69
|
+
mode: quantization mode.
|
|
70
|
+
"""
|
|
71
|
+
def fn(c):
|
|
72
|
+
c.activation_quantization_cfg.quant_mode = mode
|
|
73
|
+
|
|
74
|
+
self.update_all(fn)
|
|
75
|
+
|
|
76
|
+
def disable_weights_quantization(self):
|
|
77
|
+
"""
|
|
78
|
+
Disable all weights quantization for the base config and all candidates configs.
|
|
79
|
+
"""
|
|
80
|
+
self.update_all(lambda c: c.weights_quantization_cfg.disable_all_weights_quantization())
|
|
81
|
+
|
|
82
|
+
def get_activation_quant_mode(self) -> ActivationQuantizationMode:
|
|
83
|
+
"""
|
|
84
|
+
Retrieve activation quantization mode.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Activation quantization mode.
|
|
88
|
+
|
|
89
|
+
Raises:
|
|
90
|
+
ValueError if not all candidates contain the same mode.
|
|
91
|
+
"""
|
|
92
|
+
self._validate_consistent_activation_quant_mode()
|
|
93
|
+
return self.base_quantization_cfg.activation_quantization_cfg.quant_mode
|
|
94
|
+
|
|
95
|
+
def remove_duplicates(self):
|
|
96
|
+
"""
|
|
97
|
+
Remove duplicate candidates. First candidate among duplicates is kept, and the order is preserved.
|
|
57
98
|
"""
|
|
99
|
+
uniq_qcs = []
|
|
100
|
+
for qc in self.candidates_quantization_cfg:
|
|
101
|
+
if qc not in uniq_qcs:
|
|
102
|
+
uniq_qcs.append(qc)
|
|
103
|
+
self.candidates_quantization_cfg = uniq_qcs
|
|
58
104
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
op_cfg=op_cfg,
|
|
70
|
-
activation_quantization_fn=activation_quantization_fn,
|
|
71
|
-
activation_quantization_params_fn=activation_quantization_params_fn))
|
|
72
|
-
|
|
73
|
-
if weights_quantization_cfg is not None:
|
|
74
|
-
self.weights_quantization_cfg = weights_quantization_cfg
|
|
75
|
-
elif all(v is not None for v in (qc, op_cfg, node_attrs_list)):
|
|
76
|
-
self.weights_quantization_cfg = NodeWeightsQuantizationConfig(qc=qc,
|
|
77
|
-
op_cfg=op_cfg,
|
|
78
|
-
weights_channels_axis=weights_channels_axis,
|
|
79
|
-
node_attrs_list=node_attrs_list)
|
|
80
|
-
else:
|
|
81
|
-
self.weights_quantization_cfg = None
|
|
82
|
-
Logger.debug("Setting weights quantization config as None during CandidateNodeQuantizationConfig creation."
|
|
83
|
-
"Notice, this should happen only for FLN nodes.")
|
|
105
|
+
def __post_init__(self, validate=True):
|
|
106
|
+
if validate:
|
|
107
|
+
if not any(self.base_quantization_cfg == qc for qc in self.candidates_quantization_cfg):
|
|
108
|
+
raise ValueError('Candidates should contain the base config.')
|
|
109
|
+
self._validate_consistent_activation_quant_mode()
|
|
110
|
+
self._validate_consistent_weights_quant_mode()
|
|
111
|
+
# TODO irena
|
|
112
|
+
# for now make sure they are separate objects so that one doesnt inadvertently modify the other
|
|
113
|
+
if any(self.base_quantization_cfg is qc for qc in self.candidates_quantization_cfg):
|
|
114
|
+
self.base_quantization_cfg = copy.deepcopy(self.base_quantization_cfg)
|
|
84
115
|
|
|
116
|
+
def _validate_consistent_activation_quant_mode(self):
|
|
117
|
+
"""
|
|
118
|
+
Validate that base config and all candidates configs contain identical activation quantization mode.
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
ValueError if activation quantization mode is not consistent.
|
|
122
|
+
"""
|
|
123
|
+
activation_quant_mode = self.base_quantization_cfg.activation_quantization_cfg.quant_mode
|
|
124
|
+
if any(qc.activation_quantization_cfg.quant_mode != activation_quant_mode
|
|
125
|
+
for qc in self.candidates_quantization_cfg):
|
|
126
|
+
raise ValueError('Quantization candidates with different quantization modes are not currently supported.')
|
|
127
|
+
|
|
128
|
+
def _validate_consistent_weights_quant_mode(self):
|
|
129
|
+
"""
|
|
130
|
+
Validate that base config and all candidates configs contain identical weights quantization mode per attribute,
|
|
131
|
+
i.e. quantization for each attribute should either be enabled in all configs, or disabled in all configs.
|
|
132
|
+
|
|
133
|
+
Raises:
|
|
134
|
+
ValueError if weights quantization is not consistent.
|
|
135
|
+
"""
|
|
136
|
+
def get_weights_mode(qc):
|
|
137
|
+
# in graph fuser weights_quantization_cfg is set to None
|
|
138
|
+
if qc.weights_quantization_cfg is None:
|
|
139
|
+
return None
|
|
140
|
+
return {attr: attr_cfg.enable_weights_quantization for attr, attr_cfg
|
|
141
|
+
in qc.weights_quantization_cfg.get_all_weight_attrs_configs().items()}
|
|
142
|
+
if any(get_weights_mode(self.base_quantization_cfg) != get_weights_mode(qc)
|
|
143
|
+
for qc in self.candidates_quantization_cfg):
|
|
144
|
+
raise ValueError('Quantization candidates with different quantization modes are not currently supported.')
|
|
@@ -33,7 +33,7 @@ def filter_nodes_candidates(graph: Graph):
|
|
|
33
33
|
"""
|
|
34
34
|
nodes = list(graph.nodes)
|
|
35
35
|
for n in nodes:
|
|
36
|
-
n.candidates_quantization_cfg = filter_node_candidates(node=n)
|
|
36
|
+
n.quantization_cfg.candidates_quantization_cfg = filter_node_candidates(node=n)
|
|
37
37
|
|
|
38
38
|
return graph
|
|
39
39
|
|