mct-nightly 2.4.0.20250630.629__py3-none-any.whl → 2.4.0.20250701.185106__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/METADATA +16 -16
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/RECORD +75 -72
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -1
- model_compression_toolkit/core/common/framework_info.py +5 -32
- model_compression_toolkit/core/common/fusion/graph_fuser.py +12 -9
- model_compression_toolkit/core/common/graph/base_graph.py +20 -37
- model_compression_toolkit/core/common/graph/base_node.py +13 -106
- model_compression_toolkit/core/common/graph/functional_node.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +12 -10
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +14 -9
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +9 -15
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +2 -3
- model_compression_toolkit/core/common/network_editors/__init__.py +8 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -96
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +116 -56
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +55 -179
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +21 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +8 -5
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -70
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +10 -12
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +54 -30
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +93 -398
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +2 -5
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -4
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -6
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +12 -6
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -2
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +33 -33
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +2 -4
- model_compression_toolkit/core/graph_prep_runner.py +31 -20
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +5 -2
- model_compression_toolkit/core/keras/default_framework_info.py +0 -11
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +9 -6
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +3 -1
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +2 -1
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +47 -0
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +5 -2
- model_compression_toolkit/core/pytorch/default_framework_info.py +0 -12
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +2 -0
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +2 -1
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +45 -0
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +3 -2
- model_compression_toolkit/core/runner.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +7 -3
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +12 -3
- model_compression_toolkit/pruning/keras/pruning_facade.py +5 -9
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -5
- model_compression_toolkit/ptq/keras/quantization_facade.py +1 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +1 -1
- model_compression_toolkit/quantization_preparation/__init__.py +14 -0
- model_compression_toolkit/quantization_preparation/load_fqc.py +223 -0
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -78
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/lut_fake_quant.py +0 -0
|
@@ -21,12 +21,8 @@ from torch.nn import Conv2d, ConvTranspose2d, Linear
|
|
|
21
21
|
from torch import sigmoid
|
|
22
22
|
|
|
23
23
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo, set_fw_info, ChannelAxisMapping
|
|
24
|
-
from mct_quantizers import QuantizationMethod
|
|
25
24
|
from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
|
|
26
25
|
from model_compression_toolkit.core.pytorch.constants import KERNEL
|
|
27
|
-
from model_compression_toolkit.core.pytorch.quantizer.fake_quant_builder import power_of_two_quantization, \
|
|
28
|
-
symmetric_quantization, uniform_quantization
|
|
29
|
-
from model_compression_toolkit.core.pytorch.quantizer.lut_fake_quant import activation_lut_kmean_quantizer
|
|
30
26
|
|
|
31
27
|
|
|
32
28
|
class PyTorchInfo(FrameworkInfo):
|
|
@@ -81,14 +77,6 @@ class PyTorchInfo(FrameworkInfo):
|
|
|
81
77
|
SiLU: (-0.279, None),
|
|
82
78
|
}
|
|
83
79
|
|
|
84
|
-
"""
|
|
85
|
-
Mapping from a QuantizationMethod to an activation quantizer function.
|
|
86
|
-
"""
|
|
87
|
-
activation_quantizer_mapping = {QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
|
|
88
|
-
QuantizationMethod.SYMMETRIC: symmetric_quantization,
|
|
89
|
-
QuantizationMethod.UNIFORM: uniform_quantization,
|
|
90
|
-
QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer}
|
|
91
|
-
|
|
92
80
|
@classmethod
|
|
93
81
|
def get_kernel_channels(cls, node_type: Any) -> ChannelAxisMapping:
|
|
94
82
|
"""
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py
CHANGED
|
@@ -95,11 +95,11 @@ class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
|
|
|
95
95
|
else:
|
|
96
96
|
return graph
|
|
97
97
|
elif non_linear_node.is_match_type(hardtanh):
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
scale_factor =
|
|
102
|
-
non_linear_node.functional_op.__defaults__ = (0.0, self.threshold,
|
|
98
|
+
kwargs = non_linear_node.op_call_kwargs
|
|
99
|
+
if (kwargs[HARDTANH_MIN_VAL] == 0.0) and not \
|
|
100
|
+
(np.log2(kwargs[HARDTANH_MAX_VAL]).astype(int) - np.log2(kwargs[HARDTANH_MAX_VAL]) == 0):
|
|
101
|
+
scale_factor = kwargs[HARDTANH_MAX_VAL] / self.threshold
|
|
102
|
+
non_linear_node.functional_op.__defaults__ = (0.0, self.threshold, kwargs[INPLACE])
|
|
103
103
|
else:
|
|
104
104
|
return graph
|
|
105
105
|
else:
|
|
@@ -29,6 +29,7 @@ from model_compression_toolkit.core.common import BaseNode, Graph
|
|
|
29
29
|
from model_compression_toolkit.core.common.graph.graph_matchers import EdgeMatcher
|
|
30
30
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
31
31
|
from model_compression_toolkit.core.common.substitutions.shift_negative_activation import apply_shift_negative_correction
|
|
32
|
+
from model_compression_toolkit.core.pytorch.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
|
|
32
33
|
from model_compression_toolkit.core.pytorch.constants import PAD, VALUE, PADDING, BIAS, USE_BIAS
|
|
33
34
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
34
35
|
|
|
@@ -239,4 +240,5 @@ def pytorch_apply_shift_negative_correction(graph: Graph,
|
|
|
239
240
|
PADDING,
|
|
240
241
|
BIAS,
|
|
241
242
|
USE_BIAS,
|
|
243
|
+
get_activation_quantization_fn_factory,
|
|
242
244
|
params_search_quantization_fn=params_search_quantization_fn)
|
|
@@ -91,7 +91,7 @@ class WeightsHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
|
|
|
91
91
|
for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
|
|
92
92
|
|
|
93
93
|
# Check if the target node's layer type is supported.
|
|
94
|
-
if not ipt_node.
|
|
94
|
+
if not ipt_node.kernel_attr:
|
|
95
95
|
Logger.critical(f"Hessian information with respect to weights is not supported for "
|
|
96
96
|
f"{ipt_node.type} layers.") # pragma: no cover
|
|
97
97
|
|
|
@@ -20,6 +20,7 @@ from model_compression_toolkit.core.common.mixed_precision.configurable_quantize
|
|
|
20
20
|
verify_candidates_descending_order, init_activation_quantizers
|
|
21
21
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
|
22
22
|
CandidateNodeQuantizationConfig
|
|
23
|
+
from model_compression_toolkit.core.pytorch.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
|
|
23
24
|
from model_compression_toolkit.logger import Logger
|
|
24
25
|
from mct_quantizers import QuantizationMethod
|
|
25
26
|
from mct_quantizers import QuantizationTarget
|
|
@@ -67,7 +68,7 @@ class ConfigurableActivationQuantizer(BasePyTorchInferableQuantizer):
|
|
|
67
68
|
Logger.critical("Unsupported configuration: Mixing candidates with differing activation quantization states (enabled/disabled).") # pragma: no cover
|
|
68
69
|
|
|
69
70
|
# Setting layer's activation
|
|
70
|
-
self.activation_quantizers = init_activation_quantizers(self.node_q_cfg)
|
|
71
|
+
self.activation_quantizers = init_activation_quantizers(self.node_q_cfg, get_activation_quantization_fn_factory)
|
|
71
72
|
self.active_quantization_config_index = max_candidate_idx # initialize with first config as default
|
|
72
73
|
|
|
73
74
|
def set_active_activation_quantizer(self, index: Optional[int]):
|
|
@@ -167,7 +167,7 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
|
167
167
|
"""
|
|
168
168
|
|
|
169
169
|
attributes_with_axis = {}
|
|
170
|
-
if node.
|
|
170
|
+
if node.kernel_attr:
|
|
171
171
|
attributes_with_axis[node.kernel_attr] = (node.channel_axis.output, node.channel_axis.input)
|
|
172
172
|
|
|
173
173
|
# Bias is a vector at the length of the number of output channels.
|
|
@@ -26,7 +26,7 @@ from torch.nn import Module, Sigmoid, Softmax
|
|
|
26
26
|
|
|
27
27
|
import model_compression_toolkit.core.pytorch.constants as pytorch_constants
|
|
28
28
|
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
|
|
29
|
-
from model_compression_toolkit.core import QuantizationConfig,
|
|
29
|
+
from model_compression_toolkit.core import QuantizationConfig, CoreConfig
|
|
30
30
|
from model_compression_toolkit.core import common
|
|
31
31
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
32
32
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from collections.abc import Callable
|
|
17
|
+
|
|
18
|
+
from mct_quantizers import QuantizationMethod
|
|
19
|
+
from model_compression_toolkit.core.pytorch.quantization.fake_quant_builder import power_of_two_quantization, \
|
|
20
|
+
symmetric_quantization, uniform_quantization
|
|
21
|
+
from model_compression_toolkit.core.pytorch.quantization.lut_fake_quant import activation_lut_kmean_quantizer
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
"""
|
|
25
|
+
Mapping from a QuantizationMethod to an activation quantizer function.
|
|
26
|
+
"""
|
|
27
|
+
_activation_quantizer_factory_mapping = {
|
|
28
|
+
QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
|
|
29
|
+
QuantizationMethod.SYMMETRIC: symmetric_quantization,
|
|
30
|
+
QuantizationMethod.UNIFORM: uniform_quantization,
|
|
31
|
+
QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def get_activation_quantization_fn_factory(quantization_method: QuantizationMethod) -> Callable[[int, dict], Callable]:
|
|
36
|
+
"""
|
|
37
|
+
Get factory for activation quantizer.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
quantization_method: quantization method for activation.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Factory that accepts activation bitwidth and a dict of quantization params, and returns the quantizer.
|
|
44
|
+
"""
|
|
45
|
+
return _activation_quantizer_factory_mapping[quantization_method]
|
|
@@ -18,7 +18,7 @@ from torch.nn import Conv2d, Linear, ConvTranspose2d
|
|
|
18
18
|
from model_compression_toolkit.core import QuantizationConfig
|
|
19
19
|
from model_compression_toolkit.core.common import Graph
|
|
20
20
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
21
|
-
from model_compression_toolkit.core.
|
|
21
|
+
from model_compression_toolkit.core.pytorch.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
|
|
22
22
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
23
23
|
from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \
|
|
24
24
|
compute_activation_bias_correction_of_graph
|
|
@@ -50,5 +50,6 @@ def pytorch_compute_activation_bias_correction_of_graph(graph: Graph,
|
|
|
50
50
|
fw_impl=fw_impl,
|
|
51
51
|
activation_bias_correction_node_matchers=
|
|
52
52
|
activation_bias_correction_node_matchers,
|
|
53
|
-
kernel_size=KERNEL_SIZE
|
|
53
|
+
kernel_size=KERNEL_SIZE,
|
|
54
|
+
get_activation_quantization_fn_factory=get_activation_quantization_fn_factory)
|
|
54
55
|
return graph
|
|
@@ -118,7 +118,7 @@ def core_runner(in_model: Any,
|
|
|
118
118
|
if core_config.is_mixed_precision_enabled:
|
|
119
119
|
if core_config.mixed_precision_config.configuration_overwrite is None:
|
|
120
120
|
|
|
121
|
-
filter_candidates_for_mixed_precision(graph, target_resource_utilization
|
|
121
|
+
filter_candidates_for_mixed_precision(graph, target_resource_utilization)
|
|
122
122
|
bit_widths_config = search_bit_width(tg,
|
|
123
123
|
fw_impl,
|
|
124
124
|
target_resource_utilization,
|
model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py
CHANGED
|
@@ -12,14 +12,13 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from typing import Callable
|
|
15
|
+
from typing import Callable, Optional, List
|
|
16
16
|
from io import BytesIO
|
|
17
17
|
|
|
18
18
|
import torch.nn
|
|
19
19
|
|
|
20
20
|
from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
|
|
21
21
|
|
|
22
|
-
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
|
|
23
22
|
from model_compression_toolkit.verify_packages import FOUND_ONNX
|
|
24
23
|
from model_compression_toolkit.logger import Logger
|
|
25
24
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
@@ -65,11 +64,14 @@ if FOUND_ONNX:
|
|
|
65
64
|
self._use_onnx_custom_quantizer_ops = use_onnx_custom_quantizer_ops
|
|
66
65
|
self._onnx_opset_version = onnx_opset_version
|
|
67
66
|
|
|
68
|
-
def export(self, output_names=None) -> None:
|
|
67
|
+
def export(self, output_names: Optional[List[str]] = None) -> None:
|
|
69
68
|
"""
|
|
70
69
|
Convert an exportable (fully-quantized) PyTorch model to a fakely-quant model
|
|
71
70
|
(namely, weights that are in fake-quant format) and fake-quant layers for the activations.
|
|
72
71
|
|
|
72
|
+
Args:
|
|
73
|
+
output_names (Optional[List[str]]): Optional list of output node names for export compatibility.
|
|
74
|
+
|
|
73
75
|
Returns:
|
|
74
76
|
Fake-quant PyTorch model.
|
|
75
77
|
"""
|
|
@@ -131,6 +133,8 @@ if FOUND_ONNX:
|
|
|
131
133
|
output_names = ['output']
|
|
132
134
|
dynamic_axes.update({'output': {0: 'batch_size'}})
|
|
133
135
|
else:
|
|
136
|
+
assert isinstance(output_names, list), \
|
|
137
|
+
f"`output_names` must be a list, but got {type(output_names).__name__}"
|
|
134
138
|
if isinstance(model_output, (list, tuple)):
|
|
135
139
|
num_of_outputs = len(model_output)
|
|
136
140
|
else:
|
|
@@ -49,7 +49,7 @@ class FakelyQuantTorchScriptPyTorchExporter(BasePyTorchExporter):
|
|
|
49
49
|
save_model_path,
|
|
50
50
|
repr_dataset)
|
|
51
51
|
|
|
52
|
-
def export(self) -> None:
|
|
52
|
+
def export(self, output_names=None) -> None:
|
|
53
53
|
"""
|
|
54
54
|
Convert an exportable (fully-quantized) PyTorch model to a fakely-quant model
|
|
55
55
|
(namely, weights that are in fake-quant format) and fake-quant layers for the activations.
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from typing import Callable
|
|
15
|
+
from typing import Callable, Optional, List
|
|
16
16
|
from packaging import version
|
|
17
17
|
|
|
18
18
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
|
@@ -49,7 +49,8 @@ if FOUND_TORCH:
|
|
|
49
49
|
is_layer_exportable_fn: Callable = is_pytorch_layer_exportable,
|
|
50
50
|
serialization_format: PytorchExportSerializationFormat = PytorchExportSerializationFormat.ONNX,
|
|
51
51
|
quantization_format: QuantizationFormat = QuantizationFormat.MCTQ,
|
|
52
|
-
onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION
|
|
52
|
+
onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION,
|
|
53
|
+
output_names: Optional[List[str]] = None) -> None:
|
|
53
54
|
"""
|
|
54
55
|
Export a PyTorch quantized model to a torchscript or onnx model.
|
|
55
56
|
The model will be saved to the path in save_model_path.
|
|
@@ -67,11 +68,19 @@ if FOUND_TORCH:
|
|
|
67
68
|
PytorchExportSerializationFormat.ONNX).
|
|
68
69
|
quantization_format: Format of how quantizers are exported (fakely-quant, int8, MCTQ quantizers).
|
|
69
70
|
onnx_opset_version: ONNX opset version to use for exported ONNX model.
|
|
71
|
+
output_names (Optional[List[str]]): Optional list of output node names for export compatibility.
|
|
72
|
+
This argument is relevant only when using PytorchExportSerializationFormat.ONNX.
|
|
70
73
|
|
|
71
74
|
"""
|
|
72
75
|
# Ensure 'metadata' is available directly on the model, if present in submodules
|
|
73
76
|
find_and_assign_metadata_attr(model)
|
|
74
77
|
|
|
78
|
+
if output_names is not None and serialization_format != PytorchExportSerializationFormat.ONNX:
|
|
79
|
+
Logger.warning(
|
|
80
|
+
f'`output_names` is only applicable when exporting to ONNX. '
|
|
81
|
+
f'Current serialization format is {serialization_format}, so `output_names` will be ignored.'
|
|
82
|
+
) # pragma: no cover
|
|
83
|
+
|
|
75
84
|
if serialization_format == PytorchExportSerializationFormat.TORCHSCRIPT:
|
|
76
85
|
if quantization_format in supported_serialization_quantization_export_dict[serialization_format]:
|
|
77
86
|
exporter = FakelyQuantTorchScriptPyTorchExporter(model,
|
|
@@ -109,7 +118,7 @@ if FOUND_TORCH:
|
|
|
109
118
|
f'Unsupported serialization {serialization_format} was used to export Pytorch model.'
|
|
110
119
|
f' Please see API for supported formats.') # pragma: no cover
|
|
111
120
|
|
|
112
|
-
exporter.export()
|
|
121
|
+
exporter.export(output_names=output_names)
|
|
113
122
|
|
|
114
123
|
else:
|
|
115
124
|
def pytorch_export_model(*args, **kwargs):
|
|
@@ -17,6 +17,7 @@ from typing import Callable, Tuple, Union
|
|
|
17
17
|
|
|
18
18
|
from model_compression_toolkit import get_target_platform_capabilities
|
|
19
19
|
from model_compression_toolkit.constants import TENSORFLOW
|
|
20
|
+
from model_compression_toolkit.quantization_preparation.load_fqc import load_fqc_configuration
|
|
20
21
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
|
21
22
|
from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
|
|
22
23
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
|
@@ -24,10 +25,8 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_
|
|
|
24
25
|
from model_compression_toolkit.core.common.pruning.pruner import Pruner
|
|
25
26
|
from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
|
|
26
27
|
from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
|
|
27
|
-
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
|
|
28
28
|
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
|
|
29
29
|
from model_compression_toolkit.logger import Logger
|
|
30
|
-
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
|
31
30
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
32
31
|
|
|
33
32
|
if FOUND_TF:
|
|
@@ -117,20 +116,17 @@ if FOUND_TF:
|
|
|
117
116
|
|
|
118
117
|
target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
|
|
119
118
|
# Attach tpc model to framework
|
|
120
|
-
|
|
121
|
-
target_platform_capabilities = attach2keras.attach(target_platform_capabilities)
|
|
119
|
+
framework_platform_capabilities = AttachTpcToKeras().attach(target_platform_capabilities)
|
|
122
120
|
|
|
123
121
|
# Convert the original Keras model to an internal graph representation.
|
|
124
122
|
float_graph = read_model_to_graph(model,
|
|
125
123
|
representative_data_gen,
|
|
126
|
-
|
|
124
|
+
framework_platform_capabilities,
|
|
127
125
|
fw_impl)
|
|
128
126
|
|
|
129
127
|
# Apply quantization configuration to the graph. This step is necessary even when not quantizing,
|
|
130
128
|
# as it prepares the graph for the pruning process.
|
|
131
|
-
float_graph_with_compression_config =
|
|
132
|
-
quant_config=DEFAULTCONFIG,
|
|
133
|
-
mixed_precision_enable=False)
|
|
129
|
+
float_graph_with_compression_config = load_fqc_configuration(float_graph, framework_platform_capabilities)
|
|
134
130
|
|
|
135
131
|
# Create a Pruner object with the graph and configuration.
|
|
136
132
|
pruner = Pruner(float_graph_with_compression_config,
|
|
@@ -138,7 +134,7 @@ if FOUND_TF:
|
|
|
138
134
|
target_resource_utilization,
|
|
139
135
|
representative_data_gen,
|
|
140
136
|
pruning_config,
|
|
141
|
-
|
|
137
|
+
framework_platform_capabilities)
|
|
142
138
|
|
|
143
139
|
# Apply the pruning process.
|
|
144
140
|
pruned_graph = pruner.prune_graph()
|
|
@@ -23,10 +23,9 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_
|
|
|
23
23
|
from model_compression_toolkit.core.common.pruning.pruner import Pruner
|
|
24
24
|
from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
|
|
25
25
|
from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
|
|
26
|
-
from model_compression_toolkit.
|
|
26
|
+
from model_compression_toolkit.quantization_preparation.load_fqc import load_fqc_configuration
|
|
27
27
|
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
|
|
28
28
|
from model_compression_toolkit.logger import Logger
|
|
29
|
-
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
|
30
29
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
31
30
|
|
|
32
31
|
|
|
@@ -134,9 +133,7 @@ if FOUND_TORCH:
|
|
|
134
133
|
|
|
135
134
|
# Apply quantization configuration to the graph. This step is necessary even when not quantizing,
|
|
136
135
|
# as it prepares the graph for the pruning process.
|
|
137
|
-
float_graph_with_compression_config =
|
|
138
|
-
quant_config=DEFAULTCONFIG,
|
|
139
|
-
mixed_precision_enable=False)
|
|
136
|
+
float_graph_with_compression_config = load_fqc_configuration(float_graph, framework_platform_capabilities)
|
|
140
137
|
|
|
141
138
|
# Create a Pruner object with the graph and configuration.
|
|
142
139
|
pruner = Pruner(float_graph_with_compression_config,
|
|
@@ -122,7 +122,7 @@ if FOUND_TF:
|
|
|
122
122
|
|
|
123
123
|
>>> quantized_model, quantization_info = mct.ptq.keras_post_training_quantization(model, repr_datagen, ru, core_config=config)
|
|
124
124
|
|
|
125
|
-
For more configuration options, please take a look at our `API documentation <https://
|
|
125
|
+
For more configuration options, please take a look at our `API documentation <https://sonysemiconductorsolutions.github.io/mct-model-optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
|
|
126
126
|
|
|
127
127
|
"""
|
|
128
128
|
|
|
@@ -167,7 +167,7 @@ if FOUND_TF:
|
|
|
167
167
|
|
|
168
168
|
>>> quantized_model = tf.keras.models.load_model(model_file, custom_objects=custom_objects)
|
|
169
169
|
|
|
170
|
-
For more configuration options, please take a look at our `API documentation <https://
|
|
170
|
+
For more configuration options, please take a look at our `API documentation <https://sonysemiconductorsolutions.github.io/mct-model-optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
|
|
171
171
|
|
|
172
172
|
"""
|
|
173
173
|
|
|
@@ -136,7 +136,7 @@ if FOUND_TORCH:
|
|
|
136
136
|
|
|
137
137
|
>>> quantized_model, quantization_info = mct.qat.pytorch_quantization_aware_training_init_experimental(model, repr_datagen, core_config=config)
|
|
138
138
|
|
|
139
|
-
For more configuration options, please take a look at our `API documentation <https://
|
|
139
|
+
For more configuration options, please take a look at our `API documentation <https://sonysemiconductorsolutions.github.io/mct-model-optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
|
|
140
140
|
|
|
141
141
|
"""
|
|
142
142
|
Logger.warning(
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import List, Optional
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
18
|
+
from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
|
|
19
|
+
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator
|
|
20
|
+
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
|
21
|
+
CandidateNodeQuantizationConfig, NodeQuantizationConfig
|
|
22
|
+
from model_compression_toolkit.core.common.quantization.node_quantization_config import \
|
|
23
|
+
NodeActivationQuantizationConfig, NodeWeightsQuantizationConfig, ActivationQuantizationMode
|
|
24
|
+
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import filter_node_qco_by_graph
|
|
25
|
+
from model_compression_toolkit.logger import Logger
|
|
26
|
+
from model_compression_toolkit.target_platform_capabilities import FrameworkQuantizationCapabilities, \
|
|
27
|
+
QuantizationConfigOptions, OpQuantizationConfig
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def load_fqc_configuration(graph: Graph, fqc: FrameworkQuantizationCapabilities):
|
|
31
|
+
"""
|
|
32
|
+
Set-up graph for quantization per TPC.
|
|
33
|
+
Each node will contain quantization candidates for mixed precision and the base config for single precision.
|
|
34
|
+
The graph will contain the fusing info.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
graph: graph.
|
|
38
|
+
fqc: framework quantization capabilities object.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Updated graph.
|
|
42
|
+
"""
|
|
43
|
+
graph = _set_nodes_quantization_configuration(graph, fqc)
|
|
44
|
+
graph = _set_fusion_info(graph, fqc)
|
|
45
|
+
|
|
46
|
+
return graph
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def set_quantization_configs_to_node(node: BaseNode,
|
|
50
|
+
graph: Graph,
|
|
51
|
+
fqc: FrameworkQuantizationCapabilities):
|
|
52
|
+
"""
|
|
53
|
+
Create and set quantization configurations to a node (for both weights and activation).
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
node (BaseNode): Node to set its quantization configurations.
|
|
57
|
+
graph (Graph): Model's internal representation graph.
|
|
58
|
+
fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to get default OpQuantizationConfig.
|
|
59
|
+
"""
|
|
60
|
+
qc_options = fetch_qc_options_for_node(node, fqc)
|
|
61
|
+
base_config, candidates_qcs = filter_node_qco_by_graph(node, fqc, graph, qc_options)
|
|
62
|
+
|
|
63
|
+
node_attrs_list = node.get_node_weights_attributes()
|
|
64
|
+
mp_candidates = [_create_candidate(node.channel_axis, qc, node_attrs_list)
|
|
65
|
+
for qc in candidates_qcs]
|
|
66
|
+
sp_cfg = _create_candidate(node.channel_axis, base_config, node_attrs_list)
|
|
67
|
+
|
|
68
|
+
node.quantization_cfg = NodeQuantizationConfig(base_quantization_cfg=sp_cfg,
|
|
69
|
+
candidates_quantization_cfg=mp_candidates)
|
|
70
|
+
|
|
71
|
+
# TODO is not needed anymore as find min/max candidate look for a real max/min, but some tests still count on it
|
|
72
|
+
node.sort_node_candidates()
|
|
73
|
+
|
|
74
|
+
if not node.has_activation:
|
|
75
|
+
node.quantization_cfg.update_activation_quantization_mode(ActivationQuantizationMode.NO_QUANT)
|
|
76
|
+
|
|
77
|
+
_disable_unsupported_quant_preserving(node, graph)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def fetch_qc_options_for_node(node: BaseNode,
|
|
81
|
+
fqc: FrameworkQuantizationCapabilities,
|
|
82
|
+
return_default=True) -> Optional[QuantizationConfigOptions]:
|
|
83
|
+
"""
|
|
84
|
+
Get quantization configuration options for the node from TPC.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
node: node for which to fetch quantization configuration.
|
|
88
|
+
fqc: framework quantization capabilities.
|
|
89
|
+
return_default: whether to return the default qco or None if node op is not in FQC.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Quantization configuration options for the node.
|
|
93
|
+
"""
|
|
94
|
+
# qcos by filters
|
|
95
|
+
filter_matches = [(fl, qco) for fl, qco in fqc.filterlayer2qco.items() if node.is_match_filter_params(fl)]
|
|
96
|
+
fls, filter_qcos = zip(*filter_matches) if filter_matches else (None, None)
|
|
97
|
+
if filter_qcos and any(qco != filter_qcos[0] for qco in filter_qcos[1:]):
|
|
98
|
+
raise ValueError(f'Cannot assign quantization configuration to {node} as it matches more than one filter with '
|
|
99
|
+
f'conflicting configs: {fls}.')
|
|
100
|
+
|
|
101
|
+
# qco by opset
|
|
102
|
+
# must use is_match_type for functional op in TF2.15
|
|
103
|
+
matches = [(op_type, qco) for op_type, qco in fqc.layer2qco.items() if node.is_match_type(op_type)]
|
|
104
|
+
op_types, qcos = zip(*matches) if matches else (None, None)
|
|
105
|
+
if qcos and any(qco != qcos[0] for qco in qcos[1:]):
|
|
106
|
+
raise ValueError(f'Cannot assign quantization configuration to {node} as it matches more than one op type with '
|
|
107
|
+
f'conflicting configs: {op_types}.')
|
|
108
|
+
|
|
109
|
+
# if node matches by both filter and opset, filter takes priority
|
|
110
|
+
if filter_qcos:
|
|
111
|
+
return filter_qcos[0]
|
|
112
|
+
|
|
113
|
+
if qcos:
|
|
114
|
+
return qcos[0]
|
|
115
|
+
|
|
116
|
+
return fqc.tpc.default_qco if return_default else None
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _set_nodes_quantization_configuration(graph: Graph,
|
|
120
|
+
fqc: FrameworkQuantizationCapabilities) -> Graph:
|
|
121
|
+
"""
|
|
122
|
+
Set quantization configuration for each graph node.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
graph: graph to set with quantization configuration.
|
|
126
|
+
fqc: framework quantization capabilities.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Graph: The graph with quantization configurations attached to each node in it.
|
|
130
|
+
"""
|
|
131
|
+
_validate_custom_ops_have_qco(graph, fqc)
|
|
132
|
+
|
|
133
|
+
for n in graph.get_topo_sorted_nodes():
|
|
134
|
+
set_quantization_configs_to_node(node=n,
|
|
135
|
+
graph=graph,
|
|
136
|
+
fqc=fqc)
|
|
137
|
+
return graph
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _set_fusion_info(graph: Graph, fqc: FrameworkQuantizationCapabilities) -> Graph:
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
graph: graph.
|
|
145
|
+
fqc: quantization capabilities with attached framework.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
|
|
149
|
+
"""
|
|
150
|
+
# TODO fix the dict with const keys inside get_fusing_patterns. use named tuple or class
|
|
151
|
+
# TODO irena instead of storing fusion inside graph (including tpc objects) and then let graph convert tpc op config to
|
|
152
|
+
# node config, do it here and only store in graph whatever is relevant after this stage.
|
|
153
|
+
fusing_info = FusingInfoGenerator(fqc.get_fusing_patterns()).generate_fusing_info(graph)
|
|
154
|
+
graph.fusing_info = fusing_info
|
|
155
|
+
graph.override_fused_node_activation_quantization_candidates()
|
|
156
|
+
return graph
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _disable_unsupported_quant_preserving(node: BaseNode, graph: Graph):
|
|
160
|
+
"""
|
|
161
|
+
Disable quantization for quantization preserving ops in cases it cannot be supported
|
|
162
|
+
(multiple inputs or un-quantized previous node).
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
node: current node.
|
|
166
|
+
graph: graph.
|
|
167
|
+
"""
|
|
168
|
+
if not node.quantization_cfg.get_activation_quant_mode() == ActivationQuantizationMode.PRESERVE_QUANT:
|
|
169
|
+
return
|
|
170
|
+
|
|
171
|
+
prev_nodes = graph.get_prev_nodes(node)
|
|
172
|
+
if len(prev_nodes) != 1:
|
|
173
|
+
Logger.info(f'Disabling Quantization-Preserving for node {node.name} with {len(prev_nodes)} inputs.')
|
|
174
|
+
node.quantization_cfg.update_activation_quantization_mode(ActivationQuantizationMode.NO_QUANT)
|
|
175
|
+
elif prev_nodes[0].quantization_cfg.get_activation_quant_mode() == ActivationQuantizationMode.NO_QUANT:
|
|
176
|
+
Logger.info(f'Disabling Quantization-Preserving for node {node.name} since previous node activation '
|
|
177
|
+
f'quantization is disabled.')
|
|
178
|
+
node.quantization_cfg.update_activation_quantization_mode(ActivationQuantizationMode.NO_QUANT)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
# TODO irena copied from graph.set_fqc as is. Why does it have Keras errors?
|
|
182
|
+
def _validate_custom_ops_have_qco(graph, fqc):
|
|
183
|
+
custom_nodes = [n for n in graph.nodes if n.is_custom]
|
|
184
|
+
for n in custom_nodes:
|
|
185
|
+
qco = fetch_qc_options_for_node(n, fqc, return_default=False)
|
|
186
|
+
if not qco:
|
|
187
|
+
Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. '
|
|
188
|
+
' Please add the custom layer to Framework Quantization Capabilities (FQC), or file a feature '
|
|
189
|
+
'request or an issue if you believe this should be supported.') # pragma: no cover
|
|
190
|
+
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in qco.quantization_configurations]):
|
|
191
|
+
Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') # pragma: no cover
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _create_candidate(weight_channel_axis: ChannelAxisMapping,
|
|
195
|
+
op_cfg: OpQuantizationConfig,
|
|
196
|
+
node_attrs_list: List[str]) -> CandidateNodeQuantizationConfig:
|
|
197
|
+
"""
|
|
198
|
+
Create quantization configuration candidate.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
weight_channel_axis: channels axes of the node's kernel.
|
|
202
|
+
op_cfg: quantization config for the op.
|
|
203
|
+
node_attrs_list: A list of the node's weights attributes names.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
Candidate quantization config.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
aqc = NodeActivationQuantizationConfig(op_cfg=op_cfg)
|
|
210
|
+
|
|
211
|
+
# TODO: remove this validation and warning once enabling all attributes quantization by default
|
|
212
|
+
attrs_with_enabled_quantization = [attr for attr, cfg in op_cfg.attr_weights_configs_mapping.items()
|
|
213
|
+
if cfg.enable_weights_quantization]
|
|
214
|
+
if len(attrs_with_enabled_quantization) > 1:
|
|
215
|
+
Logger.warning(f"Multiple weights attributes quantization is enabled via the provided FQC."
|
|
216
|
+
f"Quantizing any attribute other than the kernel is experimental "
|
|
217
|
+
f"and may be subject to unstable behavior."
|
|
218
|
+
f"Attributes with enabled weights quantization: {attrs_with_enabled_quantization}.")
|
|
219
|
+
wqc = NodeWeightsQuantizationConfig(op_cfg=op_cfg,
|
|
220
|
+
weights_channels_axis=weight_channel_axis,
|
|
221
|
+
node_attrs_list=node_attrs_list)
|
|
222
|
+
|
|
223
|
+
return CandidateNodeQuantizationConfig(activation_quantization_cfg=aqc, weights_quantization_cfg=wqc)
|
|
@@ -29,7 +29,7 @@ QNNPACK_TP_MODEL = 'qnnpack'
|
|
|
29
29
|
# TP Attributes
|
|
30
30
|
KERNEL_ATTR = "kernel_attr"
|
|
31
31
|
BIAS_ATTR = "bias_attr"
|
|
32
|
-
|
|
32
|
+
POSITIONAL_ATTR = "pos_attr"
|
|
33
33
|
|
|
34
34
|
# TODO: this is duplicated from the core frameworks constants files, because the original consts can't be used here
|
|
35
35
|
# duo to circular dependency. It might be best to extract the constants from the core file and put them here (in a
|