mct-nightly 2.2.0.20241106.458__py3-none-any.whl → 2.2.0.20241108.459__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241108.459.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241108.459.dist-info}/RECORD +17 -29
  3. {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241108.459.dist-info}/top_level.txt +0 -1
  4. model_compression_toolkit/__init__.py +1 -1
  5. model_compression_toolkit/core/common/framework_implementation.py +46 -27
  6. model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -0
  7. model_compression_toolkit/core/common/quantization/quantization_config.py +2 -0
  8. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +81 -0
  9. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +190 -0
  10. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +14 -2
  11. model_compression_toolkit/core/keras/keras_implementation.py +23 -2
  12. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +67 -0
  13. model_compression_toolkit/core/pytorch/pytorch_implementation.py +21 -0
  14. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +57 -0
  15. model_compression_toolkit/core/runner.py +8 -0
  16. tests_pytest/__init__.py +0 -14
  17. tests_pytest/keras/__init__.py +0 -14
  18. tests_pytest/keras/core/__init__.py +0 -14
  19. tests_pytest/keras/core/test_data_util.py +0 -91
  20. tests_pytest/keras/gptq/__init__.py +0 -14
  21. tests_pytest/keras/gptq/test_gradual_act_quantization.py +0 -102
  22. tests_pytest/keras/trainable_infrastructure/__init__.py +0 -16
  23. tests_pytest/keras/trainable_infrastructure/test_linear_annealing.py +0 -49
  24. tests_pytest/pytorch/__init__.py +0 -14
  25. tests_pytest/pytorch/core/__init__.py +0 -14
  26. tests_pytest/pytorch/core/test_data_util.py +0 -125
  27. tests_pytest/pytorch/gptq/__init__.py +0 -14
  28. tests_pytest/pytorch/gptq/test_annealing_cfg.py +0 -40
  29. tests_pytest/pytorch/gptq/test_gradual_act_quantization.py +0 -100
  30. tests_pytest/pytorch/trainable_infrastructure/__init__.py +0 -14
  31. tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py +0 -49
  32. {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241108.459.dist-info}/LICENSE.md +0 -0
  33. {mct_nightly-2.2.0.20241106.458.dist-info → mct_nightly-2.2.0.20241108.459.dist-info}/WHEEL +0 -0
