mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250927.534__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.4.
|
30
|
+
__version__ = "2.4.2.20250927.000534"
|
@@ -32,7 +32,8 @@ def analyzer_model_quantization(representative_data_gen: Callable,
|
|
32
32
|
tb_w: TensorboardWriter,
|
33
33
|
float_graph: Graph,
|
34
34
|
quantized_graph: Graph,
|
35
|
-
fw_impl: FrameworkImplementation
|
35
|
+
fw_impl: FrameworkImplementation,
|
36
|
+
fw_info: FrameworkInfo):
|
36
37
|
"""
|
37
38
|
Plot the cosine similarity of different points on the graph between the float and quantized
|
38
39
|
graphs. Add them to the passed TensorboardWriter object and close all tensorboard writer open
|
@@ -44,12 +45,14 @@ def analyzer_model_quantization(representative_data_gen: Callable,
|
|
44
45
|
float_graph: Graph of float model.
|
45
46
|
quantized_graph: Graph of quantized model.
|
46
47
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
48
|
+
fw_info: Information needed for quantization about the specific framework.
|
47
49
|
|
48
50
|
"""
|
49
51
|
if tb_w is not None:
|
50
52
|
visual = NNVisualizer(float_graph,
|
51
53
|
quantized_graph,
|
52
|
-
fw_impl=fw_impl
|
54
|
+
fw_impl=fw_impl,
|
55
|
+
fw_info=fw_info)
|
53
56
|
if not visual.has_compare_points():
|
54
57
|
Logger.error(f'No comparing points were found to plot analyze similarity.')
|
55
58
|
else:
|
@@ -15,6 +15,7 @@
|
|
15
15
|
from abc import ABC, abstractmethod
|
16
16
|
from typing import Any, Tuple
|
17
17
|
|
18
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
18
19
|
from model_compression_toolkit.core import common
|
19
20
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
20
21
|
|
@@ -27,17 +28,20 @@ class BaseModelBuilder(ABC):
|
|
27
28
|
def __init__(self,
|
28
29
|
graph: common.Graph,
|
29
30
|
append2output=None,
|
31
|
+
fw_info: FrameworkInfo = None,
|
30
32
|
return_float_outputs: bool = False):
|
31
33
|
"""
|
32
34
|
|
33
35
|
Args:
|
34
36
|
graph: Graph to build the model from.
|
35
37
|
append2output: Nodes of graph to append to model's output.
|
38
|
+
fw_info: Information about the specific framework of the model that is built.
|
36
39
|
return_float_outputs: Whether the model returns float tensors or not.
|
37
40
|
"""
|
38
41
|
|
39
42
|
self.graph = graph
|
40
43
|
self.append2output = append2output
|
44
|
+
self.fw_info = fw_info
|
41
45
|
self.return_float_outputs = return_float_outputs
|
42
46
|
|
43
47
|
@abstractmethod
|
@@ -13,12 +13,11 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from abc import ABC, abstractmethod
|
17
16
|
import numpy as np
|
18
17
|
from model_compression_toolkit.logger import Logger
|
19
18
|
|
20
19
|
|
21
|
-
class BaseCollector(
|
20
|
+
class BaseCollector(object):
|
22
21
|
"""
|
23
22
|
Base class for statistics collection object.
|
24
23
|
"""
|
@@ -27,7 +26,6 @@ class BaseCollector(ABC):
|
|
27
26
|
# When manipulation statistics in a granularity they were not collected by, the data is invalid.
|
28
27
|
self.is_legal = True
|
29
28
|
|
30
|
-
@abstractmethod
|
31
29
|
def scale(self, scale_factor: np.ndarray):
|
32
30
|
"""
|
33
31
|
Scale all statistics in collector by some factor.
|
@@ -39,7 +37,6 @@ class BaseCollector(ABC):
|
|
39
37
|
raise NotImplemented(
|
40
38
|
f'{self.__class__.__name__} needs to implement scale operation for its state.') # pragma: no cover
|
41
39
|
|
42
|
-
@abstractmethod
|
43
40
|
def shift(self, shift_value: np.ndarray):
|
44
41
|
"""
|
45
42
|
Shift all statistics in collector by some value.
|
@@ -87,13 +87,10 @@ class MeanCollector(BaseCollector):
|
|
87
87
|
x: Tensor that goes through the mean collector and needs to be considered in the mean computation.
|
88
88
|
"""
|
89
89
|
self.i += 1 # Update the iteration index
|
90
|
-
if self.axis
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
n = x.shape[axis]
|
95
|
-
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
|
96
|
-
mu = np.mean(np.reshape(np.transpose(x, transpose_index), [n, -1]), axis=-1) # mean per channel for a batch
|
90
|
+
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
|
91
|
+
n = x.shape[axis]
|
92
|
+
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
|
93
|
+
mu = np.mean(np.reshape(np.transpose(x, transpose_index), [n, -1]), axis=-1) # mean per channel for a batch
|
97
94
|
self.current_sum += mu # sum of all batches
|
98
95
|
self.current_mean = self.current_sum / self.i # mean of all batches
|
99
96
|
|
@@ -130,13 +130,10 @@ class MinMaxPerChannelCollector(BaseCollector):
|
|
130
130
|
x: Tensor that goes through the collector and needs to be considered in the min/max computation.
|
131
131
|
"""
|
132
132
|
|
133
|
-
if self.axis
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
n = x.shape[axis]
|
138
|
-
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
|
139
|
-
x_reshape = np.reshape(np.transpose(x, transpose_index), [n, -1])
|
133
|
+
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
|
134
|
+
n = x.shape[axis]
|
135
|
+
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
|
136
|
+
x_reshape = np.reshape(np.transpose(x, transpose_index), [n, -1])
|
140
137
|
if self.state is None:
|
141
138
|
x_max = np.max(x_reshape, axis=-1)
|
142
139
|
x_min = np.min(x_reshape, axis=-1)
|
@@ -125,16 +125,18 @@ class FrameworkImplementation(ABC):
|
|
125
125
|
graph: Graph,
|
126
126
|
mode: ModelBuilderMode,
|
127
127
|
append2output: List[Any],
|
128
|
+
fw_info: FrameworkInfo,
|
128
129
|
return_float_outputs: bool = False) -> Tuple:
|
129
130
|
"""
|
130
131
|
Build a framework model from a graph.
|
131
|
-
The mode determines how the model should be
|
132
|
+
The mode determines how the model should be build. append2output is a list of Nodes
|
132
133
|
to set as the model outputs.
|
133
134
|
|
134
135
|
Args:
|
135
136
|
graph: Graph to build the model from it.
|
136
137
|
mode: Mode for how to build the model.
|
137
138
|
append2output: List of Nodes to set as the model's outputs.
|
139
|
+
fw_info: FrameworkInfo object with information about the specific framework's model
|
138
140
|
return_float_outputs (bool): whether to return outputs before or after quantization nodes (default)
|
139
141
|
|
140
142
|
Returns:
|
@@ -168,13 +170,15 @@ class FrameworkImplementation(ABC):
|
|
168
170
|
@abstractmethod
|
169
171
|
def shift_negative_correction(self,
|
170
172
|
graph: Graph,
|
171
|
-
core_config: CoreConfig
|
173
|
+
core_config: CoreConfig,
|
174
|
+
fw_info: FrameworkInfo) -> Graph:
|
172
175
|
"""
|
173
176
|
Apply shift negative correction (SNC) on a graph.
|
174
177
|
|
175
178
|
Args:
|
176
179
|
graph: Graph to apply SNC on.
|
177
180
|
core_config: Quantization configuration.
|
181
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
178
182
|
|
179
183
|
Returns:
|
180
184
|
Graph after SNC.
|
@@ -185,13 +189,15 @@ class FrameworkImplementation(ABC):
|
|
185
189
|
@abstractmethod
|
186
190
|
def compute_activation_bias_correction(self,
|
187
191
|
graph: Graph,
|
188
|
-
quant_config: QuantizationConfig
|
192
|
+
quant_config: QuantizationConfig,
|
193
|
+
fw_info: FrameworkInfo) -> Graph:
|
189
194
|
"""
|
190
195
|
Compute activation bias correction on a graph.
|
191
196
|
|
192
197
|
Args:
|
193
198
|
graph: Graph to apply activation bias correction on.
|
194
199
|
quant_config: QuantizationConfig of how the model should be quantized.
|
200
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
195
201
|
|
196
202
|
Returns:
|
197
203
|
Graph after activation bias correction computing.
|
@@ -201,28 +207,30 @@ class FrameworkImplementation(ABC):
|
|
201
207
|
|
202
208
|
@abstractmethod
|
203
209
|
def get_substitutions_channel_equalization(self,
|
204
|
-
quant_config: QuantizationConfig
|
210
|
+
quant_config: QuantizationConfig,
|
211
|
+
fw_info: FrameworkInfo) -> List[common.BaseSubstitution]:
|
205
212
|
"""
|
206
213
|
Return a list of the framework substitutions used for channel equalization.
|
207
214
|
|
208
215
|
Args:
|
209
216
|
quant_config: QuantizationConfig to determine which substitutions to return.
|
217
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
210
218
|
|
211
219
|
Returns:
|
212
220
|
A list of the framework substitutions used after we collect statistics.
|
213
221
|
"""
|
214
222
|
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
215
|
-
|
223
|
+
f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover
|
216
224
|
|
217
225
|
@abstractmethod
|
218
|
-
def get_substitutions_prepare_graph(self) -> List[common.BaseSubstitution]:
|
226
|
+
def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List[common.BaseSubstitution]:
|
219
227
|
"""
|
220
228
|
|
221
229
|
Returns: A list of the framework substitutions used to prepare the graph.
|
222
230
|
|
223
231
|
"""
|
224
232
|
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
225
|
-
|
233
|
+
f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover
|
226
234
|
|
227
235
|
@abstractmethod
|
228
236
|
def get_substitutions_pre_statistics_collection(self, quant_config: QuantizationConfig) -> \
|
@@ -320,12 +328,14 @@ class FrameworkImplementation(ABC):
|
|
320
328
|
f'method.') # pragma: no cover
|
321
329
|
|
322
330
|
def get_node_prior_info(self, node: BaseNode,
|
331
|
+
fw_info: FrameworkInfo,
|
323
332
|
graph: Graph) -> NodePriorInfo:
|
324
333
|
"""
|
325
334
|
Get a NodePriorInfo object for a node.
|
326
335
|
|
327
336
|
Args:
|
328
337
|
node: Node to get its prior info.
|
338
|
+
fw_info: Framework specific information needed to create the prior info of the node.
|
329
339
|
graph: Graph to check the next node type.
|
330
340
|
|
331
341
|
Returns:
|
@@ -333,7 +343,7 @@ class FrameworkImplementation(ABC):
|
|
333
343
|
"""
|
334
344
|
|
335
345
|
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
336
|
-
|
346
|
+
f'framework\'s get_node_prior_info method.') # pragma: no cover
|
337
347
|
|
338
348
|
def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
|
339
349
|
"""
|
@@ -384,18 +394,20 @@ class FrameworkImplementation(ABC):
|
|
384
394
|
|
385
395
|
@abstractmethod
|
386
396
|
def get_node_mac_operations(self,
|
387
|
-
node: BaseNode
|
397
|
+
node: BaseNode,
|
398
|
+
fw_info: FrameworkInfo) -> float:
|
388
399
|
"""
|
389
400
|
Gets the MAC operation count for a given operation.
|
390
401
|
|
391
402
|
Args:
|
392
403
|
node: A graph node that wraps the operation for which the MAC count is computed.
|
404
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
393
405
|
|
394
406
|
Returns: The MAC count of the operation
|
395
407
|
"""
|
396
408
|
|
397
409
|
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
398
|
-
|
410
|
+
f'framework\'s get_node_mac_operations method.') # pragma: no cover
|
399
411
|
|
400
412
|
@abstractmethod
|
401
413
|
def apply_second_moment_correction(self,
|
@@ -13,9 +13,19 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
|
17
|
+
from collections.abc import Callable
|
16
18
|
from enum import Enum
|
17
|
-
from typing import Dict, Any,
|
18
|
-
|
19
|
+
from typing import Dict, Any, List
|
20
|
+
|
21
|
+
from mct_quantizers import QuantizationMethod
|
22
|
+
from model_compression_toolkit.defaultdict import DefaultDict
|
23
|
+
|
24
|
+
|
25
|
+
# Default value to use for ops without kernel.
|
26
|
+
# This is a weird default, but it's used all over the place, so for now only extract it to const so that it can be
|
27
|
+
# referenced by variable instead of hard-coded.
|
28
|
+
DEFAULT_KERNEL_ATTRIBUTES = [None]
|
19
29
|
|
20
30
|
|
21
31
|
class ChannelAxis(Enum):
|
@@ -32,67 +42,89 @@ class ChannelAxis(Enum):
|
|
32
42
|
NCHW = 1
|
33
43
|
|
34
44
|
|
35
|
-
class
|
36
|
-
|
37
|
-
|
45
|
+
class FrameworkInfo:
|
46
|
+
|
47
|
+
def __init__(self,
|
48
|
+
activation_quantizer_mapping: Dict[QuantizationMethod, Callable],
|
49
|
+
kernel_channels_mapping: DefaultDict,
|
50
|
+
activation_min_max_mapping: Dict[str, tuple],
|
51
|
+
layer_min_max_mapping: Dict[Any, tuple],
|
52
|
+
kernel_ops_attributes_mapping: DefaultDict,
|
53
|
+
out_channel_axis_mapping: DefaultDict):
|
54
|
+
"""
|
55
|
+
A class to wrap all information about a specific framework the library needs to quantize a model.
|
56
|
+
Specifically, FrameworkInfo holds lists of layers by how they should be quantized, and multiple mappings such as
|
57
|
+
layer to it kernel channels indices, and a layer to its min/max values, etc.
|
58
|
+
The layers lists are divided into three groups:
|
59
|
+
kernel_ops: Layers that have coefficients and need to get quantized (e.g., Conv2D, Dense, etc.)
|
60
|
+
activation_ops: Layers that their outputs should get quantized (e.g., Add, ReLU, etc.)
|
61
|
+
no_quantization_ops:Layers that should not get quantized (e.g., Reshape, Transpose, etc.)
|
38
62
|
|
63
|
+
Args:
|
64
|
+
activation_quantizer_mapping (Dict[QuantizationMethod, Callable]): A dictionary mapping from QuantizationMethod to a quantization function.
|
65
|
+
kernel_channels_mapping (DefaultDict): Dictionary from a layer to a tuple of its kernel in/out channels indices.
|
66
|
+
activation_min_max_mapping (Dict[str, tuple]): Dictionary from an activation function to its min/max output values.
|
67
|
+
layer_min_max_mapping (Dict[Any, tuple]): Dictionary from a layer to its min/max output values.
|
68
|
+
kernel_ops_attributes_mapping (DefaultDict): Dictionary from a framework operator to a list of its weights attirbutes to quantize.
|
69
|
+
out_channel_axis_mapping (DefaultDict): Dictionary of output channels of the model's layers (for computing statistics per-channel).
|
39
70
|
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
activation_ops: Layers that their outputs should get quantized (e.g., Add, ReLU, etc.)
|
48
|
-
no_quantization_ops:Layers that should not get quantized (e.g., Reshape, Transpose, etc.)
|
49
|
-
|
50
|
-
Fields:
|
51
|
-
kernel_channels_mapping (Dict): Dictionary from a layer to a tuple of its kernel in/out channels indices.
|
52
|
-
kernel_ops_attribute_mapping (Dict): Dictionary from a framework operator to its weight attribute to quantize.
|
53
|
-
out_channel_axis_mapping (Dict): Dictionary of output channels of the model's layers (for computing statistics per-channel).
|
54
|
-
_layer_min_max_mapping (Dict[Any, tuple]): Dictionary from a layer to its min/max output values.
|
55
|
-
"""
|
71
|
+
Examples:
|
72
|
+
When quantizing a Keras model, if we want to quantize the kernels of Conv2D layers only, we can
|
73
|
+
set, and we know it's kernel out/in channel indices are (3, 2) respectivly:
|
74
|
+
|
75
|
+
>>> import tensorflow as tf
|
76
|
+
>>> kernel_ops = [tf.keras.layers.Conv2D]
|
77
|
+
>>> kernel_channels_mapping = DefaultDict({tf.keras.layers.Conv2D: (3,2)})
|
56
78
|
|
57
|
-
|
58
|
-
kernel_channels_mapping: Dict[Any, ChannelAxisMapping]
|
59
|
-
out_channel_axis_mapping: Dict[Any, int]
|
79
|
+
Then, we can create a FrameworkInfo object:
|
60
80
|
|
61
|
-
|
62
|
-
|
81
|
+
>>> FrameworkInfo(kernel_channels_mapping, {}, {})
|
82
|
+
|
83
|
+
If an activation layer (tf.keras.layers.Activation) should be quantized and we know it's min/max outputs range in advanced, we can add it to activation_min_max_mapping for saving the statistics collection time. For example:
|
84
|
+
|
85
|
+
>>> activation_min_max_mapping = {'softmax': (0, 1)}
|
86
|
+
>>> FrameworkInfo(kernel_channels_mapping, activation_min_max_mapping, {})
|
87
|
+
|
88
|
+
If a layer's activations should be quantized and we know it's min/max outputs range in advanced, we can add it to layer_min_max_mapping for saving the statistics collection time. For example:
|
89
|
+
|
90
|
+
>>> layer_min_max_mapping = {tf.keras.layers.Softmax: (0, 1)}
|
91
|
+
>>> FrameworkInfo(kernel_channels_mapping, activation_min_max_mapping, layer_min_max_mapping)
|
63
92
|
|
64
|
-
@classmethod
|
65
|
-
def get_kernel_op_attribute(cls, node_type: Any) -> Optional[str]:
|
66
93
|
"""
|
67
|
-
|
94
|
+
|
95
|
+
self.activation_quantizer_mapping = activation_quantizer_mapping
|
96
|
+
self.kernel_channels_mapping = kernel_channels_mapping
|
97
|
+
self.activation_min_max_mapping = activation_min_max_mapping
|
98
|
+
self.layer_min_max_mapping = layer_min_max_mapping
|
99
|
+
self.kernel_ops_attributes_mapping = kernel_ops_attributes_mapping
|
100
|
+
self.out_channel_axis_mapping = out_channel_axis_mapping
|
101
|
+
|
102
|
+
def get_kernel_op_attributes(self, node_type: Any) -> List[str]:
|
103
|
+
"""
|
104
|
+
Get a list of attributes of a layer's weights to quantize.
|
68
105
|
|
69
106
|
Args:
|
70
|
-
node_type: Layer to get its
|
107
|
+
node_type: Layer to get its attributes.
|
71
108
|
|
72
109
|
Returns:
|
73
|
-
|
110
|
+
A list of attributes the layer has and should be quantized.
|
74
111
|
"""
|
75
|
-
|
112
|
+
attr_list = self.kernel_ops_attributes_mapping.get(node_type)
|
113
|
+
return attr_list
|
76
114
|
|
77
|
-
|
78
|
-
def get_layer_min_max(cls, layer: Any, fw_attrs: Dict) -> Tuple[float, float]:
|
115
|
+
def is_kernel_op(self, node_type: Any) -> bool:
|
79
116
|
"""
|
80
|
-
|
117
|
+
Check is the node is a kernel operation.
|
118
|
+
|
81
119
|
Args:
|
82
|
-
|
83
|
-
fw_attrs: framework attributes from framework layer.
|
120
|
+
node_type: Layer to get its attributes.
|
84
121
|
|
85
122
|
Returns:
|
86
|
-
|
123
|
+
True if node type is a kernel operation, else False.
|
87
124
|
"""
|
125
|
+
return node_type in self.kernel_ops_attributes_mapping.keys()
|
88
126
|
|
89
|
-
|
90
|
-
return cls._layer_min_max_mapping[layer]
|
91
|
-
else:
|
92
|
-
return None, None
|
93
|
-
|
94
|
-
@classmethod
|
95
|
-
def layers_has_min_max(cls, layer: Any) -> bool:
|
127
|
+
def layers_has_min_max(self, layer: Any) -> bool:
|
96
128
|
"""
|
97
129
|
Check if a layer is in a layer to min/max mapping the FrameworkInfo holds.
|
98
130
|
Args:
|
@@ -102,59 +134,17 @@ class FrameworkInfo(ABC):
|
|
102
134
|
Whether a layer has a min/max known values or not.
|
103
135
|
"""
|
104
136
|
|
105
|
-
return layer in
|
137
|
+
return layer in self.layer_min_max_mapping
|
106
138
|
|
107
|
-
|
108
|
-
@abstractmethod
|
109
|
-
def get_kernel_channels(cls, node_type: Any) -> ChannelAxisMapping:
|
110
|
-
"""
|
111
|
-
Returns node's channels mapping from kernel_channels_mapping or framework specific default value.
|
112
|
-
Args:
|
113
|
-
node_type: A node type
|
114
|
-
|
115
|
-
Returns:
|
116
|
-
Node's channels mapping.
|
139
|
+
def activation_has_min_max(self, activation_name: str) -> bool:
|
117
140
|
"""
|
118
|
-
|
141
|
+
Check if an activation layer has a min/max mapping.
|
119
142
|
|
120
|
-
@classmethod
|
121
|
-
@abstractmethod
|
122
|
-
def get_out_channel_axis(cls, node_type: Any):
|
123
|
-
"""
|
124
|
-
Returns node's output channel mapping from out_channel_axis_mapping or framework specific default value.
|
125
143
|
Args:
|
126
|
-
|
144
|
+
activation_name: String of the activation function to check for its min/max values.
|
127
145
|
|
128
146
|
Returns:
|
129
|
-
|
130
|
-
|
147
|
+
Whether an activation layer has a min/max known values or not.
|
131
148
|
"""
|
132
|
-
pass
|
133
|
-
|
134
|
-
|
135
|
-
# Pointer to current FrameworkInfo class.
|
136
|
-
_current_framework_info: type[FrameworkInfo] = None
|
137
|
-
|
138
|
-
|
139
|
-
def get_fw_info():
|
140
|
-
"""
|
141
|
-
A common function to get the current FrameworkInfo class. Raises an error if the pointer wasn't initialized.
|
142
|
-
|
143
|
-
Returns: FrameworkInfo class.
|
144
|
-
"""
|
145
|
-
assert _current_framework_info is not None, "fw_info isn't initialized."
|
146
|
-
return _current_framework_info
|
147
|
-
|
148
|
-
|
149
|
-
def set_fw_info(fw_info: type[FrameworkInfo]):
|
150
|
-
"""
|
151
|
-
A common function to set the current FrameworkInfo class. Raises an error if fw_info doesn't inherit from FrameworkInfo.
|
152
|
-
|
153
|
-
Args:
|
154
|
-
fw_info: Framework specific object implementing the FrameworkInfo.
|
155
|
-
"""
|
156
|
-
global _current_framework_info
|
157
|
-
assert _current_framework_info in [None, _current_framework_info], "FrameworkInfo already initialized."
|
158
|
-
assert issubclass(fw_info, FrameworkInfo), "fw_info must inherit from FrameworkInfo."
|
159
149
|
|
160
|
-
|
150
|
+
return activation_name in self.activation_min_max_mapping
|
@@ -14,12 +14,12 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
import copy
|
17
|
-
from typing import Tuple
|
17
|
+
from typing import List, Tuple
|
18
18
|
|
19
19
|
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator
|
20
20
|
from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
|
21
|
-
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import
|
22
|
-
|
21
|
+
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig
|
22
|
+
from itertools import product
|
23
23
|
|
24
24
|
|
25
25
|
class FusedLayerType:
|
@@ -30,7 +30,6 @@ class FusedLayerType:
|
|
30
30
|
def __init__(self):
|
31
31
|
self.__name__ = 'FusedLayer'
|
32
32
|
|
33
|
-
|
34
33
|
class GraphFuser:
|
35
34
|
def apply_node_fusion(self, graph: Graph) -> Graph:
|
36
35
|
"""
|
@@ -65,6 +64,7 @@ class GraphFuser:
|
|
65
64
|
|
66
65
|
return graph_copy
|
67
66
|
|
67
|
+
|
68
68
|
@staticmethod
|
69
69
|
def _create_fused_node(fused_node_id: str, nodes: Tuple[BaseNode]) -> BaseNode:
|
70
70
|
"""
|
@@ -86,15 +86,10 @@ class GraphFuser:
|
|
86
86
|
weights={},
|
87
87
|
layer_class=FusedLayerType)
|
88
88
|
|
89
|
-
base_cfg = CandidateNodeQuantizationConfig(
|
90
|
-
activation_quantization_cfg=nodes[-1].quantization_cfg.base_quantization_cfg.activation_quantization_cfg,
|
91
|
-
weights_quantization_cfg=None
|
92
|
-
)
|
93
89
|
activation_cfgs = [c.activation_quantization_cfg for c in nodes[-1].candidates_quantization_cfg]
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
candidates_quantization_cfg=candidates)
|
90
|
+
fused_node.candidates_quantization_cfg = [
|
91
|
+
CandidateNodeQuantizationConfig(weights_quantization_cfg=None, activation_quantization_cfg=a) for a in
|
92
|
+
activation_cfgs]
|
98
93
|
|
99
94
|
# Keep the final configurations if they were set already.
|
100
95
|
fused_node.final_weights_quantization_cfg = nodes[0].final_weights_quantization_cfg
|
@@ -163,3 +158,5 @@ class GraphFuser:
|
|
163
158
|
|
164
159
|
# Finally, add the new fused node to the graph
|
165
160
|
graph.add_node(fused_node)
|
161
|
+
|
162
|
+
|