mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -24,6 +24,8 @@ from model_compression_toolkit.core.common.user_info import UserInformation
|
|
24
24
|
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder, \
|
25
25
|
PytorchModel
|
26
26
|
|
27
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
28
|
+
|
27
29
|
|
28
30
|
class FloatPyTorchModel(PytorchModel):
|
29
31
|
"""
|
@@ -32,16 +34,19 @@ class FloatPyTorchModel(PytorchModel):
|
|
32
34
|
|
33
35
|
def __init__(self,
|
34
36
|
graph: common.Graph,
|
35
|
-
append2output=None
|
37
|
+
append2output=None,
|
38
|
+
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO):
|
36
39
|
"""
|
37
40
|
|
38
41
|
Args:
|
39
42
|
graph: Graph to build its corresponding Pytorch model.
|
40
43
|
append2output: List of nodes or OutTensor objects.
|
44
|
+
fw_info: Framework information (e.g., mapping from layers to their attributes to quantize).
|
41
45
|
"""
|
42
46
|
|
43
47
|
super().__init__(graph,
|
44
|
-
append2output
|
48
|
+
append2output,
|
49
|
+
fw_info)
|
45
50
|
|
46
51
|
def _quantize_node_activations(self,
|
47
52
|
node: BaseNode,
|
@@ -66,17 +71,20 @@ class FloatPyTorchModelBuilder(PyTorchModelBuilder):
|
|
66
71
|
def __init__(self,
|
67
72
|
graph: common.Graph,
|
68
73
|
append2output=None,
|
74
|
+
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
69
75
|
return_float_outputs: bool = False):
|
70
76
|
"""
|
71
77
|
|
72
78
|
Args:
|
73
79
|
graph: Graph to build the model from.
|
74
80
|
append2output: Nodes to append to model's output.
|
81
|
+
fw_info: Information about the specific framework of the model that is built.
|
75
82
|
return_float_outputs: Whether the model returns float tensors or not.
|
76
83
|
"""
|
77
84
|
|
78
85
|
super().__init__(graph,
|
79
86
|
append2output,
|
87
|
+
fw_info,
|
80
88
|
return_float_outputs)
|
81
89
|
|
82
90
|
def build_model(self) -> Tuple[PytorchModel, UserInformation]:
|
@@ -86,4 +94,5 @@ class FloatPyTorchModelBuilder(PyTorchModelBuilder):
|
|
86
94
|
|
87
95
|
"""
|
88
96
|
return FloatPyTorchModel(self.graph,
|
89
|
-
self.append2output
|
97
|
+
self.append2output,
|
98
|
+
self.fw_info), self.graph.user_info
|
@@ -23,6 +23,7 @@ from model_compression_toolkit.core import FrameworkInfo, common
|
|
23
23
|
from model_compression_toolkit.core.common import BaseNode
|
24
24
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
25
25
|
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
|
26
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
26
27
|
from model_compression_toolkit.core.pytorch.mixed_precision.configurable_activation_quantizer import \
|
27
28
|
ConfigurableActivationQuantizer
|
28
29
|
from model_compression_toolkit.core.pytorch.mixed_precision.configurable_weights_quantizer import \
|
@@ -37,12 +38,14 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
|
|
37
38
|
def __init__(self,
|
38
39
|
graph: common.Graph,
|
39
40
|
append2output=None,
|
41
|
+
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
40
42
|
return_float_outputs: bool = False):
|
41
43
|
"""
|
42
44
|
|
43
45
|
Args:
|
44
46
|
graph: Graph to build the model from.
|
45
47
|
append2output: Nodes to append to model's output.
|
48
|
+
fw_info: Information about the specific framework of the model that is built.
|
46
49
|
return_float_outputs: Whether the model returns float tensors or not.
|
47
50
|
"""
|
48
51
|
|
@@ -50,6 +53,7 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
|
|
50
53
|
|
51
54
|
super().__init__(graph,
|
52
55
|
append2output,
|
56
|
+
fw_info,
|
53
57
|
return_float_outputs,
|
54
58
|
wrapper=self.mixed_precision_wrapper,
|
55
59
|
get_activation_quantizer_holder_fn=self.mixed_precision_activation_holder)
|
@@ -73,16 +77,17 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
|
|
73
77
|
ValueError: if kernel attribute is quantized but not configurable.
|
74
78
|
"""
|
75
79
|
|
76
|
-
|
80
|
+
kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
|
81
|
+
if kernel_attr is None or not n.is_weights_quantization_enabled(kernel_attr):
|
77
82
|
return layer
|
78
|
-
if not n.is_configurable_weight(
|
83
|
+
if not n.is_configurable_weight(kernel_attr): # pragma: no cover
|
79
84
|
raise ValueError(f'Weight wrapper is not expected to be created for non-configurable weight of node {n}.')
|
80
85
|
return PytorchQuantizationWrapper(layer,
|
81
86
|
weights_quantizers={
|
82
|
-
|
87
|
+
kernel_attr: ConfigurableWeightsQuantizer(
|
83
88
|
**self._get_weights_configurable_quantizer_kwargs(n,
|
84
|
-
|
85
|
-
kernel_attr=
|
89
|
+
kernel_attr),
|
90
|
+
kernel_attr=kernel_attr)})
|
86
91
|
|
87
92
|
def _get_weights_configurable_quantizer_kwargs(self, n: BaseNode, attr: str) -> Dict[str, Any]:
|
88
93
|
"""
|
@@ -142,13 +147,14 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
|
|
142
147
|
# activation number of bits (in reversed order).
|
143
148
|
# since only kernel attribute is quantized in weights mixed precision,
|
144
149
|
# if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
|
145
|
-
n.sort_node_candidates()
|
150
|
+
n.sort_node_candidates(self.fw_info)
|
146
151
|
|
147
152
|
max_candidate_idx = n.find_max_candidate_index()
|
148
153
|
|
154
|
+
kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
|
149
155
|
activation_quantizers = [ConfigurableActivationQuantizer(**{'node_q_cfg': node_q_cfg_candidates,
|
150
156
|
'max_candidate_idx': max_candidate_idx,
|
151
|
-
'kernel_attr':
|
157
|
+
'kernel_attr': kernel_attr})] \
|
152
158
|
* num_of_outputs
|
153
159
|
|
154
160
|
# Holder by definition uses a single quantizer for the activation quantization
|
@@ -171,7 +177,7 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
|
|
171
177
|
# creating a mapping between graph nodes and model's layers for mixed precision configurability
|
172
178
|
model_layers = dict(model.named_children())
|
173
179
|
conf_node2layers = {n.name: self._find_layers_in_model_by_node(n, model_layers)
|
174
|
-
for n in self.graph.get_configurable_sorted_nodes()}
|
180
|
+
for n in self.graph.get_configurable_sorted_nodes(self.fw_info)}
|
175
181
|
|
176
182
|
return model, user_info, conf_node2layers
|
177
183
|
|
@@ -224,7 +230,8 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
|
|
224
230
|
|
225
231
|
"""
|
226
232
|
# Only layers with kernel op are considered weights configurable
|
227
|
-
|
233
|
+
kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
|
234
|
+
weights_quant = False if kernel_attr is None else n.is_weights_quantization_enabled(kernel_attr)
|
228
235
|
act_quant = n.is_activation_quantization_enabled()
|
229
236
|
|
230
237
|
if weights_quant and not act_quant:
|
@@ -30,6 +30,7 @@ from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
|
|
30
30
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
31
31
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
32
32
|
from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder
|
33
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
33
34
|
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
|
34
35
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
35
36
|
from mct_quantizers.common.constants import ACTIVATION_HOLDER_QUANTIZER
|
@@ -363,7 +364,7 @@ class PytorchModel(torch.nn.Module):
|
|
363
364
|
"""
|
364
365
|
node_to_output_tensors_dict = dict()
|
365
366
|
node_to_output_tensors_dict_float = dict()
|
366
|
-
configurable_nodes = self.graph.get_configurable_sorted_nodes_names()
|
367
|
+
configurable_nodes = self.graph.get_configurable_sorted_nodes_names(DEFAULT_PYTORCH_INFO)
|
367
368
|
for node in self.node_sort:
|
368
369
|
op_func = self._get_op_func(node, configurable_nodes)
|
369
370
|
input_tensors = _build_input_tensors_list(node,
|
@@ -439,6 +440,7 @@ class PyTorchModelBuilder(BaseModelBuilder):
|
|
439
440
|
def __init__(self,
|
440
441
|
graph: common.Graph,
|
441
442
|
append2output=None,
|
443
|
+
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
442
444
|
return_float_outputs: bool = False,
|
443
445
|
wrapper: Callable = None,
|
444
446
|
get_activation_quantizer_holder_fn: Callable = None):
|
@@ -447,6 +449,7 @@ class PyTorchModelBuilder(BaseModelBuilder):
|
|
447
449
|
Args:
|
448
450
|
graph: Graph to build the model from.
|
449
451
|
append2output: Nodes to append to model's output.
|
452
|
+
fw_info: Information about the specific framework of the model that is built.
|
450
453
|
return_float_outputs: Whether the model returns float tensors or not.
|
451
454
|
wrapper: A function wrapper Pytorch Layers.
|
452
455
|
get_activation_quantizer_holder_fn: Function to retrieve a quantization holder for a node.
|
@@ -454,6 +457,7 @@ class PyTorchModelBuilder(BaseModelBuilder):
|
|
454
457
|
|
455
458
|
super().__init__(graph,
|
456
459
|
append2output,
|
460
|
+
fw_info,
|
457
461
|
return_float_outputs)
|
458
462
|
|
459
463
|
self.wrapper = wrapper
|
@@ -21,6 +21,7 @@ from model_compression_toolkit.core.common import BaseNode
|
|
21
21
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
22
22
|
from model_compression_toolkit.core.pytorch.back2framework.quantization_wrapper.wrapper_quantize_config import \
|
23
23
|
WrapperQuantizeConfig
|
24
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
24
25
|
from model_compression_toolkit.core.pytorch.utils import set_model, to_torch_tensor
|
25
26
|
|
26
27
|
|
@@ -92,7 +93,7 @@ class QuantizedLayerWrapper(torch.nn.Module):
|
|
92
93
|
self.layer = n.type(**framework_attr)
|
93
94
|
self.layer.load_state_dict({k: torch.Tensor(v) for k, v in n.weights.items()}, strict=False)
|
94
95
|
|
95
|
-
def _quantize_weights(self, n:
|
96
|
+
def _quantize_weights(self, n:BaseNode):
|
96
97
|
"""
|
97
98
|
Quantize node's weights and load them as the layer's weights.
|
98
99
|
|
@@ -103,7 +104,7 @@ class QuantizedLayerWrapper(torch.nn.Module):
|
|
103
104
|
None.
|
104
105
|
"""
|
105
106
|
|
106
|
-
self.weight_attrs =
|
107
|
+
self.weight_attrs = DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(n.type)
|
107
108
|
|
108
109
|
# float_weights is a list of weights for each attribute that we want to quantize.
|
109
110
|
float_weights = [n.get_weights_by_keys(attr) for attr in self.weight_attrs]
|
@@ -17,13 +17,13 @@ from typing import List, Tuple
|
|
17
17
|
|
18
18
|
import torch
|
19
19
|
|
20
|
+
from model_compression_toolkit.core import FrameworkInfo
|
20
21
|
from model_compression_toolkit.core import common
|
21
22
|
from model_compression_toolkit.core.common import BaseNode
|
22
|
-
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantization_fn
|
23
|
-
from model_compression_toolkit.core.pytorch.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
|
24
23
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
25
24
|
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder, \
|
26
25
|
PytorchModel
|
26
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
27
27
|
|
28
28
|
|
29
29
|
class QuantizedPyTorchModel(PytorchModel):
|
@@ -61,9 +61,7 @@ class QuantizedPyTorchModel(PytorchModel):
|
|
61
61
|
if node.is_activation_quantization_enabled():
|
62
62
|
if isinstance(input_tensors, list):
|
63
63
|
input_tensors = torch.cat(input_tensors, dim=0)
|
64
|
-
|
65
|
-
get_activation_quantization_fn_factory)
|
66
|
-
return activation_quantizer(input_tensors)
|
64
|
+
return node.final_activation_quantization_cfg.quantize_node_output(input_tensors)
|
67
65
|
return input_tensors
|
68
66
|
|
69
67
|
|
@@ -72,17 +70,20 @@ class QuantizedPyTorchModelBuilder(PyTorchModelBuilder):
|
|
72
70
|
def __init__(self,
|
73
71
|
graph: common.Graph,
|
74
72
|
append2output=None,
|
73
|
+
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
75
74
|
return_float_outputs: bool = False):
|
76
75
|
"""
|
77
76
|
|
78
77
|
Args:
|
79
78
|
graph: Graph to build the model from.
|
80
79
|
append2output: Nodes to append to model's output.
|
80
|
+
fw_info: Information about the specific framework of the model that is built.
|
81
81
|
return_float_outputs: Whether the model returns float tensors or not.
|
82
82
|
"""
|
83
83
|
|
84
84
|
super().__init__(graph,
|
85
85
|
append2output,
|
86
|
+
fw_info,
|
86
87
|
return_float_outputs)
|
87
88
|
|
88
89
|
def build_model(self) -> Tuple[PytorchModel, UserInformation]:
|
@@ -12,101 +12,87 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from
|
16
|
-
from
|
17
|
-
|
18
|
-
from torch.nn import Hardsigmoid, ReLU, ReLU6, Softmax, Sigmoid, GELU, SELU, SiLU
|
19
|
-
from torch.nn.functional import hardsigmoid, relu, relu6, softmax, gelu, selu, silu
|
15
|
+
from torch.nn import Hardsigmoid, ReLU, ReLU6, Softmax, Sigmoid, GELU, SELU
|
16
|
+
from torch.nn.functional import hardsigmoid, relu, relu6, softmax, gelu, selu
|
20
17
|
from torch.nn import Conv2d, ConvTranspose2d, Linear
|
21
18
|
from torch import sigmoid
|
22
19
|
|
23
|
-
from model_compression_toolkit.
|
20
|
+
from model_compression_toolkit.defaultdict import DefaultDict
|
21
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo, DEFAULT_KERNEL_ATTRIBUTES
|
22
|
+
from mct_quantizers import QuantizationMethod
|
24
23
|
from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
|
25
24
|
from model_compression_toolkit.core.pytorch.constants import KERNEL
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
Returns:
|
101
|
-
Node's output channel axis.
|
102
|
-
|
103
|
-
"""
|
104
|
-
return cls.out_channel_axis_mapping.get(node_type)
|
105
|
-
|
106
|
-
|
107
|
-
def set_pytorch_info(func):
|
108
|
-
@wraps(func)
|
109
|
-
def wrapper(*args, **kwargs):
|
110
|
-
set_fw_info(PyTorchInfo)
|
111
|
-
return func(*args, **kwargs)
|
112
|
-
return wrapper
|
25
|
+
from model_compression_toolkit.core.pytorch.quantizer.fake_quant_builder import power_of_two_quantization, \
|
26
|
+
symmetric_quantization, uniform_quantization
|
27
|
+
from model_compression_toolkit.core.pytorch.quantizer.lut_fake_quant import activation_lut_kmean_quantizer
|
28
|
+
|
29
|
+
"""
|
30
|
+
Map each layer to a list of its' weights attributes that should get quantized.
|
31
|
+
If a layer that is not listed here is queried, [None] is returned.
|
32
|
+
"""
|
33
|
+
KERNEL_ATTRIBUTES = DefaultDict({Conv2d: [KERNEL],
|
34
|
+
ConvTranspose2d: [KERNEL],
|
35
|
+
Linear: [KERNEL]},
|
36
|
+
DEFAULT_KERNEL_ATTRIBUTES)
|
37
|
+
|
38
|
+
"""
|
39
|
+
Map a layer to its kernel's output and input channels indices.
|
40
|
+
Map's values are tuples of (output_channel_index, input_channel_index).
|
41
|
+
Default value is returned for layers that are not included.
|
42
|
+
"""
|
43
|
+
DEFAULT_CHANNEL_AXIS_DICT = DefaultDict({Conv2d: (0, 1),
|
44
|
+
Linear: (0, 1),
|
45
|
+
ConvTranspose2d: (1, 0)},
|
46
|
+
(None, None))
|
47
|
+
|
48
|
+
"""
|
49
|
+
Map a layer to its output channel axis.
|
50
|
+
Where axis=-1 is the last axis
|
51
|
+
"""
|
52
|
+
DEFAULT_OUT_CHANNEL_AXIS_DICT = DefaultDict({Conv2d: 1,
|
53
|
+
Linear: -1,
|
54
|
+
ConvTranspose2d: 1},
|
55
|
+
1)
|
56
|
+
|
57
|
+
|
58
|
+
"""
|
59
|
+
Map from an activation function to its min/max output values (if known).
|
60
|
+
The values are used for tensor min/max values initialization.
|
61
|
+
"""
|
62
|
+
ACTIVATION2MINMAX = {} # should be an empty dict in Pytorch
|
63
|
+
|
64
|
+
"""
|
65
|
+
Map from an Pytorch module to its min/max output values (if known).
|
66
|
+
The values are used for tensor min/max values initialization.
|
67
|
+
"""
|
68
|
+
LAYER2MINMAX = {Softmax: (0, SOFTMAX_THRESHOLD),
|
69
|
+
softmax: (0, SOFTMAX_THRESHOLD),
|
70
|
+
Sigmoid: (0, 1),
|
71
|
+
sigmoid: (0, 1),
|
72
|
+
Hardsigmoid: (0, 1),
|
73
|
+
hardsigmoid: (0, 1),
|
74
|
+
ReLU: (0, None),
|
75
|
+
relu: (0, None),
|
76
|
+
ReLU6: (0, None),
|
77
|
+
relu6: (0, None),
|
78
|
+
GELU: (-0.17, None),
|
79
|
+
gelu: (-0.17, None),
|
80
|
+
SELU: (-1.76, None),
|
81
|
+
selu: (-1.76, None),
|
82
|
+
}
|
83
|
+
|
84
|
+
"""
|
85
|
+
Mapping from a QuantizationMethod to an activation quantizer function.
|
86
|
+
"""
|
87
|
+
ACTIVATION_QUANTIZER_MAPPING = {QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
|
88
|
+
QuantizationMethod.SYMMETRIC: symmetric_quantization,
|
89
|
+
QuantizationMethod.UNIFORM: uniform_quantization,
|
90
|
+
QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer}
|
91
|
+
|
92
|
+
|
93
|
+
DEFAULT_PYTORCH_INFO = FrameworkInfo(ACTIVATION_QUANTIZER_MAPPING,
|
94
|
+
DEFAULT_CHANNEL_AXIS_DICT,
|
95
|
+
ACTIVATION2MINMAX,
|
96
|
+
LAYER2MINMAX,
|
97
|
+
KERNEL_ATTRIBUTES,
|
98
|
+
DEFAULT_OUT_CHANNEL_AXIS_DICT)
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py
CHANGED
@@ -21,18 +21,19 @@ from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
21
21
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
22
22
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
23
23
|
from model_compression_toolkit.core.pytorch.constants import IN_CHANNELS, OUT_CHANNELS, KERNEL_SIZE, KERNEL, BIAS
|
24
|
-
from model_compression_toolkit.core.common
|
24
|
+
from model_compression_toolkit.core.common import FrameworkInfo
|
25
25
|
|
26
26
|
|
27
27
|
class FunctionalConvSubstitution(common.BaseSubstitution):
|
28
28
|
"""
|
29
29
|
Substitute functional convolutions with Layers
|
30
30
|
"""
|
31
|
-
def __init__(self):
|
31
|
+
def __init__(self, fw_info: FrameworkInfo):
|
32
32
|
"""
|
33
33
|
Matches a functional conv node
|
34
34
|
"""
|
35
35
|
func_node = NodeOperationMatcher(conv2d) | NodeOperationMatcher(conv_transpose2d)
|
36
|
+
self.fw_info = fw_info
|
36
37
|
super().__init__(matcher_instance=func_node)
|
37
38
|
|
38
39
|
def substitute(self,
|
@@ -55,7 +56,7 @@ class FunctionalConvSubstitution(common.BaseSubstitution):
|
|
55
56
|
else:
|
56
57
|
Logger.critical(f'Substitution filter mismatch. Layer {func_node.type}. Must be {type(Conv2d)} or {type(ConvTranspose2d)}.') # pragma: no cover
|
57
58
|
|
58
|
-
out_channel_index, in_channel_index =
|
59
|
+
out_channel_index, in_channel_index = self.fw_info.kernel_channels_mapping.get(new_layer)
|
59
60
|
|
60
61
|
# Create new node of layer convolution
|
61
62
|
if 1 not in func_node.weights:
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py
CHANGED
@@ -95,11 +95,11 @@ class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
|
|
95
95
|
else:
|
96
96
|
return graph
|
97
97
|
elif non_linear_node.is_match_type(hardtanh):
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
scale_factor =
|
102
|
-
non_linear_node.functional_op.__defaults__ = (0.0, self.threshold,
|
98
|
+
if (non_linear_node.framework_attr[HARDTANH_MIN_VAL] == 0.0) and not \
|
99
|
+
(np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]).astype(int) -
|
100
|
+
np.log2(non_linear_node.framework_attr[HARDTANH_MAX_VAL]) == 0):
|
101
|
+
scale_factor = non_linear_node.framework_attr[HARDTANH_MAX_VAL] / self.threshold
|
102
|
+
non_linear_node.functional_op.__defaults__ = (0.0, self.threshold, non_linear_node.framework_attr[INPLACE])
|
103
103
|
else:
|
104
104
|
return graph
|
105
105
|
else:
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py
CHANGED
@@ -46,15 +46,17 @@ class ScaleEqualization(BaseScaleEqualization):
|
|
46
46
|
"""
|
47
47
|
|
48
48
|
def __init__(self,
|
49
|
-
quant_config: QuantizationConfig
|
49
|
+
quant_config: QuantizationConfig,
|
50
|
+
fw_info: FrameworkInfo):
|
50
51
|
"""
|
51
52
|
Initialize a ScaleEqualization object.
|
52
53
|
Args:
|
53
54
|
quant_config: Quantization configuration.
|
55
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
54
56
|
groups of layers by how they should be quantized, etc.)
|
55
57
|
"""
|
56
58
|
|
57
|
-
super().__init__(quant_config=quant_config, matcher_instance=MATCHER,
|
59
|
+
super().__init__(quant_config=quant_config, fw_info=fw_info, matcher_instance=MATCHER,
|
58
60
|
kernel_str=KERNEL, bias_str=BIAS)
|
59
61
|
|
60
62
|
|
@@ -64,13 +66,15 @@ class ScaleEqualizationWithPad(BaseScaleEqualization):
|
|
64
66
|
"""
|
65
67
|
|
66
68
|
def __init__(self,
|
67
|
-
quant_config: QuantizationConfig
|
69
|
+
quant_config: QuantizationConfig,
|
70
|
+
fw_info: FrameworkInfo):
|
68
71
|
"""
|
69
72
|
Initialize a ScaleEqualization object.
|
70
73
|
Args:
|
71
74
|
quant_config: Quantization configuration.
|
75
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
72
76
|
groups of layers by how they should be quantized, etc.)
|
73
77
|
"""
|
74
78
|
|
75
|
-
super().__init__(quant_config=quant_config, matcher_instance=MATCHER_WITH_PAD,
|
79
|
+
super().__init__(quant_config=quant_config, fw_info=fw_info, matcher_instance=MATCHER_WITH_PAD,
|
76
80
|
kernel_str=KERNEL, bias_str=BIAS)
|
@@ -29,7 +29,6 @@ from model_compression_toolkit.core.common import BaseNode, Graph
|
|
29
29
|
from model_compression_toolkit.core.common.graph.graph_matchers import EdgeMatcher
|
30
30
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
31
31
|
from model_compression_toolkit.core.common.substitutions.shift_negative_activation import apply_shift_negative_correction
|
32
|
-
from model_compression_toolkit.core.pytorch.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
|
33
32
|
from model_compression_toolkit.core.pytorch.constants import PAD, VALUE, PADDING, BIAS, USE_BIAS
|
34
33
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
35
34
|
|
@@ -215,13 +214,15 @@ def is_padding_node_and_node_has_padding(pad_node_to_consider: BaseNode,
|
|
215
214
|
|
216
215
|
|
217
216
|
def pytorch_apply_shift_negative_correction(graph: Graph,
|
218
|
-
core_config: CoreConfig
|
217
|
+
core_config: CoreConfig,
|
218
|
+
fw_info: FrameworkInfo) -> Graph:
|
219
219
|
"""
|
220
220
|
Apply shift negative correction (SNC) on a graph built from a Pytorch model.
|
221
221
|
|
222
222
|
Args:
|
223
223
|
graph: Graph to apply SNC on.
|
224
224
|
core_config: Quantization configuration.
|
225
|
+
fw_info: FrameworkInfo object with information about the specific framework's module.
|
225
226
|
|
226
227
|
Returns:
|
227
228
|
Graph after SNC.
|
@@ -229,6 +230,7 @@ def pytorch_apply_shift_negative_correction(graph: Graph,
|
|
229
230
|
snc_node, linear_node, bypass_node, pad_node = shift_negative_activation_node_matchers()
|
230
231
|
return apply_shift_negative_correction(graph,
|
231
232
|
core_config,
|
233
|
+
fw_info,
|
232
234
|
snc_node,
|
233
235
|
linear_node,
|
234
236
|
bypass_node,
|
@@ -240,5 +242,4 @@ def pytorch_apply_shift_negative_correction(graph: Graph,
|
|
240
242
|
PADDING,
|
241
243
|
BIAS,
|
242
244
|
USE_BIAS,
|
243
|
-
get_activation_quantization_fn_factory,
|
244
245
|
params_search_quantization_fn=params_search_quantization_fn)
|
@@ -23,6 +23,7 @@ from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESS
|
|
23
23
|
from model_compression_toolkit.core.common import Graph
|
24
24
|
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
|
25
25
|
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
|
26
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
26
27
|
from model_compression_toolkit.core.pytorch.hessian.hessian_scores_calculator_pytorch import \
|
27
28
|
HessianScoresCalculatorPytorch
|
28
29
|
from model_compression_toolkit.logger import Logger
|
@@ -91,14 +92,22 @@ class WeightsHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
|
|
91
92
|
for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
|
92
93
|
|
93
94
|
# Check if the target node's layer type is supported.
|
94
|
-
if not ipt_node.
|
95
|
+
if not DEFAULT_PYTORCH_INFO.is_kernel_op(ipt_node.type):
|
95
96
|
Logger.critical(f"Hessian information with respect to weights is not supported for "
|
96
97
|
f"{ipt_node.type} layers.") # pragma: no cover
|
97
98
|
|
98
|
-
|
99
|
+
# Get the weight attributes for the target node type
|
100
|
+
weights_attributes = DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(ipt_node.type)
|
101
|
+
|
102
|
+
# Get the weight tensor for the target node
|
103
|
+
if len(weights_attributes) != 1: # pragma: no cover
|
104
|
+
Logger.critical(f"Currently, Hessian scores with respect to weights are supported only for nodes with a "
|
105
|
+
f"single weight attribute. {len(weights_attributes)} attributes found.")
|
106
|
+
|
107
|
+
weights_tensor = getattr(getattr(model, ipt_node.name), weights_attributes[0])
|
99
108
|
|
100
109
|
# Get the output channel index
|
101
|
-
output_channel_axis = ipt_node.
|
110
|
+
output_channel_axis, _ = DEFAULT_PYTORCH_INFO.kernel_channels_mapping.get(ipt_node.type)
|
102
111
|
shape_channel_axis = [i for i in range(len(weights_tensor.shape))]
|
103
112
|
if self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL:
|
104
113
|
shape_channel_axis.remove(output_channel_axis)
|
@@ -20,7 +20,6 @@ from model_compression_toolkit.core.common.mixed_precision.configurable_quantize
|
|
20
20
|
verify_candidates_descending_order, init_activation_quantizers
|
21
21
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
22
22
|
CandidateNodeQuantizationConfig
|
23
|
-
from model_compression_toolkit.core.pytorch.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
|
24
23
|
from model_compression_toolkit.logger import Logger
|
25
24
|
from mct_quantizers import QuantizationMethod
|
26
25
|
from mct_quantizers import QuantizationTarget
|
@@ -68,7 +67,7 @@ class ConfigurableActivationQuantizer(BasePyTorchInferableQuantizer):
|
|
68
67
|
Logger.critical("Unsupported configuration: Mixing candidates with differing activation quantization states (enabled/disabled).") # pragma: no cover
|
69
68
|
|
70
69
|
# Setting layer's activation
|
71
|
-
self.activation_quantizers = init_activation_quantizers(self.node_q_cfg
|
70
|
+
self.activation_quantizers = init_activation_quantizers(self.node_q_cfg)
|
72
71
|
self.active_quantization_config_index = max_candidate_idx # initialize with first config as default
|
73
72
|
|
74
73
|
def set_active_activation_quantizer(self, index: Optional[int]):
|