@@ -0,0 +1,190 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import numpy as np
16
+ from typing import Any, Callable
17
+
18
+ from model_compression_toolkit.core import QuantizationConfig
19
+ from model_compression_toolkit.core.common import BaseNode, Graph
20
+ from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
+
23
+
24
+ def get_previous_node_with_activation_quantization(linear_node: BaseNode,
25
+ graph: Graph) -> Any:
26
+ """
27
+ Search recursively for the previous node with activation quantization.
28
+
29
+ Args:
30
+ linear_node: Node to search for its previous node.
31
+ graph: Graph the node is in.
32
+
33
+ Returns:
34
+ The previous node (if found) or None if it was not found or there are multiple incoming edges to one of
35
+ nodes during the search (which means, the substitution can not be applied).
36
+ """
37
+
38
+ prev_nodes = graph.get_prev_nodes(linear_node)
39
+
40
+ if len(prev_nodes) != 1:
41
+ return None # pragma: no cover
42
+
43
+ prev_node = prev_nodes[0]
44
+
45
+ activation_quantization_config = prev_node.final_activation_quantization_cfg
46
+
47
+ # Search for node with activation quantization
48
+ if (activation_quantization_config.enable_activation_quantization and
49
+ not activation_quantization_config.quantization_preserving):
50
+ return prev_node
51
+ else:
52
+ return get_previous_node_with_activation_quantization(prev_node, graph)
53
+
54
+
55
+ def calculate_bin_centers(bin_edges: np.ndarray) -> np.ndarray:
56
+ """
57
+ Calculate the centers of bins given their edges.
58
+
59
+ Args:
60
+ bin_edges: Array of bin edges.
61
+
62
+ Returns:
63
+ np.ndarray: Array of bin centers.
64
+ """
65
+ # Calculate the centers by averaging continuous bin edges
66
+ bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0
67
+ return bin_centers
68
+
69
+
70
+ def compute_activation_bias_correction(graph: Graph,
71
+ quant_config: QuantizationConfig,
72
+ fw_info: FrameworkInfo,
73
+ fw_impl: FrameworkImplementation,
74
+ linear_node: BaseNode,
75
+ prev_node: BaseNode,
76
+ kernel_size: str) -> Graph:
77
+ """
78
+ Compute the activation bias correction term, and store it in the final activation
79
+ quantization configuration.
80
+
81
+ Args:
82
+ graph: Graph with nodes to compute the activation bias correction for each node's final activation quantization configuration.
83
+ quant_config: QuantizationConfig of how the model should be quantized.
84
+ fw_info: Framework info like lists of nodes their kernel should quantized.
85
+ fw_impl: FrameworkImplementation object with a specific framework methods implementation.
86
+ linear_node: Node to compute the activation bias correction for.
87
+ prev_node: Node to compute the activation error caused by his activation quantization.
88
+ kernel_size: The framework specific attribute name of the convolution layer's kernel size.
89
+
90
+ Returns:
91
+ Graph with activation bias correction term for each node.
92
+ """
93
+
94
+ # Retrieve the 'kernel_size' value if it exists and ensure it is None, 1, or (1, 1). This feature supports only
95
+ # Dense/Linear layers and convolution layers with kernel size of 1 or (1, 1).
96
+ # For Dense/Linear layers, which lack a 'kernel_size' attribute, the result will be None, and no restriction
97
+ # applies in that case.
98
+ if linear_node.framework_attr.get(kernel_size) not in [None, 1, (1, 1)]:
99
+ # If the kernel size is not 1 or (1, 1), return the current graph unmodified
100
+ return graph
101
+
102
+ prev_node_act_quant_cfg = prev_node.final_activation_quantization_cfg
103
+
104
+ # Check if the previous node's has activation quantization configuration and if the previous node have the
105
+ # histogram collector.
106
+ if prev_node_act_quant_cfg is None or not hasattr(graph.get_out_stats_collector(prev_node), 'hc'):
107
+ return graph # pragma: no cover
108
+
109
+ float_bins, float_count = graph.get_out_stats_collector(prev_node).hc.get_histogram()
110
+
111
+ # Calculate the centers of the float bins
112
+ float_centers = calculate_bin_centers(float_bins)
113
+
114
+ # Quantize the bin edges and calculate the centers of the quantized bins
115
+ quant_bins = prev_node_act_quant_cfg.quantize_node_output(fw_impl.to_tensor(float_bins))
116
+ quant_bins = fw_impl.to_numpy(quant_bins)
117
+ quant_centers = calculate_bin_centers(quant_bins)
118
+
119
+ # Calculate the mean of the both the float and the quantized bin centers, weighted by the bin counts
120
+ mean_float_centers = np.sum(float_centers * float_count) / np.sum(float_count)
121
+ mean_quant_centers = np.sum(quant_centers * float_count) / np.sum(float_count)
122
+
123
+ # Compute the difference between the mean quantized center and the mean float center
124
+ mean_diff = mean_quant_centers - mean_float_centers
125
+
126
+ # Calculate the normalized bias as a percentage of the float center norm
127
+ float_centers_norm1 = np.abs(mean_float_centers)
128
+ normalized_bias = 100 * np.abs(mean_diff) / float_centers_norm1
129
+
130
+ # If the normalized bias is below the activation bias correction threshold, return the graph unmodified.
131
+ # By default, the threshold is set to 0, allowing all nodes to proceed in this case.
132
+ if normalized_bias < quant_config.activation_bias_correction_threshold:
133
+ return graph
134
+
135
+ kernel = linear_node.get_weights_by_keys(fw_info.kernel_ops_attributes_mapping.get(linear_node.type)[0])
136
+
137
+ # Compute the activation bias correction by applying the quantization error to the kernel, resulting in an output
138
+ # size matching the number of output channels.
139
+ if kernel is not None:
140
+
141
+ # Get the axes that are not the output channel.
142
+ output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(linear_node.type)
143
+ axis_not_output_channel = list(range(len(kernel.shape)))
144
+ axis_not_output_channel.remove(output_channel_index)
145
+
146
+ # Special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters.
147
+ if output_channel_index == input_channel_index:
148
+ axis_not_output_channel.remove(3) # 3 is the depth multiplier index.
149
+
150
+ activation_bias_correction_term = mean_diff * np.sum(kernel, axis=tuple(axis_not_output_channel))
151
+ linear_node.final_activation_quantization_cfg.activation_bias_correction_term = (
152
+ activation_bias_correction_term.flatten())
153
+ return graph
154
+
155
+
156
+ def compute_activation_bias_correction_of_graph(graph: Graph,
157
+ quant_config: QuantizationConfig,
158
+ fw_info: FrameworkInfo,
159
+ fw_impl: FrameworkImplementation,
160
+ activation_bias_correction_node_matchers: Callable,
161
+ kernel_size: str) -> Graph:
162
+ """
163
+ Compute the activation bias correction term for the graph.
164
+
165
+ Args:
166
+ graph: Graph with nodes to compute the activation bias correction.
167
+ quant_config: QuantizationConfig of how the model should be quantized.
168
+ fw_info: Framework info like lists of nodes their kernel should quantized.
169
+ fw_impl: FrameworkImplementation object with a specific framework methods implementation.
170
+ activation_bias_correction_node_matchers: Function to match the layers for activation bias correction.
171
+ kernel_size: The framework specific attribute name of the convolution layer's kernel size.
172
+
173
+
174
+ Returns:
175
+ Graph with activation bias correction term for each relevant node.
176
+ """
177
+ linear_node_types = activation_bias_correction_node_matchers()
178
+
179
+ for n in graph.nodes:
180
+ if linear_node_types.apply(n):
181
+ prev_node = get_previous_node_with_activation_quantization(n, graph)
182
+ if prev_node is not None:
183
+ graph = compute_activation_bias_correction(graph=graph,
184
+ quant_config=quant_config,
185
+ fw_info=fw_info,
186
+ fw_impl=fw_impl,
187
+ linear_node=n,
188
+ prev_node=prev_node,
189
+ kernel_size=kernel_size)
190
+ return graph
@@ -18,6 +18,8 @@ from model_compression_toolkit.core.common import FrameworkInfo
18
18
  from model_compression_toolkit.core.common import Graph
