mct-nightly 2.4.0.20250705.556__py3-none-any.whl → 2.4.0.20250707.643__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250705.556.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/METADATA +1 -1
- {mct_nightly-2.4.0.20250705.556.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/RECORD +36 -38
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +4 -1
- model_compression_toolkit/core/common/collectors/mean_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +7 -4
- model_compression_toolkit/core/common/model_collector.py +11 -0
- model_compression_toolkit/core/common/pruning/memory_calculator.py +1 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +22 -87
- model_compression_toolkit/core/common/quantization/quantization_config.py +0 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +23 -17
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +26 -48
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +12 -7
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +4 -14
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -1
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +4 -13
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +5 -7
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +7 -5
- model_compression_toolkit/core/graph_prep_runner.py +1 -11
- model_compression_toolkit/core/keras/default_framework_info.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +21 -11
- model_compression_toolkit/core/keras/keras_implementation.py +2 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +8 -0
- model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +9 -1
- model_compression_toolkit/core/quantization_prep_runner.py +2 -2
- model_compression_toolkit/gptq/keras/quantization_facade.py +0 -3
- model_compression_toolkit/ptq/keras/quantization_facade.py +0 -3
- model_compression_toolkit/qat/keras/quantization_facade.py +0 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +0 -2
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +0 -6
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +2 -4
- model_compression_toolkit/core/common/model_validation.py +0 -41
- model_compression_toolkit/core/keras/keras_model_validation.py +0 -37
- {mct_nightly-2.4.0.20250705.556.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250705.556.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250705.556.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/top_level.txt +0 -0
|
@@ -18,21 +18,25 @@ from typing import Dict, Union, Optional, Tuple, Callable
|
|
|
18
18
|
from mct_quantizers import QuantizationMethod
|
|
19
19
|
|
|
20
20
|
import model_compression_toolkit.core.common.quantization.quantization_params_generation as qpg
|
|
21
|
+
from model_compression_toolkit.constants import MIN_THRESHOLD
|
|
21
22
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
|
|
22
23
|
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
|
|
23
24
|
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
|
|
24
25
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
|
|
25
|
-
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationErrorMethod
|
|
26
|
+
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationErrorMethod, \
|
|
27
|
+
QuantizationConfig
|
|
26
28
|
|
|
27
29
|
|
|
28
|
-
def compute_activation_qparams(
|
|
30
|
+
def compute_activation_qparams(quant_cfg: QuantizationConfig,
|
|
31
|
+
node_activation_quant_cfg: NodeActivationQuantizationConfig,
|
|
29
32
|
node_prior_info: NodePriorInfo,
|
|
30
33
|
out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]:
|
|
31
34
|
"""
|
|
32
35
|
Compute the activations params for a given node in a graph according to a params function.
|
|
33
36
|
|
|
34
37
|
Args:
|
|
35
|
-
|
|
38
|
+
quant_cfg: quantization config.
|
|
39
|
+
node_activation_quant_cfg: node's activation quantization configuration.
|
|
36
40
|
node_prior_info: Prior info collected for the node that is being quantized.
|
|
37
41
|
out_stats_container: Tensor containing output statistics of the node.
|
|
38
42
|
|
|
@@ -40,41 +44,43 @@ def compute_activation_qparams(activation_quant_cfg: NodeActivationQuantizationC
|
|
|
40
44
|
The computed activation quantization params.
|
|
41
45
|
"""
|
|
42
46
|
activation_quantization_params_fn = _get_activation_quantization_params_fn(
|
|
43
|
-
|
|
47
|
+
node_activation_quant_cfg.activation_quantization_method, no_clipping=node_prior_info.is_output_bounded())
|
|
44
48
|
|
|
45
49
|
# Extract and filter histogram data from the statistics container.
|
|
46
|
-
bins_values, bins_counts = _get_histogram_data(
|
|
50
|
+
bins_values, bins_counts = _get_histogram_data(out_stats_container,
|
|
51
|
+
activation_error_method=quant_cfg.activation_error_method,
|
|
52
|
+
z_threshold=quant_cfg.z_threshold)
|
|
47
53
|
|
|
48
54
|
# Retrieve the minimum and maximum values from the statistics container.
|
|
49
55
|
min_value, max_value = out_stats_container.get_min_max_values()
|
|
50
56
|
|
|
51
57
|
# Determine if the activations should be considered signed.
|
|
52
|
-
signed = _determine_signedness(
|
|
58
|
+
signed = _determine_signedness(node_activation_quant_cfg, node_prior_info, min_value, bins_values, bins_counts)
|
|
53
59
|
|
|
54
60
|
# Compute and return the activation quantization parameters.
|
|
55
61
|
return activation_quantization_params_fn(
|
|
56
62
|
bins_values,
|
|
57
63
|
bins_counts,
|
|
58
|
-
|
|
59
|
-
|
|
64
|
+
quant_cfg.l_p_value,
|
|
65
|
+
node_activation_quant_cfg.activation_n_bits,
|
|
60
66
|
min_value,
|
|
61
67
|
max_value,
|
|
62
|
-
min_threshold=
|
|
63
|
-
quant_error_method=
|
|
68
|
+
min_threshold=MIN_THRESHOLD,
|
|
69
|
+
quant_error_method=quant_cfg.activation_error_method,
|
|
64
70
|
is_signed=signed
|
|
65
71
|
)
|
|
66
72
|
|
|
67
73
|
|
|
68
|
-
def _get_histogram_data(
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
|
74
|
+
def _get_histogram_data(out_stats_container: BaseStatsCollector,
|
|
75
|
+
activation_error_method: QuantizationErrorMethod,
|
|
76
|
+
z_threshold: float) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
|
72
77
|
"""
|
|
73
78
|
Extract and filter the histogram data from the statistics container.
|
|
74
79
|
|
|
75
80
|
Args:
|
|
76
|
-
activation_quant_cfg: Node's activation quantization configuration.
|
|
77
81
|
out_stats_container: Statistics container with histogram data.
|
|
82
|
+
activation_error_method: activation quantization error method.
|
|
83
|
+
z_threshold: z threshold for z-score filtering.
|
|
78
84
|
|
|
79
85
|
Returns:
|
|
80
86
|
A tuple containing the filtered bins_values and bins_counts.
|
|
@@ -83,12 +89,12 @@ def _get_histogram_data(
|
|
|
83
89
|
# If the statistics container collected the histogram, we start by filtering outliers using z threshold
|
|
84
90
|
# filtering, and then computing the threshold based on the filtered histogram.
|
|
85
91
|
if out_stats_container.require_collection():
|
|
86
|
-
if
|
|
92
|
+
if activation_error_method == QuantizationErrorMethod.HMSE:
|
|
87
93
|
bins_values, bins_counts = out_stats_container.weighted_hc.get_histogram()
|
|
88
94
|
else:
|
|
89
95
|
bins_values, bins_counts = out_stats_container.hc.get_histogram()
|
|
90
96
|
bins_counts = qpg.z_score_filter(
|
|
91
|
-
|
|
97
|
+
z_threshold,
|
|
92
98
|
bins_values,
|
|
93
99
|
bins_counts
|
|
94
100
|
)
|
|
@@ -18,7 +18,7 @@ from tqdm import tqdm
|
|
|
18
18
|
from typing import List, Callable, Generator
|
|
19
19
|
|
|
20
20
|
from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
|
|
21
|
-
from model_compression_toolkit.core import QuantizationErrorMethod
|
|
21
|
+
from model_compression_toolkit.core import QuantizationErrorMethod, QuantizationConfig
|
|
22
22
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
23
23
|
from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
|
|
24
24
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
@@ -31,29 +31,8 @@ from model_compression_toolkit.core.common.quantization.quantization_params_gene
|
|
|
31
31
|
from model_compression_toolkit.logger import Logger
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
def _collect_nodes_for_hmse(nodes_list: List[BaseNode], graph: Graph) -> List[BaseNode]:
|
|
35
|
-
"""
|
|
36
|
-
Collects nodes that are compatiable for parameters selection search using HMSE,
|
|
37
|
-
that is, have a kernel attribute that is configured for HMSE error method.
|
|
38
|
-
|
|
39
|
-
Args:
|
|
40
|
-
nodes_list: A list of nodes to search quantization parameters for.
|
|
41
|
-
graph: Graph to compute its nodes' quantization parameters..
|
|
42
|
-
|
|
43
|
-
Returns: A (possibly empty) list of nodes.
|
|
44
|
-
|
|
45
|
-
"""
|
|
46
|
-
hmse_nodes = []
|
|
47
|
-
for n in nodes_list:
|
|
48
|
-
if n.kernel_attr is not None and n.is_weights_quantization_enabled(n.kernel_attr) and \
|
|
49
|
-
all([c.weights_quantization_cfg.get_attr_config(n.kernel_attr).weights_error_method ==
|
|
50
|
-
QuantizationErrorMethod.HMSE for c in n.candidates_quantization_cfg]):
|
|
51
|
-
hmse_nodes.append(n)
|
|
52
|
-
|
|
53
|
-
return hmse_nodes
|
|
54
|
-
|
|
55
|
-
|
|
56
34
|
def calculate_quantization_params(graph: Graph,
|
|
35
|
+
quant_cfg: QuantizationConfig,
|
|
57
36
|
fw_impl: FrameworkImplementation,
|
|
58
37
|
repr_data_gen_fn: Callable[[], Generator],
|
|
59
38
|
nodes: List[BaseNode] = None,
|
|
@@ -68,6 +47,7 @@ def calculate_quantization_params(graph: Graph,
|
|
|
68
47
|
|
|
69
48
|
Args:
|
|
70
49
|
graph: Graph to compute its nodes' thresholds.
|
|
50
|
+
quant_cfg: quantization config.
|
|
71
51
|
fw_impl: FrameworkImplementation object.
|
|
72
52
|
repr_data_gen_fn: callable returning representative dataset generator.
|
|
73
53
|
nodes: List of nodes to compute their thresholds instead of computing it for all nodes in the graph.
|
|
@@ -85,15 +65,16 @@ def calculate_quantization_params(graph: Graph,
|
|
|
85
65
|
# Collecting nodes that are configured to search weights quantization parameters using HMSE optimization
|
|
86
66
|
# and computing required Hessian information to be used for HMSE parameters selection.
|
|
87
67
|
# The Hessian scores are computed and stored in the hessian_info_service object.
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
68
|
+
if quant_cfg.weights_error_method == QuantizationErrorMethod.HMSE:
|
|
69
|
+
nodes_for_hmse = [n for n in nodes_list if n.kernel_attr and n.is_weights_quantization_enabled(n.kernel_attr)]
|
|
70
|
+
if nodes_for_hmse:
|
|
71
|
+
dataloader = fw_impl.convert_data_gen_to_dataloader(repr_data_gen_fn, batch_size=1)
|
|
72
|
+
request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
|
|
73
|
+
granularity=HessianScoresGranularity.PER_ELEMENT,
|
|
74
|
+
data_loader=dataloader,
|
|
75
|
+
n_samples=num_hessian_samples,
|
|
76
|
+
target_nodes=nodes_for_hmse)
|
|
77
|
+
hessian_info_service.fetch_hessian(request)
|
|
97
78
|
|
|
98
79
|
for n in tqdm(nodes_list, "Calculating quantization parameters"): # iterate only nodes that we should compute their thresholds
|
|
99
80
|
for candidate_qc in n.candidates_quantization_cfg:
|
|
@@ -101,28 +82,24 @@ def calculate_quantization_params(graph: Graph,
|
|
|
101
82
|
if n.is_weights_quantization_enabled(attr):
|
|
102
83
|
# If the node's weights attribute should be quantized, we compute its quantization parameters
|
|
103
84
|
attr_cfg = candidate_qc.weights_quantization_cfg.get_attr_config(attr)
|
|
104
|
-
|
|
105
|
-
if channels_axis is not None:
|
|
106
|
-
output_channels_axis = channels_axis[0]
|
|
107
|
-
else:
|
|
108
|
-
output_channels_axis = None
|
|
109
|
-
|
|
110
|
-
mod_attr_cfg = attr_cfg
|
|
85
|
+
output_channels_axis = attr_cfg.weights_channels_axis.output
|
|
111
86
|
|
|
112
|
-
|
|
87
|
+
weights_error_method = quant_cfg.weights_error_method
|
|
88
|
+
if weights_error_method == QuantizationErrorMethod.HMSE:
|
|
113
89
|
# Although we collected nodes for HMSE before running the loop, we keep this verification to
|
|
114
90
|
# notify the user in case of HMSE configured for node that is not compatible for this method
|
|
115
91
|
if n.kernel_attr is None or n.kernel_attr not in attr:
|
|
116
92
|
Logger.warning(f"The HMSE error method for parameters selection is only supported for "
|
|
117
93
|
f"kernel weights attributes. Running parameters selection for attribute "
|
|
118
94
|
f"'{attr}' in node '{n.name}' with the default MSE error method instead.")
|
|
119
|
-
|
|
120
|
-
mod_attr_cfg.weights_error_method = QuantizationErrorMethod.MSE
|
|
95
|
+
weights_error_method = QuantizationErrorMethod.MSE
|
|
121
96
|
|
|
122
|
-
min_threshold = candidate_qc.weights_quantization_cfg.min_threshold
|
|
123
97
|
weights_params, output_channels_axis = compute_weights_qparams(n.get_weights_by_keys(attr),
|
|
124
|
-
|
|
125
|
-
|
|
98
|
+
attr_cfg,
|
|
99
|
+
weights_error_method,
|
|
100
|
+
quant_cfg.l_p_value,
|
|
101
|
+
output_channels_axis,
|
|
102
|
+
node=n,
|
|
126
103
|
hessian_info_service=hessian_info_service,
|
|
127
104
|
num_hessian_samples=num_hessian_samples)
|
|
128
105
|
attr_cfg.weights_channels_axis = ChannelAxisMapping(output_channels_axis, attr_cfg.weights_channels_axis.input)
|
|
@@ -130,8 +107,9 @@ def calculate_quantization_params(graph: Graph,
|
|
|
130
107
|
|
|
131
108
|
if n.is_activation_quantization_enabled():
|
|
132
109
|
# If node's activations should be quantized as well, we compute its activation quantization parameters
|
|
133
|
-
activation_params = compute_activation_qparams(
|
|
134
|
-
|
|
135
|
-
|
|
110
|
+
activation_params = compute_activation_qparams(quant_cfg=quant_cfg,
|
|
111
|
+
node_activation_quant_cfg=candidate_qc.activation_quantization_cfg,
|
|
112
|
+
node_prior_info=n.prior_info,
|
|
113
|
+
out_stats_container=graph.get_out_stats_collector(n))
|
|
136
114
|
# Create a NodeQuantizationConfig containing all quantization params and attach it to the node
|
|
137
115
|
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params)
|
|
@@ -18,7 +18,8 @@ from typing import Dict, Any, Tuple, Callable, TYPE_CHECKING
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
from mct_quantizers import QuantizationMethod
|
|
20
20
|
|
|
21
|
-
from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
|
|
21
|
+
from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES, MIN_THRESHOLD
|
|
22
|
+
from model_compression_toolkit.core import QuantizationErrorMethod
|
|
22
23
|
from model_compression_toolkit.core.common.hessian import HessianInfoService
|
|
23
24
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation import \
|
|
24
25
|
power_of_two_selection_tensor, lut_kmeans_tensor, symmetric_selection_tensor, uniform_selection_tensor
|
|
@@ -28,10 +29,12 @@ if TYPE_CHECKING:
|
|
|
28
29
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
def compute_weights_qparams(
|
|
32
|
+
def compute_weights_qparams(weights_attr_data: np.ndarray,
|
|
32
33
|
attr_quant_config: 'WeightsAttrQuantizationConfig',
|
|
34
|
+
weights_error_method: QuantizationErrorMethod,
|
|
35
|
+
l_p_value: int,
|
|
33
36
|
output_channels_axis: int,
|
|
34
|
-
min_threshold: float,
|
|
37
|
+
min_threshold: float = MIN_THRESHOLD,
|
|
35
38
|
node=None,
|
|
36
39
|
hessian_info_service: HessianInfoService = None,
|
|
37
40
|
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Tuple[Dict[Any, Any], int]:
|
|
@@ -40,8 +43,10 @@ def compute_weights_qparams(weights_attr_values: np.ndarray,
|
|
|
40
43
|
instance.
|
|
41
44
|
|
|
42
45
|
Args:
|
|
43
|
-
|
|
46
|
+
weights_attr_data: Weights attribute parameter to compute the quantization thresholds for.
|
|
44
47
|
attr_quant_config: A specific weights attribute quantization configuration to get its params.
|
|
48
|
+
weights_error_method: quantization error method.
|
|
49
|
+
l_p_value: p-norm to use for the Lp-norm distance.
|
|
45
50
|
output_channels_axis: Index of the kernel output channels dimension.
|
|
46
51
|
min_threshold: Minimal threshold to use if threshold is too small.
|
|
47
52
|
node: The node for which the quantization error is computed (used only with HMSE error method).
|
|
@@ -54,13 +59,13 @@ def compute_weights_qparams(weights_attr_values: np.ndarray,
|
|
|
54
59
|
"""
|
|
55
60
|
params_fn = _get_weights_quantization_params_fn(attr_quant_config.weights_quantization_method)
|
|
56
61
|
weights_params, output_channels_axis = params_fn(
|
|
57
|
-
|
|
58
|
-
p=
|
|
62
|
+
weights_attr_data,
|
|
63
|
+
p=l_p_value,
|
|
59
64
|
n_bits=attr_quant_config.weights_n_bits,
|
|
60
65
|
per_channel=attr_quant_config.weights_per_channel_threshold,
|
|
61
66
|
channel_axis=output_channels_axis,
|
|
62
67
|
min_threshold=min_threshold,
|
|
63
|
-
quant_error_method=
|
|
68
|
+
quant_error_method=weights_error_method,
|
|
64
69
|
node=node,
|
|
65
70
|
hessian_info_service=hessian_info_service,
|
|
66
71
|
num_hessian_samples=num_hessian_samples)
|
model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py
CHANGED
|
@@ -14,8 +14,6 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
import copy
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
|
18
|
-
from model_compression_toolkit.core import CoreConfig
|
|
19
17
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
20
18
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
21
19
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
|
|
@@ -23,7 +21,6 @@ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_s
|
|
|
23
21
|
|
|
24
22
|
|
|
25
23
|
def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
|
|
26
|
-
core_config: CoreConfig,
|
|
27
24
|
fw_impl: FrameworkImplementation) -> Graph:
|
|
28
25
|
"""
|
|
29
26
|
Get a graph, where each node has a final weights quantization configuration (with a bias
|
|
@@ -31,7 +28,6 @@ def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
|
|
|
31
28
|
|
|
32
29
|
Args:
|
|
33
30
|
graph_to_apply_bias_correction: Graph to apply bias correction to.
|
|
34
|
-
core_config: CoreConfig containing parameters of how the model should be quantized.
|
|
35
31
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
|
36
32
|
|
|
37
33
|
Returns:
|
|
@@ -40,20 +36,14 @@ def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
|
|
|
40
36
|
|
|
41
37
|
graph = copy.deepcopy(graph_to_apply_bias_correction)
|
|
42
38
|
for n in graph.nodes:
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
not n.final_weights_quantization_cfg.weights_second_moment_correction:
|
|
47
|
-
# If a kernel was quantized and weights bias correction is enabled in n.quantization_cfg,
|
|
48
|
-
# a bias correction term was calculated during model preparation, and is used now in the node's bias term.
|
|
49
|
-
if n.final_weights_quantization_cfg.weights_bias_correction:
|
|
50
|
-
_apply_bias_correction_to_node(n, fw_impl, core_config.quantization_config)
|
|
39
|
+
if (n.final_weights_quantization_cfg and n.final_weights_quantization_cfg.bias_corrected is not None and
|
|
40
|
+
not n.final_weights_quantization_cfg.weights_second_moment_correction):
|
|
41
|
+
_apply_bias_correction_to_node(n, fw_impl)
|
|
51
42
|
return graph
|
|
52
43
|
|
|
53
44
|
|
|
54
45
|
def _apply_bias_correction_to_node(node: BaseNode,
|
|
55
|
-
fw_impl: FrameworkImplementation
|
|
56
|
-
qc: QuantizationConfig):
|
|
46
|
+
fw_impl: FrameworkImplementation):
|
|
57
47
|
"""
|
|
58
48
|
Set new bias to node using the bias-correction term that is stored in the
|
|
59
49
|
final weights quantization configuration.
|
|
@@ -52,7 +52,8 @@ def _collect_and_assign_act_threshold(graph: Graph,
|
|
|
52
52
|
|
|
53
53
|
for n in graph.nodes:
|
|
54
54
|
if n.is_activation_quantization_enabled():
|
|
55
|
-
activation_params = compute_activation_qparams(
|
|
55
|
+
activation_params = compute_activation_qparams(quant_cfg=core_config.quantization_config,
|
|
56
|
+
node_activation_quant_cfg=n.final_activation_quantization_cfg,
|
|
56
57
|
node_prior_info=n.prior_info,
|
|
57
58
|
out_stats_container=graph.get_out_stats_collector(n))
|
|
58
59
|
n.final_activation_quantization_cfg.set_activation_quantization_param(activation_params)
|
model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py
CHANGED
|
@@ -43,17 +43,9 @@ def compute_bias_correction_of_graph(graph: Graph,
|
|
|
43
43
|
for n in graph.nodes:
|
|
44
44
|
# Bias correction is computed based on the quantized kernel, so we need to get the specific kernel attribute
|
|
45
45
|
# name out of all the weights attributes of the node.
|
|
46
|
-
if n.kernel_attr:
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
if n.has_positional_weights:
|
|
50
|
-
for candidate_qc in n.candidates_quantization_cfg:
|
|
51
|
-
candidate_qc.weights_quantization_cfg.weights_bias_correction = False
|
|
52
|
-
else:
|
|
53
|
-
_compute_bias_correction_per_candidate_qc(n,
|
|
54
|
-
n.kernel_attr,
|
|
55
|
-
graph.get_in_stats_collector(n),
|
|
56
|
-
fw_impl=fw_impl)
|
|
46
|
+
if n.kernel_attr and n.is_weights_quantization_enabled(n.kernel_attr) and not n.has_positional_weights:
|
|
47
|
+
_compute_bias_correction_per_candidate_qc(n, n.kernel_attr, graph.get_in_stats_collector(n),
|
|
48
|
+
fw_impl=fw_impl)
|
|
57
49
|
return graph
|
|
58
50
|
|
|
59
51
|
|
|
@@ -74,8 +66,7 @@ def _compute_bias_correction_per_candidate_qc(node: BaseNode,
|
|
|
74
66
|
"""
|
|
75
67
|
|
|
76
68
|
for candidate_qc in node.candidates_quantization_cfg:
|
|
77
|
-
if candidate_qc.weights_quantization_cfg.
|
|
78
|
-
candidate_qc.weights_quantization_cfg.weights_second_moment_correction:
|
|
69
|
+
if not candidate_qc.weights_quantization_cfg.weights_second_moment_correction:
|
|
79
70
|
|
|
80
71
|
quantized_kernel, io_channels_axes = get_quantized_weights_attr_by_qc(kernel_attr,
|
|
81
72
|
node,
|
|
@@ -56,8 +56,9 @@ def statistics_correction_runner(transformed_graph: Graph,
|
|
|
56
56
|
########################################################
|
|
57
57
|
# Compute bias correction to nodes' config candidates
|
|
58
58
|
########################################################
|
|
59
|
-
|
|
60
|
-
|
|
59
|
+
if core_config.quantization_config.weights_bias_correction:
|
|
60
|
+
tg_with_bias = compute_bias_correction_of_graph(tg_with_bias,
|
|
61
|
+
fw_impl)
|
|
61
62
|
|
|
62
63
|
if tb_w is not None:
|
|
63
64
|
tb_w.add_graph(tg_with_bias, 'statistics_computation')
|
|
@@ -96,7 +97,6 @@ def apply_statistics_correction(transformed_graph: Graph,
|
|
|
96
97
|
#############################################
|
|
97
98
|
if core_config.quantization_config.weights_bias_correction:
|
|
98
99
|
transformed_graph = apply_bias_correction_to_graph(transformed_graph,
|
|
99
|
-
core_config,
|
|
100
100
|
fw_impl=fw_impl)
|
|
101
101
|
if tb_w is not None:
|
|
102
102
|
tb_w.add_graph(transformed_graph, 'after_statistics_correction')
|
|
@@ -20,7 +20,6 @@ from typing import Callable
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
|
|
22
22
|
from model_compression_toolkit.core.common import Graph
|
|
23
|
-
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
|
24
23
|
from model_compression_toolkit.core import common
|
|
25
24
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
|
|
26
25
|
ActivationQuantizationMode
|
|
@@ -84,14 +83,10 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
|
|
|
84
83
|
# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
|
|
85
84
|
# we should skip the substitution.
|
|
86
85
|
if source_node.is_reused():
|
|
87
|
-
for qc in source_node.candidates_quantization_cfg:
|
|
88
|
-
qc.weights_quantization_cfg.weights_second_moment_correction = False
|
|
89
86
|
return graph
|
|
90
87
|
|
|
91
88
|
# We apply only on nodes with folded BatchNormalization.
|
|
92
89
|
if source_node.prior_info.std_output is None or source_node.prior_info.mean_output is None:
|
|
93
|
-
for qc in source_node.candidates_quantization_cfg:
|
|
94
|
-
qc.weights_quantization_cfg.weights_second_moment_correction = False
|
|
95
90
|
return graph
|
|
96
91
|
|
|
97
92
|
# This feature disabled for models with weights quantization method of Power of 2
|
|
@@ -103,10 +98,13 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
|
|
|
103
98
|
== QuantizationMethod.POWER_OF_TWO):
|
|
104
99
|
Logger.warning("Second moment statistics correction feature disabled for models with weights "
|
|
105
100
|
"quantization method of Power of 2")
|
|
106
|
-
for qc_inner in source_node.candidates_quantization_cfg:
|
|
107
|
-
qc_inner.weights_quantization_cfg.weights_second_moment_correction = False
|
|
108
101
|
return graph
|
|
109
102
|
|
|
103
|
+
# turn on second moment correction flag
|
|
104
|
+
def set_second_moment_correction(qc):
|
|
105
|
+
qc.weights_quantization_cfg.weights_second_moment_correction = True
|
|
106
|
+
source_node.quantization_cfg.update_all(set_second_moment_correction)
|
|
107
|
+
|
|
110
108
|
eps = self.epsilon_val
|
|
111
109
|
|
|
112
110
|
original_gamma = source_node.prior_info.std_output
|
|
@@ -298,7 +298,7 @@ def shift_negative_function(graph: Graph,
|
|
|
298
298
|
|
|
299
299
|
negative_rate = np.abs(min_to_correct) / activation_threshold
|
|
300
300
|
|
|
301
|
-
enable_sub = negative_rate <=
|
|
301
|
+
enable_sub = negative_rate <= core_config.quantization_config.shift_negative_ratio
|
|
302
302
|
if min_to_correct >= 0 or not enable_sub:
|
|
303
303
|
return graph
|
|
304
304
|
|
|
@@ -316,7 +316,7 @@ def shift_negative_function(graph: Graph,
|
|
|
316
316
|
if core_config.quantization_config.shift_negative_params_search:
|
|
317
317
|
|
|
318
318
|
hist_bins, hist_count = graph.get_out_stats_collector(non_linear_node).hc.get_histogram()
|
|
319
|
-
hist_count = z_score_filter(
|
|
319
|
+
hist_count = z_score_filter(core_config.quantization_config.z_threshold,
|
|
320
320
|
hist_bins, hist_count)
|
|
321
321
|
|
|
322
322
|
min_mse, _th, _shift = np.inf, None, None
|
|
@@ -471,10 +471,12 @@ def shift_negative_function(graph: Graph,
|
|
|
471
471
|
pad_node=pad_node,
|
|
472
472
|
op2d_node=op2d_node)
|
|
473
473
|
|
|
474
|
-
if
|
|
475
|
-
activation_param = compute_activation_qparams(
|
|
474
|
+
if core_config.quantization_config.shift_negative_threshold_recalculation:
|
|
475
|
+
activation_param = compute_activation_qparams(quant_cfg=core_config.quantization_config,
|
|
476
|
+
node_activation_quant_cfg=non_linear_node_cfg_candidate,
|
|
476
477
|
node_prior_info=non_linear_node.prior_info,
|
|
477
|
-
out_stats_container=graph.get_out_stats_collector(
|
|
478
|
+
out_stats_container=graph.get_out_stats_collector(
|
|
479
|
+
non_linear_node))
|
|
478
480
|
|
|
479
481
|
assert activation_param.get(SIGNED) is False
|
|
480
482
|
for candidate_qc in non_linear_node.candidates_quantization_cfg:
|
|
@@ -153,20 +153,10 @@ def get_finalized_graph(initial_graph: Graph,
|
|
|
153
153
|
if bit_width_config:
|
|
154
154
|
set_manual_bitwidth_config(graph, bit_width_config)
|
|
155
155
|
|
|
156
|
-
# TODO irena:
|
|
157
|
-
# As a first stage we keep the attributes in internal configs and fill them manually from quant_config
|
|
158
|
-
# not to break all the code at once. Eventually we need to handle quant_config directly, without injecting into candidates.
|
|
159
|
-
# TODO 2: Also we adjust candidates for single precision, which we shouldn't do here.
|
|
160
|
-
def update(qc):
|
|
161
|
-
qc.activation_quantization_cfg.set_qc(quant_config)
|
|
162
|
-
qc.weights_quantization_cfg.set_qc(quant_config)
|
|
163
|
-
for attr_cfg in qc.weights_quantization_cfg.get_all_weight_attrs_configs().values():
|
|
164
|
-
attr_cfg.weights_error_method = quant_config.weights_error_method
|
|
165
|
-
attr_cfg.l_p_value = quant_config.l_p_value
|
|
156
|
+
# TODO irena: remove after base config is used
|
|
166
157
|
for n in transformed_graph.nodes:
|
|
167
158
|
if not mixed_precision_enable:
|
|
168
159
|
n.quantization_cfg.candidates_quantization_cfg = [n.quantization_cfg.base_quantization_cfg]
|
|
169
|
-
n.quantization_cfg.update_all(update)
|
|
170
160
|
|
|
171
161
|
######################################
|
|
172
162
|
# Channel equalization
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
from tensorflow.keras.layers import InputLayer, Dense, DepthwiseConv2D, Conv2D, Conv2DTranspose, ZeroPadding2D
|
|
18
18
|
from typing import List
|
|
19
19
|
|
|
20
|
-
from model_compression_toolkit.core import common
|
|
20
|
+
from model_compression_toolkit.core import common, QuantizationConfig
|
|
21
21
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
22
22
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, WalkMatcher
|
|
23
23
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
@@ -47,7 +47,8 @@ class BaseInputScaling(common.BaseSubstitution):
|
|
|
47
47
|
"""
|
|
48
48
|
|
|
49
49
|
def __init__(self,
|
|
50
|
-
matcher_instance
|
|
50
|
+
matcher_instance,
|
|
51
|
+
quant_cfg: QuantizationConfig):
|
|
51
52
|
"""
|
|
52
53
|
Matches: InputLayer -> (optional nodes) -> (Dense,Conv2D,DepthwiseConv2D,Conv2DTranspose)
|
|
53
54
|
note: the optional nodes are nodes that don't affect the scaling (such as ZeroPadding)
|
|
@@ -55,10 +56,11 @@ class BaseInputScaling(common.BaseSubstitution):
|
|
|
55
56
|
Create a substitution using different params which may affect the way this substitution is made.
|
|
56
57
|
The substitution is looking for edges in the graph which are input layers connected to linear layers.
|
|
57
58
|
Args:
|
|
58
|
-
matcher_instance: matcher instance of type WalkMatcher
|
|
59
|
-
|
|
59
|
+
matcher_instance: matcher instance of type WalkMatcher.
|
|
60
|
+
quant_cfg: quantization config.
|
|
60
61
|
"""
|
|
61
62
|
super().__init__(matcher_instance=matcher_instance)
|
|
63
|
+
self.quant_cfg = quant_cfg
|
|
62
64
|
|
|
63
65
|
def substitute(self,
|
|
64
66
|
graph: Graph,
|
|
@@ -105,9 +107,11 @@ class BaseInputScaling(common.BaseSubstitution):
|
|
|
105
107
|
for nqc in linear_layer.candidates_quantization_cfg:
|
|
106
108
|
attr_cfg = nqc.weights_quantization_cfg.get_attr_config(linear_layer.kernel_attr)
|
|
107
109
|
assert attr_cfg.enable_weights_quantization
|
|
108
|
-
w_params, _ = compute_weights_qparams(w1_fixed,
|
|
109
|
-
|
|
110
|
-
|
|
110
|
+
w_params, _ = compute_weights_qparams(w1_fixed,
|
|
111
|
+
attr_quant_config=attr_cfg,
|
|
112
|
+
weights_error_method=self.quant_cfg.weights_error_method,
|
|
113
|
+
l_p_value=self.quant_cfg.l_p_value,
|
|
114
|
+
output_channels_axis=attr_cfg.weights_channels_axis.output)
|
|
111
115
|
attr_cfg.set_weights_quantization_param(w_params)
|
|
112
116
|
|
|
113
117
|
return graph
|
|
@@ -118,12 +122,15 @@ class InputScaling(BaseInputScaling):
|
|
|
118
122
|
Substitution extends BaseInputScaling to the case of Input-->Linear
|
|
119
123
|
"""
|
|
120
124
|
|
|
121
|
-
def __init__(self):
|
|
125
|
+
def __init__(self, quant_cfg: QuantizationConfig):
|
|
122
126
|
"""
|
|
123
127
|
Initialize a ScaleEqualization object.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
quant_cfg: quantization config.
|
|
124
131
|
"""
|
|
125
132
|
|
|
126
|
-
super().__init__(matcher_instance=INPUT_MATCHER)
|
|
133
|
+
super().__init__(matcher_instance=INPUT_MATCHER, quant_cfg=quant_cfg)
|
|
127
134
|
|
|
128
135
|
|
|
129
136
|
class InputScalingWithPad(BaseInputScaling):
|
|
@@ -131,9 +138,12 @@ class InputScalingWithPad(BaseInputScaling):
|
|
|
131
138
|
Substitution extends BaseInputScaling to the case of Input-->ZeroPadding-->Linear
|
|
132
139
|
"""
|
|
133
140
|
|
|
134
|
-
def __init__(self):
|
|
141
|
+
def __init__(self, quant_cfg: QuantizationConfig):
|
|
135
142
|
"""
|
|
136
143
|
Initialize a ScaleEqualization object.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
quant_cfg: quantization config.
|
|
137
147
|
"""
|
|
138
148
|
|
|
139
|
-
super().__init__(matcher_instance=INPUT_MATCHER_WITH_PAD)
|
|
149
|
+
super().__init__(matcher_instance=INPUT_MATCHER_WITH_PAD, quant_cfg=quant_cfg)
|
|
@@ -357,8 +357,8 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
357
357
|
if quant_config.softmax_shift:
|
|
358
358
|
substitutions_list.append(keras_softmax_shift())
|
|
359
359
|
if quant_config.input_scaling:
|
|
360
|
-
substitutions_list.append(InputScaling())
|
|
361
|
-
substitutions_list.append(InputScalingWithPad())
|
|
360
|
+
substitutions_list.append(InputScaling(quant_config))
|
|
361
|
+
substitutions_list.append(InputScalingWithPad(quant_config))
|
|
362
362
|
if quant_config.concat_threshold_update:
|
|
363
363
|
substitutions_list.append(ConcatThresholdUpdate())
|
|
364
364
|
return substitutions_list
|
|
@@ -28,6 +28,10 @@ import numpy as np
|
|
|
28
28
|
from model_compression_toolkit.logger import Logger
|
|
29
29
|
|
|
30
30
|
|
|
31
|
+
# default output channel axis to use when it's not defined in node's fw_info.
|
|
32
|
+
_default_output_channel_axis = -1
|
|
33
|
+
|
|
34
|
+
|
|
31
35
|
class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementation):
|
|
32
36
|
"""
|
|
33
37
|
Implementation of the PruningFramework for the Keras framework. This class provides
|
|
@@ -172,6 +176,10 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
|
172
176
|
|
|
173
177
|
return attributes_with_axis
|
|
174
178
|
|
|
179
|
+
@property
|
|
180
|
+
def default_output_channel_axis(self):
|
|
181
|
+
return _default_output_channel_axis
|
|
182
|
+
|
|
175
183
|
|
|
176
184
|
def _is_keras_node_pruning_section_edge(node: BaseNode) -> bool:
|
|
177
185
|
"""
|