mct-nightly 2.2.0.20241106.458__py3-none-any.whl → 2.2.0.20241108.459__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241108.459.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241108.459.dist-info}/RECORD +17 -29
- {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241108.459.dist-info}/top_level.txt +0 -1
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +46 -27
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -0
- model_compression_toolkit/core/common/quantization/quantization_config.py +2 -0
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +81 -0
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +190 -0
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +14 -2
- model_compression_toolkit/core/keras/keras_implementation.py +23 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +67 -0
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +21 -0
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +57 -0
- model_compression_toolkit/core/runner.py +8 -0
- tests_pytest/__init__.py +0 -14
- tests_pytest/keras/__init__.py +0 -14
- tests_pytest/keras/core/__init__.py +0 -14
- tests_pytest/keras/core/test_data_util.py +0 -91
- tests_pytest/keras/gptq/__init__.py +0 -14
- tests_pytest/keras/gptq/test_gradual_act_quantization.py +0 -102
- tests_pytest/keras/trainable_infrastructure/__init__.py +0 -16
- tests_pytest/keras/trainable_infrastructure/test_linear_annealing.py +0 -49
- tests_pytest/pytorch/__init__.py +0 -14
- tests_pytest/pytorch/core/__init__.py +0 -14
- tests_pytest/pytorch/core/test_data_util.py +0 -125
- tests_pytest/pytorch/gptq/__init__.py +0 -14
- tests_pytest/pytorch/gptq/test_annealing_cfg.py +0 -40
- tests_pytest/pytorch/gptq/test_gradual_act_quantization.py +0 -100
- tests_pytest/pytorch/trainable_infrastructure/__init__.py +0 -14
- tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py +0 -49
- {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241108.459.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241108.459.dist-info}/WHEEL +0 -0
@@ -0,0 +1,190 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
import numpy as np
|
16
|
+
from typing import Any, Callable
|
17
|
+
|
18
|
+
from model_compression_toolkit.core import QuantizationConfig
|
19
|
+
from model_compression_toolkit.core.common import BaseNode, Graph
|
20
|
+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
21
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
22
|
+
|
23
|
+
|
24
|
+
def get_previous_node_with_activation_quantization(linear_node: BaseNode,
|
25
|
+
graph: Graph) -> Any:
|
26
|
+
"""
|
27
|
+
Search recursively for the previous node with activation quantization.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
linear_node: Node to search for its previous node.
|
31
|
+
graph: Graph the node is in.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
The previous node (if found) or None if it was not found or there are multiple incoming edges to one of
|
35
|
+
nodes during the search (which means, the substitution can not be applied).
|
36
|
+
"""
|
37
|
+
|
38
|
+
prev_nodes = graph.get_prev_nodes(linear_node)
|
39
|
+
|
40
|
+
if len(prev_nodes) != 1:
|
41
|
+
return None # pragma: no cover
|
42
|
+
|
43
|
+
prev_node = prev_nodes[0]
|
44
|
+
|
45
|
+
activation_quantization_config = prev_node.final_activation_quantization_cfg
|
46
|
+
|
47
|
+
# Search for node with activation quantization
|
48
|
+
if (activation_quantization_config.enable_activation_quantization and
|
49
|
+
not activation_quantization_config.quantization_preserving):
|
50
|
+
return prev_node
|
51
|
+
else:
|
52
|
+
return get_previous_node_with_activation_quantization(prev_node, graph)
|
53
|
+
|
54
|
+
|
55
|
+
def calculate_bin_centers(bin_edges: np.ndarray) -> np.ndarray:
|
56
|
+
"""
|
57
|
+
Calculate the centers of bins given their edges.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
bin_edges: Array of bin edges.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
np.ndarray: Array of bin centers.
|
64
|
+
"""
|
65
|
+
# Calculate the centers by averaging continuous bin edges
|
66
|
+
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0
|
67
|
+
return bin_centers
|
68
|
+
|
69
|
+
|
70
|
+
def compute_activation_bias_correction(graph: Graph,
|
71
|
+
quant_config: QuantizationConfig,
|
72
|
+
fw_info: FrameworkInfo,
|
73
|
+
fw_impl: FrameworkImplementation,
|
74
|
+
linear_node: BaseNode,
|
75
|
+
prev_node: BaseNode,
|
76
|
+
kernel_size: str) -> Graph:
|
77
|
+
"""
|
78
|
+
Compute the activation bias correction term, and store it in the final activation
|
79
|
+
quantization configuration.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
graph: Graph with nodes to compute the activation bias correction for each node's final activation quantization configuration.
|
83
|
+
quant_config: QuantizationConfig of how the model should be quantized.
|
84
|
+
fw_info: Framework info like lists of nodes their kernel should quantized.
|
85
|
+
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
86
|
+
linear_node: Node to compute the activation bias correction for.
|
87
|
+
prev_node: Node to compute the activation error caused by his activation quantization.
|
88
|
+
kernel_size: The framework specific attribute name of the convolution layer's kernel size.
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
Graph with activation bias correction term for each node.
|
92
|
+
"""
|
93
|
+
|
94
|
+
# Retrieve the 'kernel_size' value if it exists and ensure it is None, 1, or (1, 1). This feature supports only
|
95
|
+
# Dense/Linear layers and convolution layers with kernel size of 1 or (1, 1).
|
96
|
+
# For Dense/Linear layers, which lack a 'kernel_size' attribute, the result will be None, and no restriction
|
97
|
+
# applies in that case.
|
98
|
+
if linear_node.framework_attr.get(kernel_size) not in [None, 1, (1, 1)]:
|
99
|
+
# If the kernel size is not 1 or (1, 1), return the current graph unmodified
|
100
|
+
return graph
|
101
|
+
|
102
|
+
prev_node_act_quant_cfg = prev_node.final_activation_quantization_cfg
|
103
|
+
|
104
|
+
# Check if the previous node's has activation quantization configuration and if the previous node have the
|
105
|
+
# histogram collector.
|
106
|
+
if prev_node_act_quant_cfg is None or not hasattr(graph.get_out_stats_collector(prev_node), 'hc'):
|
107
|
+
return graph # pragma: no cover
|
108
|
+
|
109
|
+
float_bins, float_count = graph.get_out_stats_collector(prev_node).hc.get_histogram()
|
110
|
+
|
111
|
+
# Calculate the centers of the float bins
|
112
|
+
float_centers = calculate_bin_centers(float_bins)
|
113
|
+
|
114
|
+
# Quantize the bin edges and calculate the centers of the quantized bins
|
115
|
+
quant_bins = prev_node_act_quant_cfg.quantize_node_output(fw_impl.to_tensor(float_bins))
|
116
|
+
quant_bins = fw_impl.to_numpy(quant_bins)
|
117
|
+
quant_centers = calculate_bin_centers(quant_bins)
|
118
|
+
|
119
|
+
# Calculate the mean of the both the float and the quantized bin centers, weighted by the bin counts
|
120
|
+
mean_float_centers = np.sum(float_centers * float_count) / np.sum(float_count)
|
121
|
+
mean_quant_centers = np.sum(quant_centers * float_count) / np.sum(float_count)
|
122
|
+
|
123
|
+
# Compute the difference between the mean quantized center and the mean float center
|
124
|
+
mean_diff = mean_quant_centers - mean_float_centers
|
125
|
+
|
126
|
+
# Calculate the normalized bias as a percentage of the float center norm
|
127
|
+
float_centers_norm1 = np.abs(mean_float_centers)
|
128
|
+
normalized_bias = 100 * np.abs(mean_diff) / float_centers_norm1
|
129
|
+
|
130
|
+
# If the normalized bias is below the activation bias correction threshold, return the graph unmodified.
|
131
|
+
# By default, the threshold is set to 0, allowing all nodes to proceed in this case.
|
132
|
+
if normalized_bias < quant_config.activation_bias_correction_threshold:
|
133
|
+
return graph
|
134
|
+
|
135
|
+
kernel = linear_node.get_weights_by_keys(fw_info.kernel_ops_attributes_mapping.get(linear_node.type)[0])
|
136
|
+
|
137
|
+
# Compute the activation bias correction by applying the quantization error to the kernel, resulting in an output
|
138
|
+
# size matching the number of output channels.
|
139
|
+
if kernel is not None:
|
140
|
+
|
141
|
+
# Get the axes that are not the output channel.
|
142
|
+
output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(linear_node.type)
|
143
|
+
axis_not_output_channel = list(range(len(kernel.shape)))
|
144
|
+
axis_not_output_channel.remove(output_channel_index)
|
145
|
+
|
146
|
+
# Special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters.
|
147
|
+
if output_channel_index == input_channel_index:
|
148
|
+
axis_not_output_channel.remove(3) # 3 is the depth multiplier index.
|
149
|
+
|
150
|
+
activation_bias_correction_term = mean_diff * np.sum(kernel, axis=tuple(axis_not_output_channel))
|
151
|
+
linear_node.final_activation_quantization_cfg.activation_bias_correction_term = (
|
152
|
+
activation_bias_correction_term.flatten())
|
153
|
+
return graph
|
154
|
+
|
155
|
+
|
156
|
+
def compute_activation_bias_correction_of_graph(graph: Graph,
|
157
|
+
quant_config: QuantizationConfig,
|
158
|
+
fw_info: FrameworkInfo,
|
159
|
+
fw_impl: FrameworkImplementation,
|
160
|
+
activation_bias_correction_node_matchers: Callable,
|
161
|
+
kernel_size: str) -> Graph:
|
162
|
+
"""
|
163
|
+
Compute the activation bias correction term for the graph.
|
164
|
+
|
165
|
+
Args:
|
166
|
+
graph: Graph with nodes to compute the activation bias correction.
|
167
|
+
quant_config: QuantizationConfig of how the model should be quantized.
|
168
|
+
fw_info: Framework info like lists of nodes their kernel should quantized.
|
169
|
+
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
170
|
+
activation_bias_correction_node_matchers: Function to match the layers for activation bias correction.
|
171
|
+
kernel_size: The framework specific attribute name of the convolution layer's kernel size.
|
172
|
+
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
Graph with activation bias correction term for each relevant node.
|
176
|
+
"""
|
177
|
+
linear_node_types = activation_bias_correction_node_matchers()
|
178
|
+
|
179
|
+
for n in graph.nodes:
|
180
|
+
if linear_node_types.apply(n):
|
181
|
+
prev_node = get_previous_node_with_activation_quantization(n, graph)
|
182
|
+
if prev_node is not None:
|
183
|
+
graph = compute_activation_bias_correction(graph=graph,
|
184
|
+
quant_config=quant_config,
|
185
|
+
fw_info=fw_info,
|
186
|
+
fw_impl=fw_impl,
|
187
|
+
linear_node=n,
|
188
|
+
prev_node=prev_node,
|
189
|
+
kernel_size=kernel_size)
|
190
|
+
return graph
|
@@ -18,6 +18,8 @@ from model_compression_toolkit.core.common import FrameworkInfo
|
|
18
18
|
from model_compression_toolkit.core.common import Graph
|
19
19
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
20
20
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
21
|
+
from model_compression_toolkit.core.common.statistics_correction.apply_activation_bias_correction_to_graph import \
|
22
|
+
apply_activation_bias_correction_to_graph
|
21
23
|
from model_compression_toolkit.core.common.statistics_correction.apply_bias_correction_to_graph import \
|
22
24
|
apply_bias_correction_to_graph
|
23
25
|
from model_compression_toolkit.core.common.statistics_correction.apply_second_moment_correction_to_graph import \
|
@@ -73,7 +75,7 @@ def apply_statistics_correction(transformed_graph: Graph,
|
|
73
75
|
fw_impl: FrameworkImplementation,
|
74
76
|
tb_w: TensorboardWriter = None, ) -> Graph:
|
75
77
|
"""
|
76
|
-
Apply statistics
|
78
|
+
Apply statistics correction on graph.
|
77
79
|
Args:
|
78
80
|
transformed_graph: Graph to apply statistics correction.
|
79
81
|
representative_data_gen (Callable): Dataset used for calibration.
|
@@ -84,7 +86,7 @@ def apply_statistics_correction(transformed_graph: Graph,
|
|
84
86
|
tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
|
85
87
|
|
86
88
|
Returns:
|
87
|
-
Graph after statistics correction
|
89
|
+
Graph after statistics correction.
|
88
90
|
"""
|
89
91
|
|
90
92
|
#############################################
|
@@ -104,4 +106,14 @@ def apply_statistics_correction(transformed_graph: Graph,
|
|
104
106
|
if tb_w is not None:
|
105
107
|
tb_w.add_graph(transformed_graph, 'after_statistics_correction')
|
106
108
|
|
109
|
+
#############################################
|
110
|
+
# Apply Activation Bias Correction
|
111
|
+
#############################################
|
112
|
+
if core_config.quantization_config.activation_bias_correction:
|
113
|
+
transformed_graph = apply_activation_bias_correction_to_graph(graph=transformed_graph,
|
114
|
+
core_config=core_config,
|
115
|
+
fw_impl=fw_impl)
|
116
|
+
if tb_w is not None:
|
117
|
+
tb_w.add_graph(transformed_graph, 'after_activation_bias_correction')
|
118
|
+
|
107
119
|
return transformed_graph
|
@@ -28,6 +28,8 @@ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.remo
|
|
28
28
|
from model_compression_toolkit.core.keras.hessian.activation_hessian_scores_calculator_keras import \
|
29
29
|
ActivationHessianScoresCalculatorKeras
|
30
30
|
from model_compression_toolkit.core.keras.hessian.weights_hessian_scores_calculator_keras import WeightsHessianScoresCalculatorKeras
|
31
|
+
from model_compression_toolkit.core.keras.statistics_correction.keras_compute_activation_bias_correction_of_graph import \
|
32
|
+
keras_compute_activation_bias_correction_of_graph
|
31
33
|
from model_compression_toolkit.exporter.model_wrapper.fw_agnostic.get_inferable_quantizers import \
|
32
34
|
get_inferable_quantizers
|
33
35
|
from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \
|
@@ -84,7 +86,7 @@ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.line
|
|
84
86
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.residual_collapsing import \
|
85
87
|
keras_residual_collapsing
|
86
88
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.input_scaling import InputScaling, \
|
87
|
-
InputScalingWithPad
|
89
|
+
InputScalingWithPad
|
88
90
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.concat_threshold_update import ConcatThresholdUpdate
|
89
91
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \
|
90
92
|
ReLUBoundToPowerOfTwo
|
@@ -218,6 +220,25 @@ class KerasImplementation(FrameworkImplementation):
|
|
218
220
|
core_config,
|
219
221
|
fw_info)
|
220
222
|
|
223
|
+
def compute_activation_bias_correction(self,
|
224
|
+
graph: Graph,
|
225
|
+
quant_config: QuantizationConfig,
|
226
|
+
fw_info: FrameworkInfo):
|
227
|
+
"""
|
228
|
+
Compute activation bias correction on a graph.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
graph: Graph to apply activation bias correction on.
|
232
|
+
quant_config: QuantizationConfig of how the model should be quantized.
|
233
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
Graph after activation bias correction computing.
|
237
|
+
"""
|
238
|
+
return keras_compute_activation_bias_correction_of_graph(graph=graph,
|
239
|
+
quant_config=quant_config,
|
240
|
+
fw_info=fw_info,
|
241
|
+
fw_impl=self)
|
221
242
|
|
222
243
|
def get_substitutions_channel_equalization(self,
|
223
244
|
quant_config: QuantizationConfig,
|
@@ -309,7 +330,7 @@ class KerasImplementation(FrameworkImplementation):
|
|
309
330
|
"""
|
310
331
|
return keras_op2d_add_const_collapsing()
|
311
332
|
|
312
|
-
def get_substitutions_post_statistics_collection(self,
|
333
|
+
def get_substitutions_post_statistics_collection(self,
|
313
334
|
quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
|
314
335
|
"""
|
315
336
|
Return a list of the framework substitutions used after we collect statistics.
|
@@ -0,0 +1,67 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import tensorflow as tf
|
17
|
+
from packaging import version
|
18
|
+
|
19
|
+
from model_compression_toolkit.core.keras.constants import KERNEL_SIZE
|
20
|
+
|
21
|
+
if version.parse(tf.__version__) >= version.parse("2.13"):
|
22
|
+
from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose
|
23
|
+
else:
|
24
|
+
from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose
|
25
|
+
|
26
|
+
from model_compression_toolkit.core import QuantizationConfig
|
27
|
+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
28
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
29
|
+
from model_compression_toolkit.core.common import Graph
|
30
|
+
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
31
|
+
from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \
|
32
|
+
compute_activation_bias_correction_of_graph
|
33
|
+
|
34
|
+
|
35
|
+
def activation_bias_correction_node_matchers():
|
36
|
+
# Match linear layers where we can add a correction.
|
37
|
+
linear_node = NodeOperationMatcher(Conv2D) | \
|
38
|
+
NodeOperationMatcher(Dense) | \
|
39
|
+
NodeOperationMatcher(DepthwiseConv2D) | \
|
40
|
+
NodeOperationMatcher(Conv2DTranspose)
|
41
|
+
return linear_node
|
42
|
+
|
43
|
+
|
44
|
+
def keras_compute_activation_bias_correction_of_graph(graph: Graph,
|
45
|
+
quant_config: QuantizationConfig,
|
46
|
+
fw_info: FrameworkInfo,
|
47
|
+
fw_impl: FrameworkImplementation) -> Graph:
|
48
|
+
"""
|
49
|
+
Compute the activation bias correction term for graph based on a Keras model.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
graph: Graph with nodes to compute the activation bias correction.
|
53
|
+
quant_config: QuantizationConfig of how the model should be quantized.
|
54
|
+
fw_info: Framework info like lists of nodes their kernel should quantized.
|
55
|
+
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
Graph with activation bias correction term for each relevant node.
|
59
|
+
"""
|
60
|
+
graph = compute_activation_bias_correction_of_graph(graph=graph,
|
61
|
+
quant_config=quant_config,
|
62
|
+
fw_info=fw_info,
|
63
|
+
fw_impl=fw_impl,
|
64
|
+
activation_bias_correction_node_matchers=
|
65
|
+
activation_bias_correction_node_matchers,
|
66
|
+
kernel_size=KERNEL_SIZE)
|
67
|
+
return graph
|
@@ -92,6 +92,8 @@ from model_compression_toolkit.core.pytorch.pytorch_node_prior_info import creat
|
|
92
92
|
from model_compression_toolkit.core.pytorch.reader.reader import model_reader
|
93
93
|
from model_compression_toolkit.core.pytorch.statistics_correction.apply_second_moment_correction import \
|
94
94
|
pytorch_apply_second_moment_correction
|
95
|
+
from model_compression_toolkit.core.pytorch.statistics_correction.pytorch_compute_activation_bias_correction_of_graph import \
|
96
|
+
pytorch_compute_activation_bias_correction_of_graph
|
95
97
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model
|
96
98
|
from model_compression_toolkit.exporter.model_wrapper.fw_agnostic.get_inferable_quantizers import \
|
97
99
|
get_inferable_quantizers
|
@@ -212,6 +214,25 @@ class PytorchImplementation(FrameworkImplementation):
|
|
212
214
|
core_config,
|
213
215
|
fw_info)
|
214
216
|
|
217
|
+
def compute_activation_bias_correction(self,
|
218
|
+
graph: Graph,
|
219
|
+
quant_config: QuantizationConfig,
|
220
|
+
fw_info: FrameworkInfo):
|
221
|
+
"""
|
222
|
+
Compute activation bias correction on a graph.
|
223
|
+
|
224
|
+
Args:
|
225
|
+
graph: Graph to apply activation bias correction on.
|
226
|
+
quant_config: QuantizationConfig of how the model should be quantized.
|
227
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
228
|
+
|
229
|
+
Returns:
|
230
|
+
Graph after activation bias correction computing.
|
231
|
+
"""
|
232
|
+
return pytorch_compute_activation_bias_correction_of_graph(graph=graph,
|
233
|
+
quant_config=quant_config,
|
234
|
+
fw_info=fw_info,
|
235
|
+
fw_impl=self)
|
215
236
|
|
216
237
|
def get_substitutions_channel_equalization(self,
|
217
238
|
quant_config: QuantizationConfig,
|
@@ -0,0 +1,57 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from torch.nn import Conv2d, Linear, ConvTranspose2d
|
17
|
+
|
18
|
+
from model_compression_toolkit.core import QuantizationConfig
|
19
|
+
from model_compression_toolkit.core.common import Graph
|
20
|
+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
21
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
22
|
+
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
23
|
+
from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \
|
24
|
+
compute_activation_bias_correction_of_graph
|
25
|
+
from model_compression_toolkit.core.pytorch.constants import KERNEL_SIZE
|
26
|
+
|
27
|
+
|
28
|
+
def activation_bias_correction_node_matchers():
|
29
|
+
# Match linear layers where we can add a correction.
|
30
|
+
linear_node = NodeOperationMatcher(Linear) | NodeOperationMatcher(Conv2d) | NodeOperationMatcher(ConvTranspose2d)
|
31
|
+
return linear_node
|
32
|
+
|
33
|
+
|
34
|
+
def pytorch_compute_activation_bias_correction_of_graph(graph: Graph,
|
35
|
+
quant_config: QuantizationConfig,
|
36
|
+
fw_info: FrameworkInfo,
|
37
|
+
fw_impl: FrameworkImplementation) -> Graph:
|
38
|
+
"""
|
39
|
+
Compute the activation bias correction term for graph based on a PyTorch model.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
graph: Graph with nodes to compute the activation bias correction.
|
43
|
+
quant_config: QuantizationConfig of how the model should be quantized.
|
44
|
+
fw_info: Framework info like lists of nodes their kernel should quantized.
|
45
|
+
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
Graph with activation bias correction term for each relevant node.
|
49
|
+
"""
|
50
|
+
graph = compute_activation_bias_correction_of_graph(graph=graph,
|
51
|
+
quant_config=quant_config,
|
52
|
+
fw_info=fw_info,
|
53
|
+
fw_impl=fw_impl,
|
54
|
+
activation_bias_correction_node_matchers=
|
55
|
+
activation_bias_correction_node_matchers,
|
56
|
+
kernel_size=KERNEL_SIZE)
|
57
|
+
return graph
|
@@ -164,6 +164,14 @@ def core_runner(in_model: Any,
|
|
164
164
|
tg,
|
165
165
|
bit_widths_config)
|
166
166
|
|
167
|
+
######################################
|
168
|
+
# Compute Activation Bias Correction
|
169
|
+
######################################
|
170
|
+
if core_config.quantization_config.activation_bias_correction:
|
171
|
+
tg = fw_impl.compute_activation_bias_correction(graph=tg,
|
172
|
+
quant_config=core_config.quantization_config,
|
173
|
+
fw_info=fw_info)
|
174
|
+
|
167
175
|
# Edit the graph again after finalizing the configurations.
|
168
176
|
# This is since some actions regard the final configuration and should be edited.
|
169
177
|
edit_network_graph(tg, fw_info, core_config.debug_config.network_editor)
|
tests_pytest/__init__.py
DELETED
@@ -1,14 +0,0 @@
|
|
1
|
-
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
tests_pytest/keras/__init__.py
DELETED
@@ -1,14 +0,0 @@
|
|
1
|
-
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
@@ -1,14 +0,0 @@
|
|
1
|
-
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
@@ -1,91 +0,0 @@
|
|
1
|
-
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
import numpy as np
|
16
|
-
import pytest
|
17
|
-
|
18
|
-
from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader, TFDatasetFromGenerator
|
19
|
-
|
20
|
-
|
21
|
-
@pytest.fixture(scope='session')
|
22
|
-
def fixed_dataset():
|
23
|
-
# generate 320 images with data1[i] = i and data2[i] = i+10
|
24
|
-
data1 = np.stack([np.full((3, 30, 20), v) for v in range(320)], axis=0)
|
25
|
-
data2 = np.stack([np.full((10,), v + 10) for v in range(320)], axis=0)
|
26
|
-
return data1, data2
|
27
|
-
|
28
|
-
|
29
|
-
@pytest.fixture
|
30
|
-
def fixed_gen(fixed_dataset):
|
31
|
-
def f():
|
32
|
-
for i in range(10):
|
33
|
-
yield [fixed_dataset[0][32 * i: 32 * (i + 1)], fixed_dataset[1][32 * i: 32 * (i + 1)]]
|
34
|
-
|
35
|
-
return f
|
36
|
-
|
37
|
-
|
38
|
-
def get_random_data_gen_fn(seed=42):
|
39
|
-
""" get gen factory for reproducible gen yielding different samples in each epoch """
|
40
|
-
rng = np.random.default_rng(seed)
|
41
|
-
|
42
|
-
def f():
|
43
|
-
for i in range(10):
|
44
|
-
yield [rng.random((32, 3, 20, 30)).astype(np.float32), rng.random((32, 10)).astype(np.float32)]
|
45
|
-
return f
|
46
|
-
|
47
|
-
|
48
|
-
class TestTFDataUtil:
|
49
|
-
create_dataloader_fn = data_gen_to_dataloader
|
50
|
-
|
51
|
-
def test_iterable_dataset_from_fixed_gen(self, fixed_gen):
|
52
|
-
""" tests iterable dataset from fixed gen - same samples are generated in each epoch in the same order """
|
53
|
-
ds = TFDatasetFromGenerator(fixed_gen, batch_size=1)
|
54
|
-
self._validate_ds_from_fixed_gen(ds, 320)
|
55
|
-
|
56
|
-
def test_iterable_dataset_from_random_gen(self):
|
57
|
-
""" test that dataset samples over epochs are identical to the original data generator """
|
58
|
-
ds = TFDatasetFromGenerator(get_random_data_gen_fn(), batch_size=1)
|
59
|
-
pass1 = np.concatenate([t[0] for t in ds], axis=0)
|
60
|
-
pass2 = np.concatenate([t[0] for t in ds], axis=0)
|
61
|
-
|
62
|
-
gen_fn = get_random_data_gen_fn()
|
63
|
-
# one invocation is used for validation and batch size in dataset, so promote the reference gen for comparison
|
64
|
-
next(gen_fn())
|
65
|
-
gen_pass1 = np.concatenate([t[0] for t in gen_fn()], axis=0)
|
66
|
-
gen_pass2 = np.concatenate([t[0] for t in gen_fn()], axis=0)
|
67
|
-
# check that each pass is identical to corresponding pass in the original gen
|
68
|
-
assert np.array_equal(pass1, gen_pass1)
|
69
|
-
assert np.array_equal(pass2, gen_pass2)
|
70
|
-
assert not np.allclose(pass1, pass2)
|
71
|
-
|
72
|
-
def test_dataloader(self, fixed_gen):
|
73
|
-
ds = TFDatasetFromGenerator(fixed_gen, batch_size=25)
|
74
|
-
ds_iter = iter(ds)
|
75
|
-
batch1 = next(ds_iter)
|
76
|
-
assert batch1[0].shape[0] == batch1[1].shape[0] == 25
|
77
|
-
assert np.array_equal(batch1[0][0], np.full((3, 30, 20), 0))
|
78
|
-
assert np.array_equal(batch1[1][0], np.full((10,), 10))
|
79
|
-
assert np.array_equal(batch1[0][-1], np.full((3, 30, 20), 24))
|
80
|
-
assert np.array_equal(batch1[1][-1], np.full((10,), 34))
|
81
|
-
assert len(ds) == 13
|
82
|
-
assert ds.orig_batch_size == 32
|
83
|
-
|
84
|
-
def _validate_ds_from_fixed_gen(self, ds, exp_len):
|
85
|
-
for _ in range(2):
|
86
|
-
for i, sample in enumerate(ds):
|
87
|
-
assert np.array_equal(sample[0].cpu().numpy(), np.full((1, 3, 30, 20), i))
|
88
|
-
assert np.array_equal(sample[1].cpu().numpy(), np.full((1, 10,), i + 10))
|
89
|
-
assert i == exp_len - 1
|
90
|
-
assert ds.orig_batch_size == 32
|
91
|
-
assert len(ds) == exp_len
|
@@ -1,14 +0,0 @@
|
|
1
|
-
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|