19
19
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
20
20
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
21
+ from model_compression_toolkit.core.common.statistics_correction.apply_activation_bias_correction_to_graph import \
22
+ apply_activation_bias_correction_to_graph
21
23
  from model_compression_toolkit.core.common.statistics_correction.apply_bias_correction_to_graph import \
22
24
  apply_bias_correction_to_graph
23
25
  from model_compression_toolkit.core.common.statistics_correction.apply_second_moment_correction_to_graph import \
@@ -73,7 +75,7 @@ def apply_statistics_correction(transformed_graph: Graph,
73
75
  fw_impl: FrameworkImplementation,
74
76
  tb_w: TensorboardWriter = None, ) -> Graph:
75
77
  """
76
- Apply statistics moment correction on graph.
78
+ Apply statistics correction on graph.
77
79
  Args:
78
80
  transformed_graph: Graph to apply statistics correction.
79
81
  representative_data_gen (Callable): Dataset used for calibration.
@@ -84,7 +86,7 @@ def apply_statistics_correction(transformed_graph: Graph,
84
86
  tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
85
87
 
86
88
  Returns:
87
- Graph after statistics correction correction.
89
+ Graph after statistics correction.
88
90
  """
89
91
 
90
92
  #############################################
@@ -104,4 +106,14 @@ def apply_statistics_correction(transformed_graph: Graph,
104
106
  if tb_w is not None:
105
107
  tb_w.add_graph(transformed_graph, 'after_statistics_correction')
106
108
 
109
+ #############################################
110
+ # Apply Activation Bias Correction
111
+ #############################################
112
+ if core_config.quantization_config.activation_bias_correction:
113
+ transformed_graph = apply_activation_bias_correction_to_graph(graph=transformed_graph,
114
+ core_config=core_config,
115
+ fw_impl=fw_impl)
116
+ if tb_w is not None:
117
+ tb_w.add_graph(transformed_graph, 'after_activation_bias_correction')
118
+
107
119
  return transformed_graph
@@ -28,6 +28,8 @@ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.remo
28
28
  from model_compression_toolkit.core.keras.hessian.activation_hessian_scores_calculator_keras import \
29
29
  ActivationHessianScoresCalculatorKeras
30
30
  from model_compression_toolkit.core.keras.hessian.weights_hessian_scores_calculator_keras import WeightsHessianScoresCalculatorKeras
31
+ from model_compression_toolkit.core.keras.statistics_correction.keras_compute_activation_bias_correction_of_graph import \
32
+ keras_compute_activation_bias_correction_of_graph
31
33
  from model_compression_toolkit.exporter.model_wrapper.fw_agnostic.get_inferable_quantizers import \
32
34
  get_inferable_quantizers
33
35
  from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \
@@ -84,7 +86,7 @@ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.line
84
86
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.residual_collapsing import \
85
87
  keras_residual_collapsing
86
88
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.input_scaling import InputScaling, \
87
- InputScalingWithPad
89
+ InputScalingWithPad
88
90
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.concat_threshold_update import ConcatThresholdUpdate
89
91
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \
90
92
  ReLUBoundToPowerOfTwo
@@ -218,6 +220,25 @@ class KerasImplementation(FrameworkImplementation):
218
220
  core_config,
219
221
  fw_info)
