mct-nightly 2.4.0.20250616.616__py3-none-any.whl → 2.4.0.20250618.606__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/METADATA +1 -1
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/RECORD +120 -120
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +2 -5
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
- model_compression_toolkit/core/common/framework_implementation.py +10 -22
- model_compression_toolkit/core/common/framework_info.py +105 -68
- model_compression_toolkit/core/common/graph/base_graph.py +15 -42
- model_compression_toolkit/core/common/graph/base_node.py +103 -42
- model_compression_toolkit/core/common/graph/functional_node.py +18 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
- model_compression_toolkit/core/common/model_collector.py +10 -20
- model_compression_toolkit/core/common/model_validation.py +1 -4
- model_compression_toolkit/core/common/network_editors/actions.py +14 -38
- model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
- model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
- model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
- model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
- model_compression_toolkit/core/common/pruning/pruner.py +1 -6
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
- model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
- model_compression_toolkit/core/graph_prep_runner.py +2 -16
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/keras/default_framework_info.py +138 -87
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
- model_compression_toolkit/core/keras/keras_implementation.py +15 -35
- model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
- model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +0 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/quantization_prep_runner.py +4 -9
- model_compression_toolkit/core/runner.py +5 -15
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
- model_compression_toolkit/gptq/common/gptq_training.py +1 -8
- model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
- model_compression_toolkit/gptq/keras/graph_info.py +4 -6
- model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
- model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/runner.py +1 -7
- model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
- model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
- model_compression_toolkit/ptq/runner.py +1 -4
- model_compression_toolkit/qat/common/qat_config.py +2 -6
- model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
- model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
- model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
- model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
- model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
- model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
- model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,6 @@ from model_compression_toolkit.core import FrameworkInfo, common
|
|
23
23
|
from model_compression_toolkit.core.common import BaseNode
|
24
24
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
25
25
|
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
|
26
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
27
26
|
from model_compression_toolkit.core.pytorch.mixed_precision.configurable_activation_quantizer import \
|
28
27
|
ConfigurableActivationQuantizer
|
29
28
|
from model_compression_toolkit.core.pytorch.mixed_precision.configurable_weights_quantizer import \
|
@@ -38,14 +37,12 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
|
|
38
37
|
def __init__(self,
|
39
38
|
graph: common.Graph,
|
40
39
|
append2output=None,
|
41
|
-
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
42
40
|
return_float_outputs: bool = False):
|
43
41
|
"""
|
44
42
|
|
45
43
|
Args:
|
46
44
|
graph: Graph to build the model from.
|
47
45
|
append2output: Nodes to append to model's output.
|
48
|
-
fw_info: Information about the specific framework of the model that is built.
|
49
46
|
return_float_outputs: Whether the model returns float tensors or not.
|
50
47
|
"""
|
51
48
|
|
@@ -53,7 +50,6 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
|
|
53
50
|
|
54
51
|
super().__init__(graph,
|
55
52
|
append2output,
|
56
|
-
fw_info,
|
57
53
|
return_float_outputs,
|
58
54
|
wrapper=self.mixed_precision_wrapper,
|
59
55
|
get_activation_quantizer_holder_fn=self.mixed_precision_activation_holder)
|
@@ -77,17 +73,16 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
|
|
77
73
|
ValueError: if kernel attribute is quantized but not configurable.
|
78
74
|
"""
|
79
75
|
|
80
|
-
kernel_attr
|
81
|
-
if kernel_attr is None or not n.is_weights_quantization_enabled(kernel_attr):
|
76
|
+
if n.kernel_attr is None or not n.is_weights_quantization_enabled(n.kernel_attr):
|
82
77
|
return layer
|
83
|
-
if not n.is_configurable_weight(kernel_attr): # pragma: no cover
|
78
|
+
if not n.is_configurable_weight(n.kernel_attr): # pragma: no cover
|
84
79
|
raise ValueError(f'Weight wrapper is not expected to be created for non-configurable weight of node {n}.')
|
85
80
|
return PytorchQuantizationWrapper(layer,
|
86
81
|
weights_quantizers={
|
87
|
-
kernel_attr: ConfigurableWeightsQuantizer(
|
82
|
+
n.kernel_attr: ConfigurableWeightsQuantizer(
|
88
83
|
**self._get_weights_configurable_quantizer_kwargs(n,
|
89
|
-
kernel_attr),
|
90
|
-
kernel_attr=kernel_attr)})
|
84
|
+
n.kernel_attr),
|
85
|
+
kernel_attr=n.kernel_attr)})
|
91
86
|
|
92
87
|
def _get_weights_configurable_quantizer_kwargs(self, n: BaseNode, attr: str) -> Dict[str, Any]:
|
93
88
|
"""
|
@@ -147,14 +142,13 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
|
|
147
142
|
# activation number of bits (in reversed order).
|
148
143
|
# since only kernel attribute is quantized in weights mixed precision,
|
149
144
|
# if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
|
150
|
-
n.sort_node_candidates(
|
145
|
+
n.sort_node_candidates()
|
151
146
|
|
152
147
|
max_candidate_idx = n.find_max_candidate_index()
|
153
148
|
|
154
|
-
kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
|
155
149
|
activation_quantizers = [ConfigurableActivationQuantizer(**{'node_q_cfg': node_q_cfg_candidates,
|
156
150
|
'max_candidate_idx': max_candidate_idx,
|
157
|
-
'kernel_attr': kernel_attr})] \
|
151
|
+
'kernel_attr': n.kernel_attr})] \
|
158
152
|
* num_of_outputs
|
159
153
|
|
160
154
|
# Holder by definition uses a single quantizer for the activation quantization
|
@@ -177,7 +171,7 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
|
|
177
171
|
# creating a mapping between graph nodes and model's layers for mixed precision configurability
|
178
172
|
model_layers = dict(model.named_children())
|
179
173
|
conf_node2layers = {n.name: self._find_layers_in_model_by_node(n, model_layers)
|
180
|
-
for n in self.graph.get_configurable_sorted_nodes(
|
174
|
+
for n in self.graph.get_configurable_sorted_nodes()}
|
181
175
|
|
182
176
|
return model, user_info, conf_node2layers
|
183
177
|
|
@@ -230,8 +224,7 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
|
|
230
224
|
|
231
225
|
"""
|
232
226
|
# Only layers with kernel op are considered weights configurable
|
233
|
-
|
234
|
-
weights_quant = False if kernel_attr is None else n.is_weights_quantization_enabled(kernel_attr)
|
227
|
+
weights_quant = False if n.kernel_attr is None else n.is_weights_quantization_enabled(n.kernel_attr)
|
235
228
|
act_quant = n.is_activation_quantization_enabled()
|
236
229
|
|
237
230
|
if weights_quant and not act_quant:
|
@@ -30,7 +30,6 @@ from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
|
|
30
30
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
31
31
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
32
32
|
from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder
|
33
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
34
33
|
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
|
35
34
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
36
35
|
from mct_quantizers.common.constants import ACTIVATION_HOLDER_QUANTIZER
|
@@ -364,7 +363,7 @@ class PytorchModel(torch.nn.Module):
|
|
364
363
|
"""
|
365
364
|
node_to_output_tensors_dict = dict()
|
366
365
|
node_to_output_tensors_dict_float = dict()
|
367
|
-
configurable_nodes = self.graph.get_configurable_sorted_nodes_names(
|
366
|
+
configurable_nodes = self.graph.get_configurable_sorted_nodes_names()
|
368
367
|
for node in self.node_sort:
|
369
368
|
op_func = self._get_op_func(node, configurable_nodes)
|
370
369
|
input_tensors = _build_input_tensors_list(node,
|
@@ -440,7 +439,6 @@ class PyTorchModelBuilder(BaseModelBuilder):
|
|
440
439
|
def __init__(self,
|
441
440
|
graph: common.Graph,
|
442
441
|
append2output=None,
|
443
|
-
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
444
442
|
return_float_outputs: bool = False,
|
445
443
|
wrapper: Callable = None,
|
446
444
|
get_activation_quantizer_holder_fn: Callable = None):
|
@@ -449,7 +447,6 @@ class PyTorchModelBuilder(BaseModelBuilder):
|
|
449
447
|
Args:
|
450
448
|
graph: Graph to build the model from.
|
451
449
|
append2output: Nodes to append to model's output.
|
452
|
-
fw_info: Information about the specific framework of the model that is built.
|
453
450
|
return_float_outputs: Whether the model returns float tensors or not.
|
454
451
|
wrapper: A function wrapper Pytorch Layers.
|
455
452
|
get_activation_quantizer_holder_fn: Function to retrieve a quantization holder for a node.
|
@@ -457,7 +454,6 @@ class PyTorchModelBuilder(BaseModelBuilder):
|
|
457
454
|
|
458
455
|
super().__init__(graph,
|
459
456
|
append2output,
|
460
|
-
fw_info,
|
461
457
|
return_float_outputs)
|
462
458
|
|
463
459
|
self.wrapper = wrapper
|
@@ -21,7 +21,6 @@ from model_compression_toolkit.core.common import BaseNode
|
|
21
21
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
22
22
|
from model_compression_toolkit.core.pytorch.back2framework.quantization_wrapper.wrapper_quantize_config import \
|
23
23
|
WrapperQuantizeConfig
|
24
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
25
24
|
from model_compression_toolkit.core.pytorch.utils import set_model, to_torch_tensor
|
26
25
|
|
27
26
|
|
@@ -93,7 +92,7 @@ class QuantizedLayerWrapper(torch.nn.Module):
|
|
93
92
|
self.layer = n.type(**framework_attr)
|
94
93
|
self.layer.load_state_dict({k: torch.Tensor(v) for k, v in n.weights.items()}, strict=False)
|
95
94
|
|
96
|
-
def _quantize_weights(self, n:BaseNode):
|
95
|
+
def _quantize_weights(self, n: BaseNode):
|
97
96
|
"""
|
98
97
|
Quantize node's weights and load them as the layer's weights.
|
99
98
|
|
@@ -104,7 +103,7 @@ class QuantizedLayerWrapper(torch.nn.Module):
|
|
104
103
|
None.
|
105
104
|
"""
|
106
105
|
|
107
|
-
self.weight_attrs =
|
106
|
+
self.weight_attrs = [n.kernel_attr]
|
108
107
|
|
109
108
|
# float_weights is a list of weights for each attribute that we want to quantize.
|
110
109
|
float_weights = [n.get_weights_by_keys(attr) for attr in self.weight_attrs]
|
@@ -23,7 +23,6 @@ from model_compression_toolkit.core.common import BaseNode
|
|
23
23
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
24
24
|
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder, \
|
25
25
|
PytorchModel
|
26
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
27
26
|
|
28
27
|
|
29
28
|
class QuantizedPyTorchModel(PytorchModel):
|
@@ -70,20 +69,17 @@ class QuantizedPyTorchModelBuilder(PyTorchModelBuilder):
|
|
70
69
|
def __init__(self,
|
71
70
|
graph: common.Graph,
|
72
71
|
append2output=None,
|
73
|
-
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
74
72
|
return_float_outputs: bool = False):
|
75
73
|
"""
|
76
74
|
|
77
75
|
Args:
|
78
76
|
graph: Graph to build the model from.
|
79
77
|
append2output: Nodes to append to model's output.
|
80
|
-
fw_info: Information about the specific framework of the model that is built.
|
81
78
|
return_float_outputs: Whether the model returns float tensors or not.
|
82
79
|
"""
|
83
80
|
|
84
81
|
super().__init__(graph,
|
85
82
|
append2output,
|
86
|
-
fw_info,
|
87
83
|
return_float_outputs)
|
88
84
|
|
89
85
|
def build_model(self) -> Tuple[PytorchModel, UserInformation]:
|
@@ -12,13 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from
|
16
|
-
from
|
15
|
+
from typing import Any
|
16
|
+
from functools import wraps
|
17
|
+
|
18
|
+
from torch.nn import Hardsigmoid, ReLU, ReLU6, Softmax, Sigmoid, GELU, SELU, SiLU
|
19
|
+
from torch.nn.functional import hardsigmoid, relu, relu6, softmax, gelu, selu, silu
|
17
20
|
from torch.nn import Conv2d, ConvTranspose2d, Linear
|
18
21
|
from torch import sigmoid
|
19
22
|
|
20
|
-
from model_compression_toolkit.
|
21
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo, DEFAULT_KERNEL_ATTRIBUTES
|
23
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo, set_fw_info, ChannelAxisMapping
|
22
24
|
from mct_quantizers import QuantizationMethod
|
23
25
|
from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
|
24
26
|
from model_compression_toolkit.core.pytorch.constants import KERNEL
|
@@ -26,73 +28,97 @@ from model_compression_toolkit.core.pytorch.quantizer.fake_quant_builder import
|
|
26
28
|
symmetric_quantization, uniform_quantization
|
27
29
|
from model_compression_toolkit.core.pytorch.quantizer.lut_fake_quant import activation_lut_kmean_quantizer
|
28
30
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
"""
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
"""
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
31
|
+
|
32
|
+
class PyTorchInfo(FrameworkInfo):
|
33
|
+
"""
|
34
|
+
Extra field defined to handle Activation layer functions:
|
35
|
+
"""
|
36
|
+
|
37
|
+
"""
|
38
|
+
Map each layer to it's weight attribute that should get quantized.
|
39
|
+
If a layer that is not listed here is queried, None is returned.
|
40
|
+
"""
|
41
|
+
kernel_ops_attribute_mapping = {Conv2d: KERNEL,
|
42
|
+
ConvTranspose2d: KERNEL,
|
43
|
+
Linear: KERNEL}
|
44
|
+
|
45
|
+
"""
|
46
|
+
Map a layer to its kernel's output and input channels indices.
|
47
|
+
Map's values are tuples of (output_channel_index, input_channel_index).
|
48
|
+
Default value is returned for layers that are not included.
|
49
|
+
"""
|
50
|
+
kernel_channels_mapping = {Conv2d: ChannelAxisMapping(0, 1),
|
51
|
+
Linear: ChannelAxisMapping(0, 1),
|
52
|
+
ConvTranspose2d: ChannelAxisMapping(1, 0)}
|
53
|
+
|
54
|
+
"""
|
55
|
+
Map a layer to its output channel axis.
|
56
|
+
Where axis=-1 is the last axis
|
57
|
+
"""
|
58
|
+
out_channel_axis_mapping = {Conv2d: 1,
|
59
|
+
Linear: -1,
|
60
|
+
ConvTranspose2d: 1}
|
61
|
+
|
62
|
+
"""
|
63
|
+
Map from an Pytorch module to its min/max output values (if known).
|
64
|
+
The values are used for tensor min/max values initialization.
|
65
|
+
"""
|
66
|
+
_layer_min_max_mapping = {Softmax: (0, SOFTMAX_THRESHOLD),
|
67
|
+
softmax: (0, SOFTMAX_THRESHOLD),
|
68
|
+
Sigmoid: (0, 1),
|
69
|
+
sigmoid: (0, 1),
|
70
|
+
Hardsigmoid: (0, 1),
|
71
|
+
hardsigmoid: (0, 1),
|
72
|
+
ReLU: (0, None),
|
73
|
+
relu: (0, None),
|
74
|
+
ReLU6: (0, None),
|
75
|
+
relu6: (0, None),
|
76
|
+
GELU: (-0.17, None),
|
77
|
+
gelu: (-0.17, None),
|
78
|
+
SELU: (-1.76, None),
|
79
|
+
selu: (-1.76, None),
|
80
|
+
silu: (-0.279, None),
|
81
|
+
SiLU: (-0.279, None),
|
82
|
+
}
|
83
|
+
|
84
|
+
"""
|
85
|
+
Mapping from a QuantizationMethod to an activation quantizer function.
|
86
|
+
"""
|
87
|
+
activation_quantizer_mapping = {QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
|
88
|
+
QuantizationMethod.SYMMETRIC: symmetric_quantization,
|
89
|
+
QuantizationMethod.UNIFORM: uniform_quantization,
|
90
|
+
QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer}
|
91
|
+
|
92
|
+
@classmethod
|
93
|
+
def get_kernel_channels(cls, node_type: Any) -> ChannelAxisMapping:
|
94
|
+
"""
|
95
|
+
Returns node's channels mapping from kernel_channels_mapping or framework specific default value.
|
96
|
+
Args:
|
97
|
+
node_type: A node type.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
Node's channels mapping.
|
101
|
+
|
102
|
+
"""
|
103
|
+
return cls.kernel_channels_mapping.get(node_type, cls._default_channel_mapping)
|
104
|
+
|
105
|
+
@classmethod
|
106
|
+
def get_out_channel_axis(cls, node_type: Any):
|
107
|
+
"""
|
108
|
+
Returns node's output channel mapping from out_channel_axis_mapping or framework specific default value.
|
109
|
+
Args:
|
110
|
+
node_type: A node type.
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
Node's output channel axis.
|
114
|
+
|
115
|
+
"""
|
116
|
+
return cls.out_channel_axis_mapping.get(node_type, 1)
|
117
|
+
|
118
|
+
|
119
|
+
def set_pytorch_info(func):
|
120
|
+
@wraps(func)
|
121
|
+
def wrapper(*args, **kwargs):
|
122
|
+
set_fw_info(PyTorchInfo)
|
123
|
+
return func(*args, **kwargs)
|
124
|
+
return wrapper
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py
CHANGED
@@ -21,19 +21,18 @@ from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
21
21
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
22
22
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
23
23
|
from model_compression_toolkit.core.pytorch.constants import IN_CHANNELS, OUT_CHANNELS, KERNEL_SIZE, KERNEL, BIAS
|
24
|
-
from model_compression_toolkit.core.common import
|
24
|
+
from model_compression_toolkit.core.common.framework_info import get_fw_info
|
25
25
|
|
26
26
|
|
27
27
|
class FunctionalConvSubstitution(common.BaseSubstitution):
|
28
28
|
"""
|
29
29
|
Substitute functional convolutions with Layers
|
30
30
|
"""
|
31
|
-
def __init__(self
|
31
|
+
def __init__(self):
|
32
32
|
"""
|
33
33
|
Matches a functional conv node
|
34
34
|
"""
|
35
35
|
func_node = NodeOperationMatcher(conv2d) | NodeOperationMatcher(conv_transpose2d)
|
36
|
-
self.fw_info = fw_info
|
37
36
|
super().__init__(matcher_instance=func_node)
|
38
37
|
|
39
38
|
def substitute(self,
|
@@ -56,7 +55,7 @@ class FunctionalConvSubstitution(common.BaseSubstitution):
|
|
56
55
|
else:
|
57
56
|
Logger.critical(f'Substitution filter mismatch. Layer {func_node.type}. Must be {type(Conv2d)} or {type(ConvTranspose2d)}.') # pragma: no cover
|
58
57
|
|
59
|
-
out_channel_index, in_channel_index =
|
58
|
+
out_channel_index, in_channel_index = get_fw_info().get_kernel_channels(new_layer)
|
60
59
|
|
61
60
|
# Create new node of layer convolution
|
62
61
|
if 1 not in func_node.weights:
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py
CHANGED
@@ -46,17 +46,15 @@ class ScaleEqualization(BaseScaleEqualization):
|
|
46
46
|
"""
|
47
47
|
|
48
48
|
def __init__(self,
|
49
|
-
quant_config: QuantizationConfig
|
50
|
-
fw_info: FrameworkInfo):
|
49
|
+
quant_config: QuantizationConfig):
|
51
50
|
"""
|
52
51
|
Initialize a ScaleEqualization object.
|
53
52
|
Args:
|
54
53
|
quant_config: Quantization configuration.
|
55
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
56
54
|
groups of layers by how they should be quantized, etc.)
|
57
55
|
"""
|
58
56
|
|
59
|
-
super().__init__(quant_config=quant_config,
|
57
|
+
super().__init__(quant_config=quant_config, matcher_instance=MATCHER,
|
60
58
|
kernel_str=KERNEL, bias_str=BIAS)
|
61
59
|
|
62
60
|
|
@@ -66,15 +64,13 @@ class ScaleEqualizationWithPad(BaseScaleEqualization):
|
|
66
64
|
"""
|
67
65
|
|
68
66
|
def __init__(self,
|
69
|
-
quant_config: QuantizationConfig
|
70
|
-
fw_info: FrameworkInfo):
|
67
|
+
quant_config: QuantizationConfig):
|
71
68
|
"""
|
72
69
|
Initialize a ScaleEqualization object.
|
73
70
|
Args:
|
74
71
|
quant_config: Quantization configuration.
|
75
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
76
72
|
groups of layers by how they should be quantized, etc.)
|
77
73
|
"""
|
78
74
|
|
79
|
-
super().__init__(quant_config=quant_config,
|
75
|
+
super().__init__(quant_config=quant_config, matcher_instance=MATCHER_WITH_PAD,
|
80
76
|
kernel_str=KERNEL, bias_str=BIAS)
|
@@ -214,15 +214,13 @@ def is_padding_node_and_node_has_padding(pad_node_to_consider: BaseNode,
|
|
214
214
|
|
215
215
|
|
216
216
|
def pytorch_apply_shift_negative_correction(graph: Graph,
|
217
|
-
core_config: CoreConfig
|
218
|
-
fw_info: FrameworkInfo) -> Graph:
|
217
|
+
core_config: CoreConfig) -> Graph:
|
219
218
|
"""
|
220
219
|
Apply shift negative correction (SNC) on a graph built from a Pytorch model.
|
221
220
|
|
222
221
|
Args:
|
223
222
|
graph: Graph to apply SNC on.
|
224
223
|
core_config: Quantization configuration.
|
225
|
-
fw_info: FrameworkInfo object with information about the specific framework's module.
|
226
224
|
|
227
225
|
Returns:
|
228
226
|
Graph after SNC.
|
@@ -230,7 +228,6 @@ def pytorch_apply_shift_negative_correction(graph: Graph,
|
|
230
228
|
snc_node, linear_node, bypass_node, pad_node = shift_negative_activation_node_matchers()
|
231
229
|
return apply_shift_negative_correction(graph,
|
232
230
|
core_config,
|
233
|
-
fw_info,
|
234
231
|
snc_node,
|
235
232
|
linear_node,
|
236
233
|
bypass_node,
|
@@ -23,7 +23,6 @@ from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESS
|
|
23
23
|
from model_compression_toolkit.core.common import Graph
|
24
24
|
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
|
25
25
|
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
|
26
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
27
26
|
from model_compression_toolkit.core.pytorch.hessian.hessian_scores_calculator_pytorch import \
|
28
27
|
HessianScoresCalculatorPytorch
|
29
28
|
from model_compression_toolkit.logger import Logger
|
@@ -92,22 +91,14 @@ class WeightsHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
|
|
92
91
|
for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
|
93
92
|
|
94
93
|
# Check if the target node's layer type is supported.
|
95
|
-
if not
|
94
|
+
if not ipt_node.is_kernel_op:
|
96
95
|
Logger.critical(f"Hessian information with respect to weights is not supported for "
|
97
96
|
f"{ipt_node.type} layers.") # pragma: no cover
|
98
97
|
|
99
|
-
|
100
|
-
weights_attributes = DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(ipt_node.type)
|
101
|
-
|
102
|
-
# Get the weight tensor for the target node
|
103
|
-
if len(weights_attributes) != 1: # pragma: no cover
|
104
|
-
Logger.critical(f"Currently, Hessian scores with respect to weights are supported only for nodes with a "
|
105
|
-
f"single weight attribute. {len(weights_attributes)} attributes found.")
|
106
|
-
|
107
|
-
weights_tensor = getattr(getattr(model, ipt_node.name), weights_attributes[0])
|
98
|
+
weights_tensor = getattr(getattr(model, ipt_node.name), ipt_node.kernel_attr)
|
108
99
|
|
109
100
|
# Get the output channel index
|
110
|
-
output_channel_axis
|
101
|
+
output_channel_axis = ipt_node.channel_axis.output
|
111
102
|
shape_channel_axis = [i for i in range(len(weights_tensor.shape))]
|
112
103
|
if self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL:
|
113
104
|
shape_channel_axis.remove(output_channel_axis)
|