mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -16,20 +16,21 @@ import copy
|
|
16
16
|
import numpy as np
|
17
17
|
from typing import List, Tuple, Any, Callable
|
18
18
|
|
19
|
+
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
19
20
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
|
20
21
|
ActivationQuantizationMode
|
21
22
|
from model_compression_toolkit.logger import Logger
|
22
|
-
from model_compression_toolkit.core.common import Graph, BaseNode
|
23
|
+
from model_compression_toolkit.core.common import FrameworkInfo, Graph, BaseNode
|
23
24
|
from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
|
24
25
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
26
|
+
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import create_node_activation_qc, \
|
27
|
+
set_quantization_configs_to_node
|
25
28
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
26
29
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
|
27
|
-
import
|
30
|
+
import get_activations_qparams
|
28
31
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
|
29
32
|
_mse_error_histogram
|
30
33
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation import z_score_filter
|
31
|
-
from model_compression_toolkit.quantization_preparation.load_fqc import set_quantization_configs_to_node, \
|
32
|
-
fetch_qc_options_for_node
|
33
34
|
from model_compression_toolkit.target_platform_capabilities import QuantizationMethod, AttributeQuantizationConfig
|
34
35
|
|
35
36
|
"""
|
@@ -45,6 +46,7 @@ If the linear node pads the input tensor with zeros, we modify the padded value
|
|
45
46
|
|
46
47
|
def op2d_bias_correction(op2d_node: BaseNode,
|
47
48
|
shift_to_correct: float,
|
49
|
+
fw_info: FrameworkInfo,
|
48
50
|
bias_str: str,
|
49
51
|
bias_flag_str: str):
|
50
52
|
"""
|
@@ -55,6 +57,7 @@ def op2d_bias_correction(op2d_node: BaseNode,
|
|
55
57
|
op2d_node: Node to compute its bias correction term.
|
56
58
|
shift_to_correct: Value that was used to shift the output tensor of
|
57
59
|
the non-linear node.
|
60
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
58
61
|
bias_str:
|
59
62
|
bias_flag_str: The framework specific attribute name of the bias flag.
|
60
63
|
"""
|
@@ -66,19 +69,21 @@ def op2d_bias_correction(op2d_node: BaseNode,
|
|
66
69
|
# Add an attribute quantization configuration to the newly added bias attribute, with disabled quantization
|
67
70
|
for qc in op2d_node.candidates_quantization_cfg:
|
68
71
|
qc.weights_quantization_cfg.set_attr_config(bias_flag_str,
|
69
|
-
WeightsAttrQuantizationConfig(
|
72
|
+
WeightsAttrQuantizationConfig(QuantizationConfig(),
|
73
|
+
AttributeQuantizationConfig(
|
70
74
|
enable_weights_quantization=False)))
|
71
75
|
|
72
76
|
# Each node adds a different noise due to the shifting. It depends on the
|
73
77
|
# dimensions of the kernel, thus the correction term is a function of
|
74
78
|
# the layer type.
|
75
|
-
kernel = op2d_node.get_weights_by_keys(op2d_node.
|
79
|
+
kernel = op2d_node.get_weights_by_keys(fw_info.kernel_ops_attributes_mapping.get(op2d_node.type)[0])
|
76
80
|
if kernel is not None:
|
81
|
+
output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(op2d_node.type)
|
77
82
|
axis_not_output_channel = list(range(len(kernel.shape)))
|
78
|
-
axis_not_output_channel.remove(
|
83
|
+
axis_not_output_channel.remove(output_channel_index)
|
79
84
|
|
80
85
|
# special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters
|
81
|
-
if
|
86
|
+
if output_channel_index == input_channel_index:
|
82
87
|
axis_not_output_channel.remove(3) # 3 is the depth multiplier index
|
83
88
|
|
84
89
|
bias_correction = shift_to_correct * np.sum(kernel, axis=tuple(axis_not_output_channel))
|
@@ -245,13 +250,13 @@ def shift_negative_function(graph: Graph,
|
|
245
250
|
core_config: CoreConfig,
|
246
251
|
non_linear_node: BaseNode,
|
247
252
|
op2d_node: BaseNode,
|
253
|
+
fw_info: FrameworkInfo,
|
248
254
|
create_add_node: Callable,
|
249
255
|
get_padding_values: Callable,
|
250
256
|
create_pad_node: Callable,
|
251
257
|
padding_str: str,
|
252
258
|
bias_str: str,
|
253
259
|
bias_flag_str: str,
|
254
|
-
get_activation_quantization_fn_factory: Callable,
|
255
260
|
zero_padding_node: BaseNode = None,
|
256
261
|
bypass_nodes: List = None,
|
257
262
|
params_search_quantization_fn: Callable = None
|
@@ -271,13 +276,14 @@ def shift_negative_function(graph: Graph,
|
|
271
276
|
non_linear_node: Non-linear node with negative values to shift.
|
272
277
|
op2d_node: Linear node to correct its bias to overcome the expected error due to
|
273
278
|
the shifting.
|
279
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
280
|
+
groups of layers by how they should be quantized, etc.)
|
274
281
|
create_add_node: Function to create an add node.
|
275
282
|
get_padding_values: Function to compute the op2d node's padding values
|
276
283
|
create_pad_node: Function to create an pad node.
|
277
284
|
padding_str: The framework specific attribute name of the padding.
|
278
285
|
bias_str: The framework specific attribute name of the bias.
|
279
286
|
bias_flag_str: The framework specific attribute name of the bias flag.
|
280
|
-
get_activation_quantization_fn_factory: activation quantization functions factory.
|
281
287
|
zero_padding_node: ZeroPadding2D node that may be in the graph before the linear layer.
|
282
288
|
params_search_quantization_fn: Function to quantize np tensor using a framework (tf/torch) quantization method. Needed for better param_search estimating the expected loss.
|
283
289
|
|
@@ -293,12 +299,13 @@ def shift_negative_function(graph: Graph,
|
|
293
299
|
# all candidates have same activation config, so taking the first candidate for calculations
|
294
300
|
non_linear_node_cfg_candidate = non_linear_node.candidates_quantization_cfg[0].activation_quantization_cfg
|
295
301
|
|
302
|
+
|
296
303
|
# get the non-linear activation threshold
|
297
304
|
activation_threshold = non_linear_node_cfg_candidate.activation_quantization_params.get(THRESHOLD)
|
298
305
|
|
299
306
|
negative_rate = np.abs(min_to_correct) / activation_threshold
|
300
307
|
|
301
|
-
enable_sub = negative_rate <=
|
308
|
+
enable_sub = negative_rate <= non_linear_node_cfg_candidate.shift_negative_ratio
|
302
309
|
if min_to_correct >= 0 or not enable_sub:
|
303
310
|
return graph
|
304
311
|
|
@@ -316,7 +323,7 @@ def shift_negative_function(graph: Graph,
|
|
316
323
|
if core_config.quantization_config.shift_negative_params_search:
|
317
324
|
|
318
325
|
hist_bins, hist_count = graph.get_out_stats_collector(non_linear_node).hc.get_histogram()
|
319
|
-
hist_count = z_score_filter(
|
326
|
+
hist_count = z_score_filter(non_linear_node_cfg_candidate.z_threshold,
|
320
327
|
hist_bins, hist_count)
|
321
328
|
|
322
329
|
min_mse, _th, _shift = np.inf, None, None
|
@@ -327,15 +334,13 @@ def shift_negative_function(graph: Graph,
|
|
327
334
|
'float32') # Change to type float32 to support tensorflow dtypes
|
328
335
|
for _shift_value in _q_points:
|
329
336
|
_hist_bins = hist_bins.astype(np.float32) + _shift_value
|
330
|
-
|
331
|
-
non_linear_node_cfg_candidate.activation_quantization_method)
|
332
|
-
fw_quant_fn = quantizer_factory(non_linear_node_cfg_candidate.activation_n_bits, qparams)
|
337
|
+
fw_quant_fn = non_linear_node_cfg_candidate.activation_quantization_fn(non_linear_node_cfg_candidate.activation_n_bits,qparams)
|
333
338
|
"""
|
334
339
|
In SNC, when better shifting values are tested for better choice,
|
335
340
|
the histogram (which is a numpy object) is quantized using the non-linear node activation
|
336
341
|
quantization function (to estimate the expected mse comparing to the original histogram).
|
337
342
|
The quantization function is a framework function, which makes it fail since it
|
338
|
-
expects a fw tensor. The
|
343
|
+
expects a fw tensor. The commmon part of SNC receives an argument which is a callable
|
339
344
|
that receives two argument and returns one: it gets the fw activation quantization function
|
340
345
|
and the bins to quantize. The function (of each fw) responsible for doing (if needed) a preprocessing and postprocessing
|
341
346
|
to the bins which is a numpy object.
|
@@ -385,6 +390,7 @@ def shift_negative_function(graph: Graph,
|
|
385
390
|
first_node=non_linear_node)
|
386
391
|
op2d_bias_correction(op2d_node,
|
387
392
|
shift_value,
|
393
|
+
fw_info,
|
388
394
|
bias_str,
|
389
395
|
bias_flag_str)
|
390
396
|
|
@@ -395,9 +401,12 @@ def shift_negative_function(graph: Graph,
|
|
395
401
|
graph.set_out_stats_collector_to_node(add_node, add_node_stats_collector)
|
396
402
|
graph.shift_stats_collector(add_node, np.array(shift_value))
|
397
403
|
|
398
|
-
set_quantization_configs_to_node(
|
404
|
+
set_quantization_configs_to_node(fw_info=fw_info,
|
405
|
+
node=add_node,
|
399
406
|
graph=graph,
|
400
|
-
|
407
|
+
quant_config=core_config.quantization_config,
|
408
|
+
fqc=graph.fqc,
|
409
|
+
mixed_precision_enable=core_config.is_mixed_precision_enabled)
|
401
410
|
|
402
411
|
update_fused_op_with_add(graph=graph,
|
403
412
|
non_linear_node=non_linear_node,
|
@@ -419,9 +428,12 @@ def shift_negative_function(graph: Graph,
|
|
419
428
|
last_node=op2d_node)
|
420
429
|
|
421
430
|
# Set quantization configuration to node, even though we do not quantize it:
|
422
|
-
set_quantization_configs_to_node(
|
431
|
+
set_quantization_configs_to_node(fw_info=fw_info,
|
432
|
+
node=pad_node,
|
423
433
|
graph=graph,
|
424
|
-
|
434
|
+
quant_config=core_config.quantization_config,
|
435
|
+
fqc=graph.fqc,
|
436
|
+
mixed_precision_enable=core_config.is_mixed_precision_enabled)
|
425
437
|
|
426
438
|
for candidate_qc in pad_node.candidates_quantization_cfg:
|
427
439
|
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
@@ -446,7 +458,7 @@ def shift_negative_function(graph: Graph,
|
|
446
458
|
bypass_candidate_qc.activation_quantization_cfg.activation_quantization_params[SIGNED] = False
|
447
459
|
graph.shift_stats_collector(bypass_node, np.array(shift_value))
|
448
460
|
|
449
|
-
add_node_qco =
|
461
|
+
add_node_qco = add_node.get_qco(graph.fqc).quantization_configurations
|
450
462
|
add_supported_bitwidths = [c.activation_n_bits for c in add_node_qco]
|
451
463
|
if original_non_linear_activation_nbits not in add_supported_bitwidths:
|
452
464
|
raise ValueError(
|
@@ -454,16 +466,19 @@ def shift_negative_function(graph: Graph,
|
|
454
466
|
f"bitwidth is {original_non_linear_activation_nbits}. Consider adapting the TPC so 'Add' will support the "
|
455
467
|
f"same bitwidth as {non_linear_node.type} or disable shift negative correction.")
|
456
468
|
|
457
|
-
|
458
|
-
|
459
|
-
|
469
|
+
for op_qc_idx, candidate_qc in enumerate(add_node.candidates_quantization_cfg):
|
470
|
+
for attr in add_node.get_node_weights_attributes():
|
471
|
+
# TODO: do we not quantize the weights of this 'add' on purpose?
|
472
|
+
candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False
|
473
|
+
|
474
|
+
candidate_qc.activation_quantization_cfg = create_node_activation_qc(core_config.quantization_config,
|
475
|
+
fw_info,
|
476
|
+
add_node_qco[op_qc_idx])
|
460
477
|
|
461
|
-
|
462
|
-
|
463
|
-
c.activation_quantization_cfg.set_activation_quantization_param({THRESHOLD: activation_threshold,
|
464
|
-
SIGNED: False})
|
478
|
+
candidate_qc.activation_quantization_cfg.set_activation_quantization_param({THRESHOLD: activation_threshold,
|
479
|
+
SIGNED: False})
|
465
480
|
|
466
|
-
|
481
|
+
candidate_qc.activation_quantization_cfg.activation_n_bits = original_non_linear_activation_nbits
|
467
482
|
|
468
483
|
# Add the new padding node to a fused op with the op2d.
|
469
484
|
if pad_node:
|
@@ -471,14 +486,12 @@ def shift_negative_function(graph: Graph,
|
|
471
486
|
pad_node=pad_node,
|
472
487
|
op2d_node=op2d_node)
|
473
488
|
|
474
|
-
if
|
475
|
-
activation_param =
|
476
|
-
|
477
|
-
|
478
|
-
out_stats_container=graph.get_out_stats_collector(
|
479
|
-
non_linear_node))
|
489
|
+
if non_linear_node_cfg_candidate.shift_negative_threshold_recalculation:
|
490
|
+
activation_param = get_activations_qparams(activation_quant_cfg=non_linear_node_cfg_candidate,
|
491
|
+
nodes_prior_info=non_linear_node.prior_info,
|
492
|
+
out_stats_container=graph.get_out_stats_collector(non_linear_node))
|
480
493
|
|
481
|
-
assert activation_param.get(SIGNED)
|
494
|
+
assert activation_param.get(SIGNED) == False
|
482
495
|
for candidate_qc in non_linear_node.candidates_quantization_cfg:
|
483
496
|
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_param)
|
484
497
|
|
@@ -560,6 +573,7 @@ def get_next_nodes_to_correct(n: BaseNode,
|
|
560
573
|
|
561
574
|
def apply_shift_negative_correction(graph: Graph,
|
562
575
|
core_config: CoreConfig,
|
576
|
+
fw_info: FrameworkInfo,
|
563
577
|
snc_node_types: NodeOperationMatcher,
|
564
578
|
linear_node_types: NodeOperationMatcher,
|
565
579
|
bypass_node_types: NodeOperationMatcher,
|
@@ -571,7 +585,6 @@ def apply_shift_negative_correction(graph: Graph,
|
|
571
585
|
padding_str: str,
|
572
586
|
bias_str: str,
|
573
587
|
bias_flag_str: str,
|
574
|
-
get_activation_quantization_fn_factory: Callable,
|
575
588
|
params_search_quantization_fn: Callable=None) -> Graph:
|
576
589
|
"""
|
577
590
|
Apply the substitution even if the linear node is not immediately after
|
@@ -580,6 +593,7 @@ def apply_shift_negative_correction(graph: Graph,
|
|
580
593
|
Args:
|
581
594
|
graph: Graph to apply the substitution on.
|
582
595
|
core_config: Quantization configuration to build the substitutions list according to.
|
596
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
583
597
|
groups of layers by how they should be quantized, etc.)
|
584
598
|
snc_node_types: Types of activation nodes with negative outputs to consider.
|
585
599
|
linear_node_types: Types of linear nodes to consider.
|
@@ -593,9 +607,6 @@ def apply_shift_negative_correction(graph: Graph,
|
|
593
607
|
padding_str: The framework specific attribute name of the padding.
|
594
608
|
bias_str: The framework specific attribute name of the bias.
|
595
609
|
bias_flag_str: The framework specific attribute name of the bias flag.
|
596
|
-
get_activation_quantization_fn_factory: activation quantization functions factory.
|
597
|
-
params_search_quantization_fn: Function to quantize np tensor using a framework (tf/torch) quantization method. Needed for better param_search estimating the expected loss.
|
598
|
-
|
599
610
|
Returns:
|
600
611
|
Graph after applying shift negative on selected activations.
|
601
612
|
"""
|
@@ -603,8 +614,9 @@ def apply_shift_negative_correction(graph: Graph,
|
|
603
614
|
nodes = list(graph.nodes())
|
604
615
|
for n in nodes:
|
605
616
|
# Skip substitution if QuantizationMethod is uniform.
|
606
|
-
|
607
|
-
|
617
|
+
node_qco = n.get_qco(graph.fqc)
|
618
|
+
if any([op_qc.activation_quantization_method is QuantizationMethod.UNIFORM
|
619
|
+
for op_qc in node_qco.quantization_configurations]):
|
608
620
|
continue
|
609
621
|
|
610
622
|
if snc_node_types.apply(n):
|
@@ -620,13 +632,13 @@ def apply_shift_negative_correction(graph: Graph,
|
|
620
632
|
core_config,
|
621
633
|
n,
|
622
634
|
linear_node,
|
635
|
+
fw_info,
|
623
636
|
create_add_node,
|
624
637
|
get_padding_values,
|
625
638
|
create_pad_node,
|
626
639
|
padding_str,
|
627
640
|
bias_str,
|
628
641
|
bias_flag_str,
|
629
|
-
get_activation_quantization_fn_factory,
|
630
642
|
zero_padding_node=pad_node,
|
631
643
|
bypass_nodes=bypass_nodes,
|
632
644
|
params_search_quantization_fn=params_search_quantization_fn)
|
model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py
CHANGED
@@ -50,7 +50,9 @@ class BaseVirtualActivationWeightsComposition(BaseSubstitution):
|
|
50
50
|
return graph
|
51
51
|
|
52
52
|
# Virtual composed activation-weights node
|
53
|
-
v_node = VirtualActivationWeightsNode(act_node,
|
53
|
+
v_node = VirtualActivationWeightsNode(act_node,
|
54
|
+
weights_node,
|
55
|
+
fw_info=graph.fw_info)
|
54
56
|
|
55
57
|
# Update graph
|
56
58
|
graph.add_node(v_node)
|
@@ -50,7 +50,7 @@ class BaseWeightsActivationSplit(BaseSubstitution):
|
|
50
50
|
Graph after applying the substitution.
|
51
51
|
"""
|
52
52
|
# The decomposition works on linear nodes, that is, nodes with kernel ops
|
53
|
-
kernel_attr = node.
|
53
|
+
kernel_attr = graph.fw_info.get_kernel_op_attributes(node.type)[0]
|
54
54
|
if kernel_attr is None:
|
55
55
|
Logger.critical(f"Trying to split node weights and activation, but node "
|
56
56
|
f"{node.name} doesn't have a kernel attribute.")
|
@@ -59,19 +59,22 @@ class NNVisualizer:
|
|
59
59
|
def __init__(self,
|
60
60
|
graph_float: Graph,
|
61
61
|
graph_quantized: Graph,
|
62
|
-
fw_impl: FrameworkImplementation
|
62
|
+
fw_impl: FrameworkImplementation,
|
63
|
+
fw_info: FrameworkInfo):
|
63
64
|
"""
|
64
65
|
Initialize a NNVisualizer object.
|
65
66
|
Args:
|
66
67
|
graph_float: Float version of the graph.
|
67
68
|
graph_quantized: Quantized version of the graph.
|
68
69
|
fw_impl: Framework implementation with framework-specific methods implementation.
|
70
|
+
fw_info: Framework info with framework-specific information.
|
69
71
|
|
70
72
|
"""
|
71
73
|
|
72
74
|
self.graph_float = graph_float
|
73
75
|
self.graph_quantized = graph_quantized
|
74
76
|
self.fw_impl = fw_impl
|
77
|
+
self.fw_info = fw_info
|
75
78
|
|
76
79
|
# Get compare points of two graphs.
|
77
80
|
self.compare_points, self.compare_points_name = _get_compare_points(self.graph_quantized)
|
@@ -89,11 +92,13 @@ class NNVisualizer:
|
|
89
92
|
|
90
93
|
self.quantized_model, _ = self.fw_impl.model_builder(self.graph_quantized,
|
91
94
|
mode=ModelBuilderMode.QUANTIZED,
|
92
|
-
append2output=self.compare_points
|
95
|
+
append2output=self.compare_points,
|
96
|
+
fw_info=self.fw_info)
|
93
97
|
|
94
98
|
self.float_model, _ = self.fw_impl.model_builder(self.graph_float,
|
95
99
|
mode=ModelBuilderMode.FLOAT,
|
96
|
-
append2output=self.compare_points_float
|
100
|
+
append2output=self.compare_points_float,
|
101
|
+
fw_info=self.fw_info)
|
97
102
|
|
98
103
|
def has_compare_points(self) -> bool:
|
99
104
|
"""
|
@@ -89,18 +89,20 @@ class TensorboardWriter(object):
|
|
89
89
|
Class to log events to display using Tensorboard such as graphs, histograms, images, etc.
|
90
90
|
"""
|
91
91
|
|
92
|
-
def __init__(self, dir_path: str):
|
92
|
+
def __init__(self, dir_path: str, fw_info: FrameworkInfo):
|
93
93
|
"""
|
94
94
|
Initialize a TensorboardWriter object.
|
95
95
|
|
96
96
|
Args:
|
97
97
|
dir_path: Path to save all events to display on Tensorboard.
|
98
|
+
fw_info: FrameworkInfo object (needed for computing nodes' weights memory).
|
98
99
|
|
99
100
|
"""
|
100
101
|
self.dir_path = dir_path
|
101
102
|
# we hold EventWriter per tag name, so events can be gathered by tags (like phases during the quantization
|
102
103
|
# process).
|
103
104
|
self.tag_name_to_event_writer = {}
|
105
|
+
self.fw_info = fw_info
|
104
106
|
|
105
107
|
def close(self):
|
106
108
|
"""
|
@@ -207,7 +209,7 @@ class TensorboardWriter(object):
|
|
207
209
|
attr = dict()
|
208
210
|
if n.final_activation_quantization_cfg is not None:
|
209
211
|
attr.update(n.final_activation_quantization_cfg.__dict__)
|
210
|
-
elif n.
|
212
|
+
elif n.candidates_quantization_cfg is not None:
|
211
213
|
attr.update(n.get_unified_activation_candidates_dict())
|
212
214
|
return attr
|
213
215
|
|
@@ -229,8 +231,8 @@ class TensorboardWriter(object):
|
|
229
231
|
attr = dict()
|
230
232
|
if n.final_weights_quantization_cfg is not None:
|
231
233
|
attr.update(n.final_weights_quantization_cfg.__dict__)
|
232
|
-
elif n.
|
233
|
-
attr.update(n.get_unified_weights_candidates_dict())
|
234
|
+
elif n.candidates_quantization_cfg is not None:
|
235
|
+
attr.update(n.get_unified_weights_candidates_dict(self.fw_info))
|
234
236
|
return attr
|
235
237
|
|
236
238
|
def __get_node_attr(n: BaseNode) -> Dict[str, Any]:
|
@@ -294,7 +296,7 @@ class TensorboardWriter(object):
|
|
294
296
|
|
295
297
|
return NodeExecStats(node_name=n.name,
|
296
298
|
memory=[AllocatorMemoryUsed(
|
297
|
-
total_bytes=int(n.get_memory_bytes())
|
299
|
+
total_bytes=int(n.get_memory_bytes(self.fw_info))
|
298
300
|
)])
|
299
301
|
|
300
302
|
graph_def = GraphDef() # GraphDef to add to Tensorboard
|
@@ -524,12 +526,14 @@ class TensorboardWriter(object):
|
|
524
526
|
er.add_event(event)
|
525
527
|
er.flush()
|
526
528
|
|
527
|
-
|
528
|
-
def init_tensorboard_writer() -> TensorboardWriter:
|
529
|
+
def init_tensorboard_writer(fw_info: FrameworkInfo) -> TensorboardWriter:
|
529
530
|
"""
|
530
531
|
Create a TensorBoardWriter object initialized with the logger dir path if it was set,
|
531
532
|
or None otherwise.
|
532
533
|
|
534
|
+
Args:
|
535
|
+
fw_info: FrameworkInfo object.
|
536
|
+
|
533
537
|
Returns:
|
534
538
|
A TensorBoardWriter object.
|
535
539
|
"""
|
@@ -537,7 +541,7 @@ def init_tensorboard_writer() -> TensorboardWriter:
|
|
537
541
|
if Logger.LOG_PATH is not None:
|
538
542
|
tb_log_dir = os.path.join(os.getcwd(), Logger.LOG_PATH, 'tensorboard_logs')
|
539
543
|
Logger.info(f'To use Tensorboard, please run: tensorboard --logdir {tb_log_dir}')
|
540
|
-
tb_w = TensorboardWriter(tb_log_dir)
|
544
|
+
tb_w = TensorboardWriter(tb_log_dir, fw_info)
|
541
545
|
return tb_w
|
542
546
|
|
543
547
|
|
@@ -16,27 +16,28 @@
|
|
16
16
|
|
17
17
|
from typing import Callable, Any
|
18
18
|
|
19
|
+
from model_compression_toolkit.core.common import FrameworkInfo
|
19
20
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
21
|
+
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator
|
20
22
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
21
23
|
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
|
22
24
|
from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates
|
23
|
-
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
24
|
-
QuantizationErrorMethod
|
25
|
+
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
25
26
|
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
26
|
-
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import
|
27
|
+
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import \
|
28
|
+
set_quantization_configuration_to_graph
|
27
29
|
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
|
28
30
|
from model_compression_toolkit.core.common.substitutions.linear_collapsing_substitution import \
|
29
31
|
linear_collapsing_substitute
|
30
32
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter
|
31
|
-
from model_compression_toolkit.quantization_preparation.load_fqc import load_fqc_configuration
|
32
33
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
33
34
|
FrameworkQuantizationCapabilities
|
34
|
-
from model_compression_toolkit.logger import Logger
|
35
35
|
|
36
36
|
|
37
37
|
def graph_preparation_runner(in_model: Any,
|
38
38
|
representative_data_gen: Callable,
|
39
39
|
quantization_config: QuantizationConfig,
|
40
|
+
fw_info: FrameworkInfo,
|
40
41
|
fw_impl: FrameworkImplementation,
|
41
42
|
fqc: FrameworkQuantizationCapabilities,
|
42
43
|
bit_width_config: BitWidthConfig = None,
|
@@ -55,6 +56,8 @@ def graph_preparation_runner(in_model: Any,
|
|
55
56
|
in_model (Any): Model to quantize.
|
56
57
|
representative_data_gen (Callable): Dataset used for calibration.
|
57
58
|
quantization_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be quantized.
|
59
|
+
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
60
|
+
groups of layers by how they should be quantized, etc.).
|
58
61
|
fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
|
59
62
|
fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities object that models the inference target platform and
|
60
63
|
the attached framework operator's information.
|
@@ -70,6 +73,7 @@ def graph_preparation_runner(in_model: Any,
|
|
70
73
|
graph = read_model_to_graph(in_model,
|
71
74
|
representative_data_gen,
|
72
75
|
fqc,
|
76
|
+
fw_info,
|
73
77
|
fw_impl)
|
74
78
|
|
75
79
|
if tb_w is not None:
|
@@ -79,6 +83,7 @@ def graph_preparation_runner(in_model: Any,
|
|
79
83
|
fqc,
|
80
84
|
quantization_config,
|
81
85
|
bit_width_config,
|
86
|
+
fw_info,
|
82
87
|
tb_w,
|
83
88
|
fw_impl,
|
84
89
|
mixed_precision_enable=mixed_precision_enable,
|
@@ -91,6 +96,7 @@ def get_finalized_graph(initial_graph: Graph,
|
|
91
96
|
fqc: FrameworkQuantizationCapabilities,
|
92
97
|
quant_config: QuantizationConfig = DEFAULTCONFIG,
|
93
98
|
bit_width_config: BitWidthConfig = None,
|
99
|
+
fw_info: FrameworkInfo = None,
|
94
100
|
tb_w: TensorboardWriter = None,
|
95
101
|
fw_impl: FrameworkImplementation = None,
|
96
102
|
mixed_precision_enable: bool = False,
|
@@ -105,6 +111,8 @@ def get_finalized_graph(initial_graph: Graph,
|
|
105
111
|
quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be
|
106
112
|
quantized.
|
107
113
|
bit_width_config (BitWidthConfig): Config for bit-width selection. Defaults to None.
|
114
|
+
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g.,
|
115
|
+
kernel channels indices, groups of layers by how they should be quantized, etc.)
|
108
116
|
tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
|
109
117
|
fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
|
110
118
|
mixed_precision_enable: is mixed precision enabled.
|
@@ -112,17 +120,11 @@ def get_finalized_graph(initial_graph: Graph,
|
|
112
120
|
|
113
121
|
Returns: Graph object that represents the model, after applying all required modifications to it.
|
114
122
|
"""
|
115
|
-
if quant_config.weights_error_method == QuantizationErrorMethod.HMSE:
|
116
|
-
if not running_gptq:
|
117
|
-
raise ValueError(f"The HMSE error method for parameters selection is only supported when running GPTQ "
|
118
|
-
f"optimization due to long execution time that is not suitable for basic PTQ.")
|
119
|
-
Logger.warning("Using the HMSE error method for weights quantization parameters search. "
|
120
|
-
"Note: This method may significantly increase runtime during the parameter search process.")
|
121
123
|
|
122
124
|
######################################
|
123
125
|
# Graph substitution (prepare graph)
|
124
126
|
######################################
|
125
|
-
graph = substitute(initial_graph, fw_impl.get_substitutions_prepare_graph())
|
127
|
+
graph = substitute(initial_graph, fw_impl.get_substitutions_prepare_graph(fw_info))
|
126
128
|
|
127
129
|
if tb_w is not None:
|
128
130
|
tb_w.add_graph(graph, 'after_graph_preparation')
|
@@ -132,6 +134,7 @@ def get_finalized_graph(initial_graph: Graph,
|
|
132
134
|
##########################################
|
133
135
|
for node in graph.nodes:
|
134
136
|
node.prior_info = fw_impl.get_node_prior_info(node=node,
|
137
|
+
fw_info=fw_info,
|
135
138
|
graph=graph)
|
136
139
|
|
137
140
|
##################################################
|
@@ -147,22 +150,28 @@ def get_finalized_graph(initial_graph: Graph,
|
|
147
150
|
if tb_w is not None:
|
148
151
|
tb_w.add_graph(transformed_graph, 'pre_statistics_collection_substitutions')
|
149
152
|
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
153
|
+
######################################
|
154
|
+
# Add quantization configurations
|
155
|
+
######################################
|
156
|
+
transformed_graph = set_quantization_configuration_to_graph(graph=transformed_graph,
|
157
|
+
quant_config=quant_config,
|
158
|
+
bit_width_config=bit_width_config,
|
159
|
+
mixed_precision_enable=mixed_precision_enable,
|
160
|
+
running_gptq=running_gptq)
|
155
161
|
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
162
|
+
######################################
|
163
|
+
# Layer fusing
|
164
|
+
######################################
|
165
|
+
fusing_info = FusingInfoGenerator(fqc.get_fusing_patterns()).generate_fusing_info(transformed_graph)
|
166
|
+
transformed_graph.fusing_info = fusing_info
|
167
|
+
transformed_graph.disable_fused_nodes_activation_quantization()
|
160
168
|
|
161
169
|
######################################
|
162
170
|
# Channel equalization
|
163
171
|
######################################
|
164
172
|
transformed_graph = substitute(transformed_graph,
|
165
|
-
fw_impl.get_substitutions_channel_equalization(quant_config
|
173
|
+
fw_impl.get_substitutions_channel_equalization(quant_config,
|
174
|
+
fw_info))
|
166
175
|
|
167
176
|
if tb_w is not None:
|
168
177
|
tb_w.add_graph(transformed_graph, 'after_graph_marking')
|
@@ -181,6 +190,7 @@ def get_finalized_graph(initial_graph: Graph,
|
|
181
190
|
def read_model_to_graph(in_model: Any,
|
182
191
|
representative_data_gen: Callable,
|
183
192
|
fqc: FrameworkQuantizationCapabilities,
|
193
|
+
fw_info: FrameworkInfo = None,
|
184
194
|
fw_impl: FrameworkImplementation = None) -> Graph:
|
185
195
|
|
186
196
|
"""
|
@@ -191,6 +201,8 @@ def read_model_to_graph(in_model: Any,
|
|
191
201
|
representative_data_gen: Dataset used for calibration.
|
192
202
|
fqc: FrameworkQuantizationCapabilities object that models the inference target platform and
|
193
203
|
the attached framework operator's information.
|
204
|
+
fw_info: Information needed for quantization about the specific framework (e.g.,
|
205
|
+
kernel channels indices, groups of layers by how they should be quantized, etc.)
|
194
206
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
195
207
|
|
196
208
|
Returns:
|
@@ -198,5 +210,6 @@ def read_model_to_graph(in_model: Any,
|
|
198
210
|
"""
|
199
211
|
graph = fw_impl.model_reader(in_model,
|
200
212
|
representative_data_gen)
|
213
|
+
graph.set_fw_info(fw_info)
|
201
214
|
graph.set_fqc(fqc)
|
202
215
|
return graph
|
@@ -17,6 +17,7 @@ from typing import List
|
|
17
17
|
from model_compression_toolkit.core import FrameworkInfo
|
18
18
|
from model_compression_toolkit.core.common import BaseNode
|
19
19
|
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
20
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
20
21
|
from model_compression_toolkit.core import common
|
21
22
|
from tensorflow.python.util.object_identity import Reference as TFReference
|
22
23
|
|
@@ -28,17 +29,20 @@ class FloatKerasModelBuilder(KerasModelBuilder):
|
|
28
29
|
def __init__(self,
|
29
30
|
graph: common.Graph,
|
30
31
|
append2output=None,
|
32
|
+
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
|
31
33
|
return_float_outputs: bool = False):
|
32
34
|
"""
|
33
35
|
|
34
36
|
Args:
|
35
37
|
graph: Graph to build the model from.
|
36
38
|
append2output: Nodes to append to model's output.
|
39
|
+
fw_info: Information about the specific framework of the model that is built.
|
37
40
|
return_float_outputs: Whether the model returns float tensors or not.
|
38
41
|
"""
|
39
42
|
|
40
43
|
super().__init__(graph,
|
41
44
|
append2output,
|
45
|
+
fw_info,
|
42
46
|
return_float_outputs)
|
43
47
|
|
44
48
|
def _quantize_node_activations(self,
|