220
222
 
223
+ def compute_activation_bias_correction(self,
224
+ graph: Graph,
225
+ quant_config: QuantizationConfig,
226
+ fw_info: FrameworkInfo):
227
+ """
228
+ Compute activation bias correction on a graph.
229
+
230
+ Args:
231
+ graph: Graph to apply activation bias correction on.
232
+ quant_config: QuantizationConfig of how the model should be quantized.
233
+ fw_info: FrameworkInfo object with information about the specific framework's model.
234
+
235
+ Returns:
236
+ Graph after activation bias correction computing.
237
+ """
238
+ return keras_compute_activation_bias_correction_of_graph(graph=graph,
239
+ quant_config=quant_config,
240
+ fw_info=fw_info,
241
+ fw_impl=self)
221
242
 
222
243
  def get_substitutions_channel_equalization(self,
223
244
  quant_config: QuantizationConfig,
@@ -309,7 +330,7 @@ class KerasImplementation(FrameworkImplementation):
309
330
  """
310
331
  return keras_op2d_add_const_collapsing()
311
332
 
312
- def get_substitutions_post_statistics_collection(self,
333
+ def get_substitutions_post_statistics_collection(self,
313
334
  quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
314
335
  """
315
336
  Return a list of the framework substitutions used after we collect statistics.
@@ -0,0 +1,67 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import tensorflow as tf
17
+ from packaging import version
18
+
19
+ from model_compression_toolkit.core.keras.constants import KERNEL_SIZE
20
+
21
+ if version.parse(tf.__version__) >= version.parse("2.13"):
22
+ from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose
23
+ else:
24
+ from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose
25
+
26
+ from model_compression_toolkit.core import QuantizationConfig
27
+ from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
28
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
29
+ from model_compression_toolkit.core.common import Graph
30
+ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
31
+ from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \
32
+ compute_activation_bias_correction_of_graph
33
+
34
+
35
+ def activation_bias_correction_node_matchers():
36
+ # Match linear layers where we can add a correction.
37
+ linear_node = NodeOperationMatcher(Conv2D) | \
38
+ NodeOperationMatcher(Dense) | \
39
+ NodeOperationMatcher(DepthwiseConv2D) | \
40
+ NodeOperationMatcher(Conv2DTranspose)
41
+ return linear_node
42
+
43
+
44
+ def keras_compute_activation_bias_correction_of_graph(graph: Graph,
45
+ quant_config: QuantizationConfig,
46
+ fw_info: FrameworkInfo,
47
+ fw_impl: FrameworkImplementation) -> Graph:
48
+ """
49
+ Compute the activation bias correction term for graph based on a Keras model.
50
+
51
+ Args:
52
+ graph: Graph with nodes to compute the activation bias correction.
53
+ quant_config: QuantizationConfig of how the model should be quantized.
54
+ fw_info: Framework info like lists of nodes their kernel should quantized.
55
+ fw_impl: FrameworkImplementation object with a specific framework methods implementation.
56
+
57
+ Returns:
58
+ Graph with activation bias correction term for each relevant node.
59
+ """
60
+ graph = compute_activation_bias_correction_of_graph(graph=graph,
61
+ quant_config=quant_config,
62
+ fw_info=fw_info,
63
+ fw_impl=fw_impl,
64
+ activation_bias_correction_node_matchers=
65
+ activation_bias_correction_node_matchers,
66
+ kernel_size=KERNEL_SIZE)
67
+ return graph
@@ -92,6 +92,8 @@ from model_compression_toolkit.core.pytorch.pytorch_node_prior_info import creat
92
92
  from model_compression_toolkit.core.pytorch.reader.reader import model_reader
93
93
  from model_compression_toolkit.core.pytorch.statistics_correction.apply_second_moment_correction import \
94
94
  pytorch_apply_second_moment_correction
95
+ from model_compression_toolkit.core.pytorch.statistics_correction.pytorch_compute_activation_bias_correction_of_graph import \
96
+ pytorch_compute_activation_bias_correction_of_graph
95
97
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model
96
98
  from model_compression_toolkit.exporter.model_wrapper.fw_agnostic.get_inferable_quantizers import \
97
99
  get_inferable_quantizers
@@ -212,6 +214,25 @@ class PytorchImplementation(FrameworkImplementation):
212
214
  core_config,
213
215
  fw_info)
