mct-nightly 2.4.0.20250630.629__py3-none-any.whl → 2.4.0.20250701.185106__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/METADATA +16 -16
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/RECORD +75 -72
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -1
- model_compression_toolkit/core/common/framework_info.py +5 -32
- model_compression_toolkit/core/common/fusion/graph_fuser.py +12 -9
- model_compression_toolkit/core/common/graph/base_graph.py +20 -37
- model_compression_toolkit/core/common/graph/base_node.py +13 -106
- model_compression_toolkit/core/common/graph/functional_node.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +12 -10
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +14 -9
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +9 -15
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +2 -3
- model_compression_toolkit/core/common/network_editors/__init__.py +8 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -96
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +116 -56
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +55 -179
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +21 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +8 -5
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -70
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +10 -12
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +54 -30
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +93 -398
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +2 -5
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -4
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -6
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +12 -6
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -2
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +33 -33
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +2 -4
- model_compression_toolkit/core/graph_prep_runner.py +31 -20
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +5 -2
- model_compression_toolkit/core/keras/default_framework_info.py +0 -11
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +9 -6
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +3 -1
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +2 -1
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +47 -0
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +5 -2
- model_compression_toolkit/core/pytorch/default_framework_info.py +0 -12
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +2 -0
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +2 -1
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +45 -0
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +3 -2
- model_compression_toolkit/core/runner.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +7 -3
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +12 -3
- model_compression_toolkit/pruning/keras/pruning_facade.py +5 -9
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -5
- model_compression_toolkit/ptq/keras/quantization_facade.py +1 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +1 -1
- model_compression_toolkit/quantization_preparation/__init__.py +14 -0
- model_compression_toolkit/quantization_preparation/load_fqc.py +223 -0
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -78
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250630.629.dist-info → mct_nightly-2.4.0.20250701.185106.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantizer → quantization}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantizer → quantization}/lut_fake_quant.py +0 -0
|
@@ -18,7 +18,7 @@ from typing import Any, Callable
|
|
|
18
18
|
from model_compression_toolkit.core import QuantizationConfig
|
|
19
19
|
from model_compression_toolkit.core.common import BaseNode, Graph
|
|
20
20
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
21
|
-
from model_compression_toolkit.core.common.
|
|
21
|
+
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantization_fn
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
def get_previous_node_with_activation_quantization(linear_node: BaseNode,
|
|
@@ -67,7 +67,8 @@ def compute_activation_bias_correction(graph: Graph,
|
|
|
67
67
|
fw_impl: FrameworkImplementation,
|
|
68
68
|
linear_node: BaseNode,
|
|
69
69
|
prev_node: BaseNode,
|
|
70
|
-
kernel_size: str
|
|
70
|
+
kernel_size: str,
|
|
71
|
+
get_activation_quantization_fn_factory: Callable) -> Graph:
|
|
71
72
|
"""
|
|
72
73
|
Compute the activation bias correction term, and store it in the final activation
|
|
73
74
|
quantization configuration.
|
|
@@ -79,6 +80,7 @@ def compute_activation_bias_correction(graph: Graph,
|
|
|
79
80
|
linear_node: Node to compute the activation bias correction for.
|
|
80
81
|
prev_node: Node to compute the activation error caused by his activation quantization.
|
|
81
82
|
kernel_size: The framework specific attribute name of the convolution layer's kernel size.
|
|
83
|
+
get_activation_quantization_fn_factory: activation quantization functions factory.
|
|
82
84
|
|
|
83
85
|
Returns:
|
|
84
86
|
Graph with activation bias correction term for each node.
|
|
@@ -105,7 +107,9 @@ def compute_activation_bias_correction(graph: Graph,
|
|
|
105
107
|
float_centers = calculate_bin_centers(float_bins)
|
|
106
108
|
|
|
107
109
|
# Quantize the bin edges and calculate the centers of the quantized bins
|
|
108
|
-
|
|
110
|
+
activation_quantizer = get_activation_quantization_fn(prev_node_act_quant_cfg,
|
|
111
|
+
get_activation_quantization_fn_factory)
|
|
112
|
+
quant_bins = activation_quantizer(fw_impl.to_tensor(float_bins))
|
|
109
113
|
quant_bins = fw_impl.to_numpy(quant_bins)
|
|
110
114
|
quant_centers = calculate_bin_centers(quant_bins)
|
|
111
115
|
|
|
@@ -149,7 +153,8 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
|
|
|
149
153
|
quant_config: QuantizationConfig,
|
|
150
154
|
fw_impl: FrameworkImplementation,
|
|
151
155
|
activation_bias_correction_node_matchers: Callable,
|
|
152
|
-
kernel_size: str
|
|
156
|
+
kernel_size: str,
|
|
157
|
+
get_activation_quantization_fn_factory: Callable) -> Graph:
|
|
153
158
|
"""
|
|
154
159
|
Compute the activation bias correction term for the graph.
|
|
155
160
|
|
|
@@ -159,7 +164,7 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
|
|
|
159
164
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
|
160
165
|
activation_bias_correction_node_matchers: Function to match the layers for activation bias correction.
|
|
161
166
|
kernel_size: The framework specific attribute name of the convolution layer's kernel size.
|
|
162
|
-
|
|
167
|
+
get_activation_quantization_fn_factory: activation quantization functions factory.
|
|
163
168
|
|
|
164
169
|
Returns:
|
|
165
170
|
Graph with activation bias correction term for each relevant node.
|
|
@@ -175,5 +180,6 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
|
|
|
175
180
|
fw_impl=fw_impl,
|
|
176
181
|
linear_node=n,
|
|
177
182
|
prev_node=prev_node,
|
|
178
|
-
kernel_size=kernel_size
|
|
183
|
+
kernel_size=kernel_size,
|
|
184
|
+
get_activation_quantization_fn_factory=get_activation_quantization_fn_factory)
|
|
179
185
|
return graph
|
model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py
CHANGED
|
@@ -43,7 +43,7 @@ def compute_bias_correction_of_graph(graph: Graph,
|
|
|
43
43
|
for n in graph.nodes:
|
|
44
44
|
# Bias correction is computed based on the quantized kernel, so we need to get the specific kernel attribute
|
|
45
45
|
# name out of all the weights attributes of the node.
|
|
46
|
-
if n.
|
|
46
|
+
if n.kernel_attr:
|
|
47
47
|
if n.is_weights_quantization_enabled(n.kernel_attr):
|
|
48
48
|
# Bias correction is not applied to layers with constant inputs.
|
|
49
49
|
if n.has_positional_weights:
|
|
@@ -124,7 +124,7 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
|
|
|
124
124
|
|
|
125
125
|
bn_node.prior_info = copy.deepcopy(source_node.prior_info)
|
|
126
126
|
|
|
127
|
-
bn_node.
|
|
127
|
+
bn_node.quantization_cfg = copy.deepcopy(source_node.quantization_cfg)
|
|
128
128
|
|
|
129
129
|
for qc in bn_node.candidates_quantization_cfg:
|
|
130
130
|
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
|
@@ -139,7 +139,6 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
|
|
|
139
139
|
# reconstructed node BN attributes need to be quantized and how.
|
|
140
140
|
qc.weights_quantization_cfg.set_attr_config(attr,
|
|
141
141
|
WeightsAttrQuantizationConfig(
|
|
142
|
-
QuantizationConfig(),
|
|
143
142
|
AttributeQuantizationConfig(
|
|
144
143
|
enable_weights_quantization=False)))
|
|
145
144
|
|
|
@@ -16,21 +16,20 @@ import copy
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
from typing import List, Tuple, Any, Callable
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
|
20
19
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
|
|
21
20
|
ActivationQuantizationMode
|
|
22
21
|
from model_compression_toolkit.logger import Logger
|
|
23
|
-
from model_compression_toolkit.core.common import
|
|
22
|
+
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
24
23
|
from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
|
|
25
24
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
26
|
-
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import create_node_activation_qc, \
|
|
27
|
-
set_quantization_configs_to_node
|
|
28
25
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
|
29
26
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
|
|
30
|
-
import
|
|
27
|
+
import compute_activation_qparams
|
|
31
28
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
|
|
32
29
|
_mse_error_histogram
|
|
33
30
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation import z_score_filter
|
|
31
|
+
from model_compression_toolkit.quantization_preparation.load_fqc import set_quantization_configs_to_node, \
|
|
32
|
+
fetch_qc_options_for_node
|
|
34
33
|
from model_compression_toolkit.target_platform_capabilities import QuantizationMethod, AttributeQuantizationConfig
|
|
35
34
|
|
|
36
35
|
"""
|
|
@@ -67,8 +66,7 @@ def op2d_bias_correction(op2d_node: BaseNode,
|
|
|
67
66
|
# Add an attribute quantization configuration to the newly added bias attribute, with disabled quantization
|
|
68
67
|
for qc in op2d_node.candidates_quantization_cfg:
|
|
69
68
|
qc.weights_quantization_cfg.set_attr_config(bias_flag_str,
|
|
70
|
-
WeightsAttrQuantizationConfig(
|
|
71
|
-
AttributeQuantizationConfig(
|
|
69
|
+
WeightsAttrQuantizationConfig(AttributeQuantizationConfig(
|
|
72
70
|
enable_weights_quantization=False)))
|
|
73
71
|
|
|
74
72
|
# Each node adds a different noise due to the shifting. It depends on the
|
|
@@ -253,6 +251,7 @@ def shift_negative_function(graph: Graph,
|
|
|
253
251
|
padding_str: str,
|
|
254
252
|
bias_str: str,
|
|
255
253
|
bias_flag_str: str,
|
|
254
|
+
get_activation_quantization_fn_factory: Callable,
|
|
256
255
|
zero_padding_node: BaseNode = None,
|
|
257
256
|
bypass_nodes: List = None,
|
|
258
257
|
params_search_quantization_fn: Callable = None
|
|
@@ -278,6 +277,7 @@ def shift_negative_function(graph: Graph,
|
|
|
278
277
|
padding_str: The framework specific attribute name of the padding.
|
|
279
278
|
bias_str: The framework specific attribute name of the bias.
|
|
280
279
|
bias_flag_str: The framework specific attribute name of the bias flag.
|
|
280
|
+
get_activation_quantization_fn_factory: activation quantization functions factory.
|
|
281
281
|
zero_padding_node: ZeroPadding2D node that may be in the graph before the linear layer.
|
|
282
282
|
params_search_quantization_fn: Function to quantize np tensor using a framework (tf/torch) quantization method. Needed for better param_search estimating the expected loss.
|
|
283
283
|
|
|
@@ -327,13 +327,15 @@ def shift_negative_function(graph: Graph,
|
|
|
327
327
|
'float32') # Change to type float32 to support tensorflow dtypes
|
|
328
328
|
for _shift_value in _q_points:
|
|
329
329
|
_hist_bins = hist_bins.astype(np.float32) + _shift_value
|
|
330
|
-
|
|
330
|
+
quantizer_factory = get_activation_quantization_fn_factory(
|
|
331
|
+
non_linear_node_cfg_candidate.activation_quantization_method)
|
|
332
|
+
fw_quant_fn = quantizer_factory(non_linear_node_cfg_candidate.activation_n_bits, qparams)
|
|
331
333
|
"""
|
|
332
334
|
In SNC, when better shifting values are tested for better choice,
|
|
333
335
|
the histogram (which is a numpy object) is quantized using the non-linear node activation
|
|
334
336
|
quantization function (to estimate the expected mse comparing to the original histogram).
|
|
335
337
|
The quantization function is a framework function, which makes it fail since it
|
|
336
|
-
expects a fw tensor. The
|
|
338
|
+
expects a fw tensor. The common part of SNC receives an argument which is a callable
|
|
337
339
|
that receives two argument and returns one: it gets the fw activation quantization function
|
|
338
340
|
and the bins to quantize. The function (of each fw) responsible for doing (if needed) a preprocessing and postprocessing
|
|
339
341
|
to the bins which is a numpy object.
|
|
@@ -395,9 +397,7 @@ def shift_negative_function(graph: Graph,
|
|
|
395
397
|
|
|
396
398
|
set_quantization_configs_to_node(node=add_node,
|
|
397
399
|
graph=graph,
|
|
398
|
-
|
|
399
|
-
fqc=graph.fqc,
|
|
400
|
-
mixed_precision_enable=core_config.is_mixed_precision_enabled)
|
|
400
|
+
fqc=graph.fqc)
|
|
401
401
|
|
|
402
402
|
update_fused_op_with_add(graph=graph,
|
|
403
403
|
non_linear_node=non_linear_node,
|
|
@@ -421,9 +421,7 @@ def shift_negative_function(graph: Graph,
|
|
|
421
421
|
# Set quantization configuration to node, even though we do not quantize it:
|
|
422
422
|
set_quantization_configs_to_node(node=pad_node,
|
|
423
423
|
graph=graph,
|
|
424
|
-
|
|
425
|
-
fqc=graph.fqc,
|
|
426
|
-
mixed_precision_enable=core_config.is_mixed_precision_enabled)
|
|
424
|
+
fqc=graph.fqc)
|
|
427
425
|
|
|
428
426
|
for candidate_qc in pad_node.candidates_quantization_cfg:
|
|
429
427
|
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
|
@@ -448,7 +446,7 @@ def shift_negative_function(graph: Graph,
|
|
|
448
446
|
bypass_candidate_qc.activation_quantization_cfg.activation_quantization_params[SIGNED] = False
|
|
449
447
|
graph.shift_stats_collector(bypass_node, np.array(shift_value))
|
|
450
448
|
|
|
451
|
-
add_node_qco = add_node
|
|
449
|
+
add_node_qco = fetch_qc_options_for_node(add_node, graph.fqc).quantization_configurations
|
|
452
450
|
add_supported_bitwidths = [c.activation_n_bits for c in add_node_qco]
|
|
453
451
|
if original_non_linear_activation_nbits not in add_supported_bitwidths:
|
|
454
452
|
raise ValueError(
|
|
@@ -456,18 +454,16 @@ def shift_negative_function(graph: Graph,
|
|
|
456
454
|
f"bitwidth is {original_non_linear_activation_nbits}. Consider adapting the TPC so 'Add' will support the "
|
|
457
455
|
f"same bitwidth as {non_linear_node.type} or disable shift negative correction.")
|
|
458
456
|
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False
|
|
457
|
+
set_quantization_configs_to_node(add_node, graph, graph.fqc)
|
|
458
|
+
# TODO: do we not quantize the weights of this 'add' on purpose?
|
|
459
|
+
add_node.quantization_cfg.disable_weights_quantization()
|
|
463
460
|
|
|
464
|
-
|
|
465
|
-
|
|
461
|
+
def update(c):
|
|
462
|
+
c.activation_quantization_cfg.activation_n_bits = original_non_linear_activation_nbits
|
|
463
|
+
c.activation_quantization_cfg.set_activation_quantization_param({THRESHOLD: activation_threshold,
|
|
464
|
+
SIGNED: False})
|
|
466
465
|
|
|
467
|
-
|
|
468
|
-
SIGNED: False})
|
|
469
|
-
|
|
470
|
-
candidate_qc.activation_quantization_cfg.activation_n_bits = original_non_linear_activation_nbits
|
|
466
|
+
add_node.quantization_cfg.update_all(update, remove_duplicates=True)
|
|
471
467
|
|
|
472
468
|
# Add the new padding node to a fused op with the op2d.
|
|
473
469
|
if pad_node:
|
|
@@ -476,11 +472,11 @@ def shift_negative_function(graph: Graph,
|
|
|
476
472
|
op2d_node=op2d_node)
|
|
477
473
|
|
|
478
474
|
if non_linear_node_cfg_candidate.shift_negative_threshold_recalculation:
|
|
479
|
-
activation_param =
|
|
480
|
-
|
|
481
|
-
|
|
475
|
+
activation_param = compute_activation_qparams(activation_quant_cfg=non_linear_node_cfg_candidate,
|
|
476
|
+
node_prior_info=non_linear_node.prior_info,
|
|
477
|
+
out_stats_container=graph.get_out_stats_collector(non_linear_node))
|
|
482
478
|
|
|
483
|
-
assert activation_param.get(SIGNED)
|
|
479
|
+
assert activation_param.get(SIGNED) is False
|
|
484
480
|
for candidate_qc in non_linear_node.candidates_quantization_cfg:
|
|
485
481
|
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_param)
|
|
486
482
|
|
|
@@ -573,6 +569,7 @@ def apply_shift_negative_correction(graph: Graph,
|
|
|
573
569
|
padding_str: str,
|
|
574
570
|
bias_str: str,
|
|
575
571
|
bias_flag_str: str,
|
|
572
|
+
get_activation_quantization_fn_factory: Callable,
|
|
576
573
|
params_search_quantization_fn: Callable=None) -> Graph:
|
|
577
574
|
"""
|
|
578
575
|
Apply the substitution even if the linear node is not immediately after
|
|
@@ -594,6 +591,9 @@ def apply_shift_negative_correction(graph: Graph,
|
|
|
594
591
|
padding_str: The framework specific attribute name of the padding.
|
|
595
592
|
bias_str: The framework specific attribute name of the bias.
|
|
596
593
|
bias_flag_str: The framework specific attribute name of the bias flag.
|
|
594
|
+
get_activation_quantization_fn_factory: activation quantization functions factory.
|
|
595
|
+
params_search_quantization_fn: Function to quantize np tensor using a framework (tf/torch) quantization method. Needed for better param_search estimating the expected loss.
|
|
596
|
+
|
|
597
597
|
Returns:
|
|
598
598
|
Graph after applying shift negative on selected activations.
|
|
599
599
|
"""
|
|
@@ -601,9 +601,8 @@ def apply_shift_negative_correction(graph: Graph,
|
|
|
601
601
|
nodes = list(graph.nodes())
|
|
602
602
|
for n in nodes:
|
|
603
603
|
# Skip substitution if QuantizationMethod is uniform.
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
for op_qc in node_qco.quantization_configurations]):
|
|
604
|
+
if any(aqc.activation_quantization_cfg.activation_quantization_method == QuantizationMethod.UNIFORM
|
|
605
|
+
for aqc in n.candidates_quantization_cfg):
|
|
607
606
|
continue
|
|
608
607
|
|
|
609
608
|
if snc_node_types.apply(n):
|
|
@@ -625,6 +624,7 @@ def apply_shift_negative_correction(graph: Graph,
|
|
|
625
624
|
padding_str,
|
|
626
625
|
bias_str,
|
|
627
626
|
bias_flag_str,
|
|
627
|
+
get_activation_quantization_fn_factory,
|
|
628
628
|
zero_padding_node=pad_node,
|
|
629
629
|
bypass_nodes=bypass_nodes,
|
|
630
630
|
params_search_quantization_fn=params_search_quantization_fn)
|
|
@@ -207,7 +207,7 @@ class TensorboardWriter(object):
|
|
|
207
207
|
attr = dict()
|
|
208
208
|
if n.final_activation_quantization_cfg is not None:
|
|
209
209
|
attr.update(n.final_activation_quantization_cfg.__dict__)
|
|
210
|
-
elif n.
|
|
210
|
+
elif n.quantization_cfg is not None:
|
|
211
211
|
attr.update(n.get_unified_activation_candidates_dict())
|
|
212
212
|
return attr
|
|
213
213
|
|
|
@@ -229,7 +229,7 @@ class TensorboardWriter(object):
|
|
|
229
229
|
attr = dict()
|
|
230
230
|
if n.final_weights_quantization_cfg is not None:
|
|
231
231
|
attr.update(n.final_weights_quantization_cfg.__dict__)
|
|
232
|
-
elif n.
|
|
232
|
+
elif n.quantization_cfg is not None:
|
|
233
233
|
attr.update(n.get_unified_weights_candidates_dict())
|
|
234
234
|
return attr
|
|
235
235
|
|
|
@@ -530,8 +530,6 @@ def init_tensorboard_writer() -> TensorboardWriter:
|
|
|
530
530
|
Create a TensorBoardWriter object initialized with the logger dir path if it was set,
|
|
531
531
|
or None otherwise.
|
|
532
532
|
|
|
533
|
-
Args:
|
|
534
|
-
|
|
535
533
|
Returns:
|
|
536
534
|
A TensorBoardWriter object.
|
|
537
535
|
"""
|
|
@@ -16,22 +16,22 @@
|
|
|
16
16
|
|
|
17
17
|
from typing import Callable, Any
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit.core.common import FrameworkInfo
|
|
20
19
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
21
|
-
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator
|
|
22
20
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
23
21
|
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
|
|
24
22
|
from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates
|
|
25
|
-
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
|
23
|
+
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG, \
|
|
24
|
+
QuantizationErrorMethod
|
|
26
25
|
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
|
27
|
-
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import
|
|
28
|
-
set_quantization_configuration_to_graph
|
|
26
|
+
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_manual_bitwidth_config
|
|
29
27
|
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
|
|
30
28
|
from model_compression_toolkit.core.common.substitutions.linear_collapsing_substitution import \
|
|
31
29
|
linear_collapsing_substitute
|
|
32
30
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter
|
|
31
|
+
from model_compression_toolkit.quantization_preparation.load_fqc import load_fqc_configuration
|
|
33
32
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
|
34
33
|
FrameworkQuantizationCapabilities
|
|
34
|
+
from model_compression_toolkit.logger import Logger
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
def graph_preparation_runner(in_model: Any,
|
|
@@ -112,6 +112,12 @@ def get_finalized_graph(initial_graph: Graph,
|
|
|
112
112
|
|
|
113
113
|
Returns: Graph object that represents the model, after applying all required modifications to it.
|
|
114
114
|
"""
|
|
115
|
+
if quant_config.weights_error_method == QuantizationErrorMethod.HMSE:
|
|
116
|
+
if not running_gptq:
|
|
117
|
+
raise ValueError(f"The HMSE error method for parameters selection is only supported when running GPTQ "
|
|
118
|
+
f"optimization due to long execution time that is not suitable for basic PTQ.")
|
|
119
|
+
Logger.warning("Using the HMSE error method for weights quantization parameters search. "
|
|
120
|
+
"Note: This method may significantly increase runtime during the parameter search process.")
|
|
115
121
|
|
|
116
122
|
######################################
|
|
117
123
|
# Graph substitution (prepare graph)
|
|
@@ -141,21 +147,26 @@ def get_finalized_graph(initial_graph: Graph,
|
|
|
141
147
|
if tb_w is not None:
|
|
142
148
|
tb_w.add_graph(transformed_graph, 'pre_statistics_collection_substitutions')
|
|
143
149
|
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
150
|
+
transformed_graph = load_fqc_configuration(transformed_graph, fqc)
|
|
151
|
+
|
|
152
|
+
# filter candidates per manual config
|
|
153
|
+
if bit_width_config:
|
|
154
|
+
set_manual_bitwidth_config(graph, bit_width_config)
|
|
155
|
+
|
|
156
|
+
# TODO irena: load_fqc_configuration only loads config from tpc. Previously quant_config was read as well.
|
|
157
|
+
# As a first stage we keep the attributes in internal configs and fill them manually from quant_config
|
|
158
|
+
# not to break all the code at once. Eventually we need to handle quant_config directly, without injecting into candidates.
|
|
159
|
+
# TODO 2: Also we adjust candidates for single precision, which we shouldn't do here.
|
|
160
|
+
def update(qc):
|
|
161
|
+
qc.activation_quantization_cfg.set_qc(quant_config)
|
|
162
|
+
qc.weights_quantization_cfg.set_qc(quant_config)
|
|
163
|
+
for attr_cfg in qc.weights_quantization_cfg.get_all_weight_attrs_configs().values():
|
|
164
|
+
attr_cfg.weights_error_method = quant_config.weights_error_method
|
|
165
|
+
attr_cfg.l_p_value = quant_config.l_p_value
|
|
166
|
+
for n in transformed_graph.nodes:
|
|
167
|
+
if not mixed_precision_enable:
|
|
168
|
+
n.quantization_cfg.candidates_quantization_cfg = [n.quantization_cfg.base_quantization_cfg]
|
|
169
|
+
n.quantization_cfg.update_all(update)
|
|
159
170
|
|
|
160
171
|
######################################
|
|
161
172
|
# Channel equalization
|
|
@@ -14,9 +14,10 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import List
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.core import FrameworkInfo
|
|
18
17
|
from model_compression_toolkit.core import common
|
|
19
18
|
from model_compression_toolkit.core.common import BaseNode
|
|
19
|
+
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantization_fn
|
|
20
|
+
from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
|
|
20
21
|
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
|
21
22
|
from tensorflow.python.util.object_identity import Reference as TFReference
|
|
22
23
|
|
|
@@ -56,4 +57,6 @@ class QuantizedKerasModelBuilder(KerasModelBuilder):
|
|
|
56
57
|
Output of the node.
|
|
57
58
|
|
|
58
59
|
"""
|
|
59
|
-
|
|
60
|
+
activation_quantizer = get_activation_quantization_fn(node.final_activation_quantization_cfg,
|
|
61
|
+
get_activation_quantization_fn_factory)
|
|
62
|
+
return activation_quantizer(input_tensors)
|
|
@@ -18,7 +18,6 @@ import tensorflow as tf
|
|
|
18
18
|
from typing import Tuple, Any, Dict
|
|
19
19
|
from functools import wraps
|
|
20
20
|
|
|
21
|
-
from model_compression_toolkit.core.keras.quantizer.lut_fake_quant import activation_lut_kmean_quantizer
|
|
22
21
|
from packaging import version
|
|
23
22
|
|
|
24
23
|
if version.parse(tf.__version__) >= version.parse("2.13"):
|
|
@@ -26,11 +25,9 @@ if version.parse(tf.__version__) >= version.parse("2.13"):
|
|
|
26
25
|
else:
|
|
27
26
|
from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU, Activation # pragma: no cover
|
|
28
27
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo, set_fw_info, ChannelAxisMapping
|
|
29
|
-
from mct_quantizers import QuantizationMethod
|
|
30
28
|
from model_compression_toolkit.constants import SOFTMAX_THRESHOLD, ACTIVATION
|
|
31
29
|
from model_compression_toolkit.core.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \
|
|
32
30
|
KERNEL, DEPTHWISE_KERNEL, GELU
|
|
33
|
-
from model_compression_toolkit.core.keras.quantizer.fake_quant_builder import power_of_two_quantization, symmetric_quantization, uniform_quantization
|
|
34
31
|
|
|
35
32
|
|
|
36
33
|
class KerasInfo(FrameworkInfo):
|
|
@@ -103,14 +100,6 @@ class KerasInfo(FrameworkInfo):
|
|
|
103
100
|
tf.nn.softmax: (0, SOFTMAX_THRESHOLD),
|
|
104
101
|
}
|
|
105
102
|
|
|
106
|
-
"""
|
|
107
|
-
Mapping from a QuantizationMethod to an activation quantizer function.
|
|
108
|
-
"""
|
|
109
|
-
activation_quantizer_mapping = {QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
|
|
110
|
-
QuantizationMethod.SYMMETRIC: symmetric_quantization,
|
|
111
|
-
QuantizationMethod.UNIFORM: uniform_quantization,
|
|
112
|
-
QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer}
|
|
113
|
-
|
|
114
103
|
@classmethod
|
|
115
104
|
def get_layer_min_max(cls, layer: Any, fw_attrs: Dict) -> Tuple[float, float]:
|
|
116
105
|
"""
|
|
@@ -18,13 +18,12 @@ from tensorflow.keras.layers import InputLayer, Dense, DepthwiseConv2D, Conv2D,
|
|
|
18
18
|
from typing import List
|
|
19
19
|
|
|
20
20
|
from model_compression_toolkit.core import common
|
|
21
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
22
21
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
23
|
-
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher,
|
|
22
|
+
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, WalkMatcher
|
|
24
23
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
25
|
-
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
|
26
24
|
from model_compression_toolkit.constants import THRESHOLD
|
|
27
|
-
from model_compression_toolkit.core.
|
|
25
|
+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_weights_computation import \
|
|
26
|
+
compute_weights_qparams
|
|
28
27
|
from model_compression_toolkit.logger import Logger
|
|
29
28
|
|
|
30
29
|
input_node = NodeOperationMatcher(InputLayer)
|
|
@@ -104,8 +103,12 @@ class BaseInputScaling(common.BaseSubstitution):
|
|
|
104
103
|
|
|
105
104
|
# After scaling weights may have different thresholds so it needs to be recalculated
|
|
106
105
|
for nqc in linear_layer.candidates_quantization_cfg:
|
|
107
|
-
nqc.weights_quantization_cfg.get_attr_config(linear_layer.kernel_attr)
|
|
108
|
-
|
|
106
|
+
attr_cfg = nqc.weights_quantization_cfg.get_attr_config(linear_layer.kernel_attr)
|
|
107
|
+
assert attr_cfg.enable_weights_quantization
|
|
108
|
+
w_params, _ = compute_weights_qparams(w1_fixed, attr_quant_config=attr_cfg,
|
|
109
|
+
output_channels_axis=attr_cfg.weights_channels_axis.output,
|
|
110
|
+
min_threshold=nqc.weights_quantization_cfg.min_threshold)
|
|
111
|
+
attr_cfg.set_weights_quantization_param(w_params)
|
|
109
112
|
|
|
110
113
|
return graph
|
|
111
114
|
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py
CHANGED
|
@@ -34,6 +34,7 @@ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOpera
|
|
|
34
34
|
NodeFrameworkAttrMatcher
|
|
35
35
|
from model_compression_toolkit.core.common.substitutions.shift_negative_activation import \
|
|
36
36
|
apply_shift_negative_correction
|
|
37
|
+
from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
|
|
37
38
|
from model_compression_toolkit.core.keras.constants import KERNEL_SIZE, STRIDES, ACTIVATION, SWISH, \
|
|
38
39
|
SELU, GELU, FUNCTION, ADD, PAD
|
|
39
40
|
from model_compression_toolkit.core.keras.constants import NEGATIVE_SLOPE, PADDING, PAD_SAME, PAD_VALID, BIAS, USE_BIAS
|
|
@@ -252,5 +253,6 @@ def keras_apply_shift_negative_correction(graph: Graph,
|
|
|
252
253
|
is_padding_node_and_node_has_padding,
|
|
253
254
|
PADDING,
|
|
254
255
|
BIAS,
|
|
255
|
-
USE_BIAS
|
|
256
|
+
USE_BIAS,
|
|
257
|
+
get_activation_quantization_fn_factory
|
|
256
258
|
)
|
|
@@ -94,7 +94,7 @@ class WeightsHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
|
|
|
94
94
|
for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
|
|
95
95
|
|
|
96
96
|
# Check if the target node's layer type is supported.
|
|
97
|
-
if not ipt_node.
|
|
97
|
+
if not ipt_node.kernel_attr:
|
|
98
98
|
Logger.critical(f"Hessian information with respect to weights is not supported for "
|
|
99
99
|
f"{ipt_node.type} layers.") # pragma: no cover
|
|
100
100
|
|
|
@@ -23,6 +23,7 @@ from model_compression_toolkit.core.common.mixed_precision.configurable_quantize
|
|
|
23
23
|
verify_candidates_descending_order, init_activation_quantizers
|
|
24
24
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
|
25
25
|
CandidateNodeQuantizationConfig
|
|
26
|
+
from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
|
|
26
27
|
from model_compression_toolkit.logger import Logger
|
|
27
28
|
|
|
28
29
|
import tensorflow as tf
|
|
@@ -67,7 +68,7 @@ class ConfigurableActivationQuantizer(BaseKerasInferableQuantizer):
|
|
|
67
68
|
if qc.activation_quantization_cfg.quant_mode != node_q_cfg[0].activation_quantization_cfg.quant_mode:
|
|
68
69
|
Logger.critical("Unsupported configuration: Mixing candidates with differing activation quantization states (enabled/disabled).") # pragma: no cover
|
|
69
70
|
|
|
70
|
-
self.activation_quantizers = init_activation_quantizers(self.node_q_cfg)
|
|
71
|
+
self.activation_quantizers = init_activation_quantizers(self.node_q_cfg, get_activation_quantization_fn_factory)
|
|
71
72
|
self.active_quantization_config_index = max_candidate_idx # initialize with first config as default
|
|
72
73
|
|
|
73
74
|
def set_active_activation_quantizer(self, index: Optional[int]):
|
|
@@ -155,7 +155,7 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
|
155
155
|
"""
|
|
156
156
|
|
|
157
157
|
attributes_with_axis = {}
|
|
158
|
-
if node.
|
|
158
|
+
if node.kernel_attr:
|
|
159
159
|
attributes_with_axis[node.kernel_attr] = (node.channel_axis.output, node.channel_axis.input)
|
|
160
160
|
|
|
161
161
|
# Bias is a vector at the length of the number of output channels.
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from collections.abc import Callable
|
|
17
|
+
|
|
18
|
+
from mct_quantizers import QuantizationMethod
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
from model_compression_toolkit.core.keras.quantization.fake_quant_builder import power_of_two_quantization, \
|
|
22
|
+
symmetric_quantization, uniform_quantization
|
|
23
|
+
from model_compression_toolkit.core.keras.quantization.lut_fake_quant import activation_lut_kmean_quantizer
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
Mapping from a QuantizationMethod to an activation quantizer function.
|
|
28
|
+
"""
|
|
29
|
+
_activation_quantizer_factory_mapping = {
|
|
30
|
+
QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
|
|
31
|
+
QuantizationMethod.SYMMETRIC: symmetric_quantization,
|
|
32
|
+
QuantizationMethod.UNIFORM: uniform_quantization,
|
|
33
|
+
QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_activation_quantization_fn_factory(quantization_method: QuantizationMethod) -> Callable[[int, dict], Callable]:
|
|
38
|
+
"""
|
|
39
|
+
Get factory for activation quantizer.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
quantization_method: quantization method for activation.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Factory that accepts activation bitwidth and a dict of quantization params, and returns the quantizer.
|
|
46
|
+
"""
|
|
47
|
+
return _activation_quantizer_factory_mapping[quantization_method]
|
|
@@ -25,7 +25,7 @@ else:
|
|
|
25
25
|
|
|
26
26
|
from model_compression_toolkit.core import QuantizationConfig
|
|
27
27
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
28
|
-
from model_compression_toolkit.core.
|
|
28
|
+
from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
|
|
29
29
|
from model_compression_toolkit.core.common import Graph
|
|
30
30
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
31
31
|
from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \
|
|
@@ -60,5 +60,6 @@ def keras_compute_activation_bias_correction_of_graph(graph: Graph,
|
|
|
60
60
|
fw_impl=fw_impl,
|
|
61
61
|
activation_bias_correction_node_matchers=
|
|
62
62
|
activation_bias_correction_node_matchers,
|
|
63
|
-
kernel_size=KERNEL_SIZE
|
|
63
|
+
kernel_size=KERNEL_SIZE,
|
|
64
|
+
get_activation_quantization_fn_factory=get_activation_quantization_fn_factory)
|
|
64
65
|
return graph
|
|
@@ -17,9 +17,10 @@ from typing import List, Tuple
|
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
|
|
20
|
-
from model_compression_toolkit.core import FrameworkInfo
|
|
21
20
|
from model_compression_toolkit.core import common
|
|
22
21
|
from model_compression_toolkit.core.common import BaseNode
|
|
22
|
+
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantization_fn
|
|
23
|
+
from model_compression_toolkit.core.pytorch.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
|
|
23
24
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
24
25
|
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder, \
|
|
25
26
|
PytorchModel
|
|
@@ -60,7 +61,9 @@ class QuantizedPyTorchModel(PytorchModel):
|
|
|
60
61
|
if node.is_activation_quantization_enabled():
|
|
61
62
|
if isinstance(input_tensors, list):
|
|
62
63
|
input_tensors = torch.cat(input_tensors, dim=0)
|
|
63
|
-
|
|
64
|
+
activation_quantizer = get_activation_quantization_fn(node.final_activation_quantization_cfg,
|
|
65
|
+
get_activation_quantization_fn_factory)
|
|
66
|
+
return activation_quantizer(input_tensors)
|
|
64
67
|
return input_tensors
|
|
65
68
|
|
|
66
69
|
|