mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py
CHANGED
@@ -34,7 +34,6 @@ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOpera
|
|
34
34
|
NodeFrameworkAttrMatcher
|
35
35
|
from model_compression_toolkit.core.common.substitutions.shift_negative_activation import \
|
36
36
|
apply_shift_negative_correction
|
37
|
-
from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
|
38
37
|
from model_compression_toolkit.core.keras.constants import KERNEL_SIZE, STRIDES, ACTIVATION, SWISH, \
|
39
38
|
SELU, GELU, FUNCTION, ADD, PAD
|
40
39
|
from model_compression_toolkit.core.keras.constants import NEGATIVE_SLOPE, PADDING, PAD_SAME, PAD_VALID, BIAS, USE_BIAS
|
@@ -228,13 +227,15 @@ def is_padding_node_and_node_has_padding(pad_node_to_consider: BaseNode,
|
|
228
227
|
|
229
228
|
|
230
229
|
def keras_apply_shift_negative_correction(graph: Graph,
|
231
|
-
core_config: CoreConfig
|
230
|
+
core_config: CoreConfig,
|
231
|
+
fw_info: FrameworkInfo) -> Graph:
|
232
232
|
"""
|
233
233
|
Apply shift negative correction (SNC) on a graph built from a Keras model.
|
234
234
|
|
235
235
|
Args:
|
236
236
|
graph: Graph to apply SNC on.
|
237
237
|
core_config: Quantization configuration.
|
238
|
+
fw_info: FrameworkInfo object with information about the specific framework's module.
|
238
239
|
|
239
240
|
Returns:
|
240
241
|
Graph after SNC.
|
@@ -243,6 +244,7 @@ def keras_apply_shift_negative_correction(graph: Graph,
|
|
243
244
|
|
244
245
|
return apply_shift_negative_correction(graph,
|
245
246
|
core_config,
|
247
|
+
fw_info,
|
246
248
|
snc_node,
|
247
249
|
linear_node,
|
248
250
|
bypass_node,
|
@@ -253,6 +255,5 @@ def keras_apply_shift_negative_correction(graph: Graph,
|
|
253
255
|
is_padding_node_and_node_has_padding,
|
254
256
|
PADDING,
|
255
257
|
BIAS,
|
256
|
-
USE_BIAS
|
257
|
-
get_activation_quantization_fn_factory
|
258
|
+
USE_BIAS
|
258
259
|
)
|
@@ -22,6 +22,7 @@ from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESS
|
|
22
22
|
from model_compression_toolkit.core.common import Graph
|
23
23
|
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
|
24
24
|
from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
|
25
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
25
26
|
from model_compression_toolkit.core.keras.hessian.hessian_scores_calculator_keras import HessianScoresCalculatorKeras
|
26
27
|
from model_compression_toolkit.logger import Logger
|
27
28
|
|
@@ -94,11 +95,20 @@ class WeightsHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
|
|
94
95
|
for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
|
95
96
|
|
96
97
|
# Check if the target node's layer type is supported.
|
97
|
-
if not ipt_node.
|
98
|
+
if not DEFAULT_KERAS_INFO.is_kernel_op(ipt_node.type):
|
98
99
|
Logger.critical(f"Hessian information with respect to weights is not supported for "
|
99
100
|
f"{ipt_node.type} layers.") # pragma: no cover
|
100
101
|
|
101
|
-
|
102
|
+
# Get the weight attributes for the target node type
|
103
|
+
weight_attributes = DEFAULT_KERAS_INFO.get_kernel_op_attributes(ipt_node.type)
|
104
|
+
|
105
|
+
# Get the weight tensor for the target node
|
106
|
+
if len(weight_attributes) != 1: # pragma: no cover
|
107
|
+
Logger.critical(
|
108
|
+
f"Hessian-based scoring with respect to weights is currently supported only for nodes with "
|
109
|
+
f"a single weight attribute. Found {len(weight_attributes)} attributes.")
|
110
|
+
|
111
|
+
weight_tensor = getattr(model.get_layer(ipt_node.name), weight_attributes[0])
|
102
112
|
|
103
113
|
if j == 0:
|
104
114
|
# On the first iteration we store the weight_tensor shape for later reshaping the results
|
@@ -106,7 +116,7 @@ class WeightsHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
|
|
106
116
|
tensors_original_shape.append(weight_tensor.shape)
|
107
117
|
|
108
118
|
# Get the output channel index (needed for HessianInfoGranularity.PER_OUTPUT_CHANNEL case)
|
109
|
-
output_channel_axis = ipt_node.
|
119
|
+
output_channel_axis, _ = DEFAULT_KERAS_INFO.kernel_channels_mapping.get(ipt_node.type)
|
110
120
|
|
111
121
|
# Get number of scores that should be calculated by the granularity.
|
112
122
|
num_of_scores = self._get_num_scores_by_granularity(weight_tensor,
|
@@ -65,6 +65,7 @@ from model_compression_toolkit.core.common import Graph, BaseNode
|
|
65
65
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
66
66
|
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
|
67
67
|
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
|
68
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
68
69
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.activation_decomposition import \
|
69
70
|
ActivationDecomposition
|
70
71
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.matmul_substitution import \
|
@@ -174,16 +175,18 @@ class KerasImplementation(FrameworkImplementation):
|
|
174
175
|
graph: Graph,
|
175
176
|
mode: ModelBuilderMode,
|
176
177
|
append2output: List[Any] = None,
|
178
|
+
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
|
177
179
|
return_float_outputs: bool = False) -> Tuple:
|
178
180
|
"""
|
179
181
|
Build a Keras model from a graph.
|
180
|
-
The mode determines how the model should be
|
182
|
+
The mode determines how the model should be build. append2output is a list of Nodes
|
181
183
|
to set as the model outputs.
|
182
184
|
|
183
185
|
Args:
|
184
186
|
graph: Graph to build the model from it.
|
185
187
|
mode: Mode for how to build the model.
|
186
188
|
append2output: List of Nodes to set as the model's outputs.
|
189
|
+
fw_info: FrameworkInfo object with information about the specific framework's model
|
187
190
|
return_float_outputs (bool): whether to return outputs before or after quantization nodes (default)
|
188
191
|
Returns:
|
189
192
|
A tuple with the model and additional relevant supporting objects.
|
@@ -192,6 +195,7 @@ class KerasImplementation(FrameworkImplementation):
|
|
192
195
|
keras_model_builder = get_keras_model_builder(mode)
|
193
196
|
return keras_model_builder(graph=graph,
|
194
197
|
append2output=append2output,
|
198
|
+
fw_info=fw_info,
|
195
199
|
return_float_outputs=return_float_outputs).build_model()
|
196
200
|
|
197
201
|
def run_model_inference(self,
|
@@ -223,57 +227,65 @@ class KerasImplementation(FrameworkImplementation):
|
|
223
227
|
|
224
228
|
def shift_negative_correction(self,
|
225
229
|
graph: Graph,
|
226
|
-
core_config: CoreConfig
|
230
|
+
core_config: CoreConfig,
|
231
|
+
fw_info: FrameworkInfo) -> Graph:
|
227
232
|
"""
|
228
233
|
Apply shift negative correction (SNC) on a graph.
|
229
234
|
|
230
235
|
Args:
|
231
236
|
graph: Graph to apply SNC on.
|
232
237
|
core_config: Quantization configuration.
|
238
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
233
239
|
|
234
240
|
Returns:
|
235
241
|
Graph after SNC.
|
236
242
|
"""
|
237
243
|
return keras_apply_shift_negative_correction(graph,
|
238
|
-
core_config
|
244
|
+
core_config,
|
245
|
+
fw_info)
|
239
246
|
|
240
247
|
def compute_activation_bias_correction(self,
|
241
248
|
graph: Graph,
|
242
|
-
quant_config: QuantizationConfig
|
249
|
+
quant_config: QuantizationConfig,
|
250
|
+
fw_info: FrameworkInfo):
|
243
251
|
"""
|
244
252
|
Compute activation bias correction on a graph.
|
245
253
|
|
246
254
|
Args:
|
247
255
|
graph: Graph to apply activation bias correction on.
|
248
256
|
quant_config: QuantizationConfig of how the model should be quantized.
|
257
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
249
258
|
|
250
259
|
Returns:
|
251
260
|
Graph after activation bias correction computing.
|
252
261
|
"""
|
253
262
|
return keras_compute_activation_bias_correction_of_graph(graph=graph,
|
254
263
|
quant_config=quant_config,
|
264
|
+
fw_info=fw_info,
|
255
265
|
fw_impl=self)
|
256
266
|
|
257
267
|
def get_substitutions_channel_equalization(self,
|
258
|
-
quant_config: QuantizationConfig
|
268
|
+
quant_config: QuantizationConfig,
|
269
|
+
fw_info: FrameworkInfo) -> List[common.BaseSubstitution]:
|
259
270
|
"""
|
260
271
|
Return a list of the framework substitutions used for channel equalization.
|
261
272
|
|
262
273
|
Args:
|
263
274
|
quant_config: QuantizationConfig to determine which substitutions to return.
|
275
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
264
276
|
|
265
277
|
Returns:
|
266
278
|
A list of the framework substitutions used after we collect statistics.
|
267
279
|
"""
|
268
280
|
substitutions_list = []
|
269
281
|
if quant_config.activation_channel_equalization:
|
270
|
-
substitutions_list.extend([ScaleEqualization(quant_config),
|
271
|
-
ScaleEqualizationWithPad(quant_config),
|
272
|
-
ScaleEqualizationMidActivation(quant_config),
|
273
|
-
ScaleEqualizationMidActivationWithPad(quant_config)])
|
282
|
+
substitutions_list.extend([ScaleEqualization(quant_config, fw_info),
|
283
|
+
ScaleEqualizationWithPad(quant_config, fw_info),
|
284
|
+
ScaleEqualizationMidActivation(quant_config, fw_info),
|
285
|
+
ScaleEqualizationMidActivationWithPad(quant_config, fw_info)])
|
274
286
|
return substitutions_list
|
275
287
|
|
276
|
-
def get_substitutions_prepare_graph(self) -> List[common.BaseSubstitution]:
|
288
|
+
def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List[common.BaseSubstitution]:
|
277
289
|
"""
|
278
290
|
|
279
291
|
Returns: A list of the framework substitutions used to prepare the graph.
|
@@ -357,8 +369,8 @@ class KerasImplementation(FrameworkImplementation):
|
|
357
369
|
if quant_config.softmax_shift:
|
358
370
|
substitutions_list.append(keras_softmax_shift())
|
359
371
|
if quant_config.input_scaling:
|
360
|
-
substitutions_list.append(InputScaling(
|
361
|
-
substitutions_list.append(InputScalingWithPad(
|
372
|
+
substitutions_list.append(InputScaling())
|
373
|
+
substitutions_list.append(InputScalingWithPad())
|
362
374
|
if quant_config.concat_threshold_update:
|
363
375
|
substitutions_list.append(ConcatThresholdUpdate())
|
364
376
|
return substitutions_list
|
@@ -390,19 +402,22 @@ class KerasImplementation(FrameworkImplementation):
|
|
390
402
|
|
391
403
|
def get_node_prior_info(self,
|
392
404
|
node: BaseNode,
|
405
|
+
fw_info: FrameworkInfo,
|
393
406
|
graph: Graph) -> NodePriorInfo:
|
394
407
|
"""
|
395
408
|
Get a NodePriorInfo object for a node that represents a Keras layer.
|
396
409
|
|
397
410
|
Args:
|
398
411
|
node: Node to get its prior info.
|
412
|
+
fw_info: Framework specific information needed to create the prior info of the node.
|
399
413
|
graph: Graph to check the next node type.
|
400
414
|
|
401
415
|
Returns:
|
402
416
|
NodePriorInfo with information about the node.
|
403
417
|
"""
|
404
418
|
|
405
|
-
return create_node_prior_info(node=node,
|
419
|
+
return create_node_prior_info(node=node,
|
420
|
+
fw_info=fw_info, graph=graph)
|
406
421
|
|
407
422
|
def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
|
408
423
|
"""
|
@@ -515,19 +530,23 @@ class KerasImplementation(FrameworkImplementation):
|
|
515
530
|
return True
|
516
531
|
|
517
532
|
def get_node_mac_operations(self,
|
518
|
-
node: BaseNode
|
533
|
+
node: BaseNode,
|
534
|
+
fw_info: FrameworkInfo) -> float:
|
519
535
|
"""
|
520
536
|
Gets the MAC operation count for a given operation.
|
521
537
|
|
522
538
|
Args:
|
523
539
|
node: A graph node that wraps the operation for which the MAC count is computed.
|
540
|
+
fw_info: FrameworkInfo object with information about the Keras model.
|
524
541
|
|
525
542
|
Returns: The MAC count og the operation
|
526
543
|
"""
|
527
|
-
|
544
|
+
kernels = fw_info.get_kernel_op_attributes(node.type)
|
545
|
+
if not kernels or kernels[0] is None:
|
528
546
|
return 0
|
529
547
|
|
530
|
-
|
548
|
+
assert len(kernels) == 1
|
549
|
+
kernel_shape = node.get_weights_by_keys(kernels[0]).shape
|
531
550
|
|
532
551
|
if node.is_match_type(Conv2D) or node.is_match_type(Conv2DTranspose) or node.is_match_type(DepthwiseConv2D):
|
533
552
|
h, w = node.get_output_shapes_list()[0][-3:-1]
|
@@ -535,7 +554,8 @@ class KerasImplementation(FrameworkImplementation):
|
|
535
554
|
|
536
555
|
if node.is_match_type(Dense):
|
537
556
|
# IN * OUT * (all previous dims[:-1])
|
538
|
-
|
557
|
+
_, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
|
558
|
+
return node.get_total_output_params() * kernel_shape[input_channel_axis]
|
539
559
|
|
540
560
|
return 0
|
541
561
|
|
@@ -0,0 +1,38 @@
|
|
1
|
+
from tensorflow.keras.models import Model
|
2
|
+
|
3
|
+
from model_compression_toolkit.core import FrameworkInfo
|
4
|
+
from model_compression_toolkit.core.common.framework_info import ChannelAxis
|
5
|
+
from model_compression_toolkit.core.common.model_validation import ModelValidation
|
6
|
+
from model_compression_toolkit.core.keras.constants import CHANNELS_FORMAT, CHANNELS_FORMAT_LAST, CHANNELS_FORMAT_FIRST
|
7
|
+
|
8
|
+
|
9
|
+
class KerasModelValidation(ModelValidation):
|
10
|
+
"""
|
11
|
+
Class to define validation methods in order to validate the received Keras model to quantize.
|
12
|
+
"""
|
13
|
+
|
14
|
+
def __init__(self, model: Model, fw_info: FrameworkInfo):
|
15
|
+
"""
|
16
|
+
Initialize a KerasModelValidation object.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
model: Keras model to check its validity.
|
20
|
+
fw_info: Information about the framework of the model (Keras).
|
21
|
+
"""
|
22
|
+
|
23
|
+
super(KerasModelValidation, self).__init__(model=model,
|
24
|
+
fw_info=fw_info)
|
25
|
+
|
26
|
+
def validate_output_channel_consistency(self):
|
27
|
+
"""
|
28
|
+
|
29
|
+
Validate that output channels index in all layers of the model are the same.
|
30
|
+
If the model has layers with different output channels index, an exception is thrown.
|
31
|
+
|
32
|
+
"""
|
33
|
+
for layer in self.model.layers:
|
34
|
+
data_format = layer.get_config().get(CHANNELS_FORMAT)
|
35
|
+
if data_format is not None:
|
36
|
+
assert (data_format == CHANNELS_FORMAT_LAST and self.fw_info.out_channel_axis_mapping.get(layer) == ChannelAxis.NHWC.value
|
37
|
+
or data_format == CHANNELS_FORMAT_FIRST and self.fw_info.out_channel_axis_mapping.get(layer) == ChannelAxis.NCHW.value), \
|
38
|
+
f'Model can not have layers with different data formats.'
|
@@ -17,19 +17,22 @@ from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
17
17
|
|
18
18
|
|
19
19
|
def create_node_prior_info(node: BaseNode,
|
20
|
+
fw_info: FrameworkInfo,
|
20
21
|
graph: Graph):
|
21
22
|
"""
|
22
23
|
Create a NodePriorInfo object for a given node.
|
23
24
|
|
24
25
|
Args:
|
25
26
|
node: Node to create its prior info.
|
27
|
+
fw_info: Information about a specific framework the node was generated from.
|
26
28
|
graph: Graph to check the next node type.
|
27
29
|
|
28
30
|
Returns:
|
29
31
|
NodePriorInfo object with info about the node.
|
30
32
|
"""
|
31
33
|
|
32
|
-
min_output, max_output = _get_min_max_outputs(node=node
|
34
|
+
min_output, max_output = _get_min_max_outputs(node=node,
|
35
|
+
fw_info=fw_info)
|
33
36
|
|
34
37
|
mean_output, std_output = _get_mean_std_outputs(node=node,
|
35
38
|
graph=graph)
|
@@ -39,12 +42,14 @@ def create_node_prior_info(node: BaseNode,
|
|
39
42
|
std_output=std_output)
|
40
43
|
|
41
44
|
|
42
|
-
def _get_min_max_outputs(node: BaseNode
|
45
|
+
def _get_min_max_outputs(node: BaseNode,
|
46
|
+
fw_info: FrameworkInfo) -> Tuple[Any, Any]:
|
43
47
|
"""
|
44
48
|
Return the min/max output values of a node if known.
|
45
49
|
If one of them (or both of them) is unknown - return None instead of a value.
|
46
50
|
Args:
|
47
51
|
node: Node to create its prior info.
|
52
|
+
fw_info: Information about a specific framework the node was generated from.
|
48
53
|
|
49
54
|
Returns:
|
50
55
|
Min/max output values if known.
|
@@ -53,8 +58,12 @@ def _get_min_max_outputs(node: BaseNode) -> Tuple[Any, Any]:
|
|
53
58
|
|
54
59
|
if node.is_match_type(ReLU):
|
55
60
|
min_output = node.framework_attr[THRESHOLD] if node.framework_attr[NEGATIVE_SLOPE] == 0 else None
|
56
|
-
|
57
|
-
|
61
|
+
|
62
|
+
elif fw_info.layers_has_min_max(node.type):
|
63
|
+
min_output, max_output = fw_info.layer_min_max_mapping[node.type]
|
64
|
+
|
65
|
+
elif node.is_match_type(Activation) and fw_info.activation_has_min_max(node.framework_attr[ACTIVATION]):
|
66
|
+
min_output, max_output = fw_info.activation_min_max_mapping[node.framework_attr[ACTIVATION]]
|
58
67
|
|
59
68
|
return min_output, max_output
|
60
69
|
|
@@ -23,7 +23,6 @@ from model_compression_toolkit.core.common.mixed_precision.configurable_quantize
|
|
23
23
|
verify_candidates_descending_order, init_activation_quantizers
|
24
24
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
25
25
|
CandidateNodeQuantizationConfig
|
26
|
-
from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
|
27
26
|
from model_compression_toolkit.logger import Logger
|
28
27
|
|
29
28
|
import tensorflow as tf
|
@@ -68,7 +67,7 @@ class ConfigurableActivationQuantizer(BaseKerasInferableQuantizer):
|
|
68
67
|
if qc.activation_quantization_cfg.quant_mode != node_q_cfg[0].activation_quantization_cfg.quant_mode:
|
69
68
|
Logger.critical("Unsupported configuration: Mixing candidates with differing activation quantization states (enabled/disabled).") # pragma: no cover
|
70
69
|
|
71
|
-
self.activation_quantizers = init_activation_quantizers(self.node_q_cfg
|
70
|
+
self.activation_quantizers = init_activation_quantizers(self.node_q_cfg)
|
72
71
|
self.active_quantization_config_index = max_candidate_idx # initialize with first config as default
|
73
72
|
|
74
73
|
def set_active_activation_quantizer(self, index: Optional[int]):
|
@@ -19,6 +19,7 @@ from model_compression_toolkit.core.common.pruning.pruning_framework_implementat
|
|
19
19
|
PruningFrameworkImplementation
|
20
20
|
from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
|
21
21
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
22
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
22
23
|
from model_compression_toolkit.core.common import BaseNode
|
23
24
|
from model_compression_toolkit.core.keras.constants import BIAS, GROUPS, FILTERS, UNITS, USE_BIAS
|
24
25
|
import keras
|
@@ -28,10 +29,6 @@ import numpy as np
|
|
28
29
|
from model_compression_toolkit.logger import Logger
|
29
30
|
|
30
31
|
|
31
|
-
# default output channel axis to use when it's not defined in node's fw_info.
|
32
|
-
_default_output_channel_axis = -1
|
33
|
-
|
34
|
-
|
35
32
|
class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementation):
|
36
33
|
"""
|
37
34
|
Implementation of the PruningFramework for the Keras framework. This class provides
|
@@ -41,23 +38,27 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
41
38
|
|
42
39
|
def prune_entry_node(self,
|
43
40
|
node: BaseNode,
|
44
|
-
output_mask: np.ndarray
|
41
|
+
output_mask: np.ndarray,
|
42
|
+
fw_info: FrameworkInfo):
|
45
43
|
"""
|
46
44
|
Prunes the entry node of a model in Keras.
|
47
45
|
|
48
46
|
Args:
|
49
47
|
node (BaseNode): The entry node to be pruned.
|
50
48
|
output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
|
49
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
51
50
|
|
52
51
|
"""
|
53
52
|
return _prune_keras_edge_node(node=node,
|
54
53
|
mask=output_mask,
|
54
|
+
fw_info=fw_info,
|
55
55
|
is_exit_node=False)
|
56
56
|
|
57
57
|
def prune_intermediate_node(self,
|
58
58
|
node: BaseNode,
|
59
59
|
input_mask: np.ndarray,
|
60
|
-
output_mask: np.ndarray
|
60
|
+
output_mask: np.ndarray,
|
61
|
+
fw_info: FrameworkInfo):
|
61
62
|
"""
|
62
63
|
Prunes an intermediate node in a Keras model.
|
63
64
|
|
@@ -65,6 +66,7 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
65
66
|
node (BaseNode): The intermediate node to be pruned.
|
66
67
|
input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
|
67
68
|
output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
|
69
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
68
70
|
|
69
71
|
"""
|
70
72
|
_edit_node_input_shape(input_mask, node)
|
@@ -77,17 +79,20 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
77
79
|
|
78
80
|
def prune_exit_node(self,
|
79
81
|
node: BaseNode,
|
80
|
-
input_mask: np.ndarray
|
82
|
+
input_mask: np.ndarray,
|
83
|
+
fw_info: FrameworkInfo):
|
81
84
|
"""
|
82
85
|
Prunes the exit node of a model in Keras.
|
83
86
|
|
84
87
|
Args:
|
85
88
|
node (BaseNode): The exit node to be pruned.
|
86
89
|
input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
|
90
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
87
91
|
|
88
92
|
"""
|
89
93
|
return _prune_keras_edge_node(node=node,
|
90
94
|
mask=input_mask,
|
95
|
+
fw_info=fw_info,
|
91
96
|
is_exit_node=True)
|
92
97
|
|
93
98
|
def is_node_entry_node(self, node: BaseNode) -> bool:
|
@@ -104,19 +109,22 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
104
109
|
|
105
110
|
def is_node_exit_node(self,
|
106
111
|
node: BaseNode,
|
107
|
-
corresponding_entry_node: BaseNode
|
112
|
+
corresponding_entry_node: BaseNode,
|
113
|
+
fw_info: FrameworkInfo) -> bool:
|
108
114
|
"""
|
109
115
|
Determines whether a node is an exit node in a Keras model.
|
110
116
|
|
111
117
|
Args:
|
112
118
|
node (BaseNode): The node to be checked.
|
113
119
|
corresponding_entry_node (BaseNode): The entry node of the pruning section that is checked.
|
120
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
114
121
|
|
115
122
|
Returns:
|
116
123
|
bool: Boolean indicating if the node is an exit node.
|
117
124
|
"""
|
118
125
|
return _is_keras_node_pruning_section_edge(node) and PruningSection.has_matching_channel_count(node,
|
119
|
-
corresponding_entry_node
|
126
|
+
corresponding_entry_node,
|
127
|
+
fw_info)
|
120
128
|
|
121
129
|
def is_node_intermediate_pruning_section(self, node: BaseNode) -> bool:
|
122
130
|
"""
|
@@ -135,7 +143,8 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
135
143
|
keras.layers.Dense]
|
136
144
|
|
137
145
|
def attrs_oi_channels_info_for_pruning(self,
|
138
|
-
node: BaseNode
|
146
|
+
node: BaseNode,
|
147
|
+
fw_info: FrameworkInfo) -> Dict[str, Tuple[int, int]]:
|
139
148
|
"""
|
140
149
|
Retrieves the attributes of a given node along with the output/input (OI) channel axis
|
141
150
|
for each attribute used to prune these attributes.
|
@@ -152,6 +161,7 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
152
161
|
|
153
162
|
Args:
|
154
163
|
node (BaseNode): The node from the computational graph.
|
164
|
+
fw_info (FrameworkInfo): Contains framework-specific information and utilities.
|
155
165
|
|
156
166
|
Returns:
|
157
167
|
Dict[str, Tuple[int, int]]: A dictionary where each key is an attribute name (like 'kernel' or 'bias')
|
@@ -159,8 +169,13 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
159
169
|
"""
|
160
170
|
|
161
171
|
attributes_with_axis = {}
|
162
|
-
if node.
|
163
|
-
|
172
|
+
if fw_info.is_kernel_op(node.type):
|
173
|
+
kernel_attributes = fw_info.get_kernel_op_attributes(node.type)
|
174
|
+
if kernel_attributes is None or len(kernel_attributes)==0:
|
175
|
+
Logger.critical(f"Expected kernel attributes for operation for node type {node.type}, found None or empty.")
|
176
|
+
|
177
|
+
for attr in kernel_attributes:
|
178
|
+
attributes_with_axis[attr] = fw_info.kernel_channels_mapping.get(node.type)
|
164
179
|
|
165
180
|
# Bias is a vector at the length of the number of output channels.
|
166
181
|
# For this reason, input channel axis is irrelevant to the bias attribute.
|
@@ -176,10 +191,6 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
176
191
|
|
177
192
|
return attributes_with_axis
|
178
193
|
|
179
|
-
@property
|
180
|
-
def default_output_channel_axis(self):
|
181
|
-
return _default_output_channel_axis
|
182
|
-
|
183
194
|
|
184
195
|
def _is_keras_node_pruning_section_edge(node: BaseNode) -> bool:
|
185
196
|
"""
|
@@ -205,6 +216,7 @@ def _is_keras_node_pruning_section_edge(node: BaseNode) -> bool:
|
|
205
216
|
|
206
217
|
def _prune_keras_edge_node(node: BaseNode,
|
207
218
|
mask: np.ndarray,
|
219
|
+
fw_info: FrameworkInfo,
|
208
220
|
is_exit_node: bool):
|
209
221
|
"""
|
210
222
|
Prunes the given Keras node by applying the mask to the node's weights (kernels and biases).
|
@@ -213,18 +225,21 @@ def _prune_keras_edge_node(node: BaseNode,
|
|
213
225
|
Args:
|
214
226
|
node: The node to be pruned.
|
215
227
|
mask: The pruning mask to be applied.
|
228
|
+
fw_info: Framework-specific information object.
|
216
229
|
is_exit_node: A boolean indicating whether the node is an exit node.
|
217
230
|
|
218
231
|
"""
|
219
232
|
|
220
233
|
# Retrieve the kernel attribute and the axes to prune.
|
221
|
-
|
222
|
-
|
234
|
+
kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
|
235
|
+
io_axis = fw_info.kernel_channels_mapping.get(node.type)
|
236
|
+
axis_to_prune = io_axis[int(is_exit_node)]
|
237
|
+
kernel = node.get_weights_by_keys(kernel_attr)
|
223
238
|
# Convert mask to boolean.
|
224
239
|
mask_bool = mask.astype(bool)
|
225
240
|
|
226
241
|
pruned_kernel = kernel.compress(mask_bool, axis=axis_to_prune)
|
227
|
-
node.set_weights_by_keys(name=
|
242
|
+
node.set_weights_by_keys(name=kernel_attr, tensor=pruned_kernel)
|
228
243
|
|
229
244
|
if not is_exit_node and node.framework_attr[USE_BIAS]:
|
230
245
|
# Prune the bias if applicable and it's an entry node.
|
@@ -27,15 +27,14 @@ if FOUND_TF:
|
|
27
27
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
|
28
28
|
AttachTpcToKeras
|
29
29
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
30
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
30
31
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
31
|
-
from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
|
32
32
|
from tensorflow.keras.models import Model
|
33
33
|
|
34
34
|
from model_compression_toolkit import get_target_platform_capabilities
|
35
35
|
|
36
36
|
KERAS_DEFAULT_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
37
37
|
|
38
|
-
@set_keras_info
|
39
38
|
def keras_resource_utilization_data(in_model: Model,
|
40
39
|
representative_data_gen: Callable,
|
41
40
|
core_config: CoreConfig = CoreConfig(
|
@@ -94,6 +93,7 @@ if FOUND_TF:
|
|
94
93
|
representative_data_gen,
|
95
94
|
core_config,
|
96
95
|
target_platform_capabilities,
|
96
|
+
DEFAULT_KERAS_INFO,
|
97
97
|
fw_impl)
|
98
98
|
|
99
99
|
else:
|
@@ -25,7 +25,7 @@ else:
|
|
25
25
|
|
26
26
|
from model_compression_toolkit.core import QuantizationConfig
|
27
27
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
28
|
-
from model_compression_toolkit.core.
|
28
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
29
29
|
from model_compression_toolkit.core.common import Graph
|
30
30
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
31
31
|
from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \
|
@@ -43,6 +43,7 @@ def activation_bias_correction_node_matchers():
|
|
43
43
|
|
44
44
|
def keras_compute_activation_bias_correction_of_graph(graph: Graph,
|
45
45
|
quant_config: QuantizationConfig,
|
46
|
+
fw_info: FrameworkInfo,
|
46
47
|
fw_impl: FrameworkImplementation) -> Graph:
|
47
48
|
"""
|
48
49
|
Compute the activation bias correction term for graph based on a Keras model.
|
@@ -50,6 +51,7 @@ def keras_compute_activation_bias_correction_of_graph(graph: Graph,
|
|
50
51
|
Args:
|
51
52
|
graph: Graph with nodes to compute the activation bias correction.
|
52
53
|
quant_config: QuantizationConfig of how the model should be quantized.
|
54
|
+
fw_info: Framework info like lists of nodes their kernel should quantized.
|
53
55
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
54
56
|
|
55
57
|
Returns:
|
@@ -57,9 +59,9 @@ def keras_compute_activation_bias_correction_of_graph(graph: Graph,
|
|
57
59
|
"""
|
58
60
|
graph = compute_activation_bias_correction_of_graph(graph=graph,
|
59
61
|
quant_config=quant_config,
|
62
|
+
fw_info=fw_info,
|
60
63
|
fw_impl=fw_impl,
|
61
64
|
activation_bias_correction_node_matchers=
|
62
65
|
activation_bias_correction_node_matchers,
|
63
|
-
kernel_size=KERNEL_SIZE
|
64
|
-
get_activation_quantization_fn_factory=get_activation_quantization_fn_factory)
|
66
|
+
kernel_size=KERNEL_SIZE)
|
65
67
|
return graph
|