214
216
 
217
+ def compute_activation_bias_correction(self,
218
+ graph: Graph,
219
+ quant_config: QuantizationConfig,
220
+ fw_info: FrameworkInfo):
221
+ """
222
+ Compute activation bias correction on a graph.
223
+
224
+ Args:
225
+ graph: Graph to apply activation bias correction on.
226
+ quant_config: QuantizationConfig of how the model should be quantized.
227
+ fw_info: FrameworkInfo object with information about the specific framework's model.
228
+
229
+ Returns:
230
+ Graph after activation bias correction computing.
231
+ """
232
+ return pytorch_compute_activation_bias_correction_of_graph(graph=graph,
233
+ quant_config=quant_config,
234
+ fw_info=fw_info,
235
+ fw_impl=self)
215
236
 
216
237
  def get_substitutions_channel_equalization(self,
217
238
  quant_config: QuantizationConfig,
@@ -0,0 +1,57 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from torch.nn import Conv2d, Linear, ConvTranspose2d
17
+
18
+ from model_compression_toolkit.core import QuantizationConfig
19
+ from model_compression_toolkit.core.common import Graph
20
+ from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
+ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
23
+ from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \
24
+ compute_activation_bias_correction_of_graph
25
+ from model_compression_toolkit.core.pytorch.constants import KERNEL_SIZE
26
+
27
+
28
+ def activation_bias_correction_node_matchers():
29
+ # Match linear layers where we can add a correction.
30
+ linear_node = NodeOperationMatcher(Linear) | NodeOperationMatcher(Conv2d) | NodeOperationMatcher(ConvTranspose2d)
31
+ return linear_node
32
+
33
+
34
+ def pytorch_compute_activation_bias_correction_of_graph(graph: Graph,
35
+ quant_config: QuantizationConfig,
36
+ fw_info: FrameworkInfo,
37
+ fw_impl: FrameworkImplementation) -> Graph:
38
+ """
39
+ Compute the activation bias correction term for graph based on a PyTorch model.
40
+
41
+ Args:
42
+ graph: Graph with nodes to compute the activation bias correction.
43
+ quant_config: QuantizationConfig of how the model should be quantized.
44
+ fw_info: Framework info like lists of nodes their kernel should quantized.
45
+ fw_impl: FrameworkImplementation object with a specific framework methods implementation.
46
+
47
+ Returns:
48
+ Graph with activation bias correction term for each relevant node.
49
+ """
50
+ graph = compute_activation_bias_correction_of_graph(graph=graph,
51
+ quant_config=quant_config,
52
+ fw_info=fw_info,
53
+ fw_impl=fw_impl,
54
+ activation_bias_correction_node_matchers=
55
+ activation_bias_correction_node_matchers,
56
+ kernel_size=KERNEL_SIZE)
57
+ return graph
@@ -164,6 +164,14 @@ def core_runner(in_model: Any,
164
164
  tg,
165
165
  bit_widths_config)
166
166
 
167
+ ######################################
168
+ # Compute Activation Bias Correction
169
+ ######################################
170
+ if core_config.quantization_config.activation_bias_correction:
171
+ tg = fw_impl.compute_activation_bias_correction(graph=tg,
172
+ quant_config=core_config.quantization_config,
173
+ fw_info=fw_info)
174
+
167
175
  # Edit the graph again after finalizing the configurations.
168
176
  # This is since some actions regard the final configuration and should be edited.
169
177
  edit_network_graph(tg, fw_info, core_config.debug_config.network_editor)
tests_pytest/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,14 +0,0 @@
1
- # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,14 +0,0 @@
1
- # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,91 +0,0 @@
1
- # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import numpy as np
16
- import pytest
17
-
18
- from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader, TFDatasetFromGenerator
19
-
20
-
21
- @pytest.fixture(scope='session')
22
- def fixed_dataset():
23
- # generate 320 images with data1[i] = i and data2[i] = i+10
24
- data1 = np.stack([np.full((3, 30, 20), v) for v in range(320)], axis=0)
25
- data2 = np.stack([np.full((10,), v + 10) for v in range(320)], axis=0)
26
- return data1, data2
27
-
28
-
29
- @pytest.fixture
30
- def fixed_gen(fixed_dataset):
31
- def f():
32
- for i in range(10):
33
- yield [fixed_dataset[0][32 * i: 32 * (i + 1)], fixed_dataset[1][32 * i: 32 * (i + 1)]]
34
-
35
- return f
36
-
37
-
38
- def get_random_data_gen_fn(seed=42):
39
- """ get gen factory for reproducible gen yielding different samples in each epoch """
40
- rng = np.random.default_rng(seed)
41
-
42
- def f():
43
- for i in range(10):
44
- yield [rng.random((32, 3, 20, 30)).astype(np.float32), rng.random((32, 10)).astype(np.float32)]
45
- return f
46
-
47
-
48
- class TestTFDataUtil:
49
- create_dataloader_fn = data_gen_to_dataloader
50
-
51
- def test_iterable_dataset_from_fixed_gen(self, fixed_gen):
52
- """ tests iterable dataset from fixed gen - same samples are generated in each epoch in the same order """
53
- ds = TFDatasetFromGenerator(fixed_gen, batch_size=1)
54
- self._validate_ds_from_fixed_gen(ds, 320)
55
-
56
- def test_iterable_dataset_from_random_gen(self):
57
- """ test that dataset samples over epochs are identical to the original data generator """
58
- ds = TFDatasetFromGenerator(get_random_data_gen_fn(), batch_size=1)
59
- pass1 = np.concatenate([t[0] for t in ds], axis=0)
60
- pass2 = np.concatenate([t[0] for t in ds], axis=0)
61
-
62
- gen_fn = get_random_data_gen_fn()
63
- # one invocation is used for validation and batch size in dataset, so promote the reference gen for comparison
64
- next(gen_fn())
65
- gen_pass1 = np.concatenate([t[0] for t in gen_fn()], axis=0)
66
- gen_pass2 = np.concatenate([t[0] for t in gen_fn()], axis=0)
67
- # check that each pass is identical to corresponding pass in the original gen
68
- assert np.array_equal(pass1, gen_pass1)
69
- assert np.array_equal(pass2, gen_pass2)
70
- assert not np.allclose(pass1, pass2)
71
-
72
- def test_dataloader(self, fixed_gen):
73
- ds = TFDatasetFromGenerator(fixed_gen, batch_size=25)
74
- ds_iter = iter(ds)
75
- batch1 = next(ds_iter)
76
- assert batch1[0].shape[0] == batch1[1].shape[0] == 25
77
- assert np.array_equal(batch1[0][0], np.full((3, 30, 20), 0))
78
- assert np.array_equal(batch1[1][0], np.full((10,), 10))
79
- assert np.array_equal(batch1[0][-1], np.full((3, 30, 20), 24))
80
- assert np.array_equal(batch1[1][-1], np.full((10,), 34))
81
- assert len(ds) == 13
82
- assert ds.orig_batch_size == 32
83
-
84
- def _validate_ds_from_fixed_gen(self, ds, exp_len):
85
- for _ in range(2):
86
- for i, sample in enumerate(ds):
87
- assert np.array_equal(sample[0].cpu().numpy(), np.full((1, 3, 30, 20), i))
88
- assert np.array_equal(sample[1].cpu().numpy(), np.full((1, 10,), i + 10))
89
- assert i == exp_len - 1
90
- assert ds.orig_batch_size == 32
91
- assert len(ds) == exp_len
@@ -1,14 +0,0 @@
1
- # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================