mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -12,32 +12,206 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
|
|
15
|
+
import copy
|
|
16
16
|
from typing import Callable
|
|
17
|
+
from functools import partial
|
|
18
|
+
|
|
19
|
+
from model_compression_toolkit.core.common.constants import FOUND_TORCH, PYTORCH
|
|
17
20
|
|
|
18
|
-
from model_compression_toolkit
|
|
21
|
+
from model_compression_toolkit import CoreConfig
|
|
22
|
+
from model_compression_toolkit.core import common
|
|
19
23
|
from model_compression_toolkit.core.common import Logger
|
|
20
|
-
from model_compression_toolkit.core.common.constants import PYTORCH
|
|
21
|
-
from model_compression_toolkit.core.common.target_platform import TargetPlatformCapabilities
|
|
22
|
-
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
23
24
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
24
|
-
from model_compression_toolkit import
|
|
25
|
+
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
26
|
+
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
27
|
+
MixedPrecisionQuantizationConfigV2
|
|
28
|
+
from model_compression_toolkit.core.common.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
29
|
+
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
|
|
30
|
+
from model_compression_toolkit.ptq.runner import ptq_runner
|
|
31
|
+
|
|
25
32
|
|
|
26
33
|
if FOUND_TORCH:
|
|
34
|
+
import torch.nn as nn
|
|
35
|
+
from torch.nn import Module
|
|
27
36
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
28
37
|
from model_compression_toolkit.core.pytorch.constants import DEFAULT_TP_MODEL
|
|
29
|
-
from
|
|
30
|
-
|
|
38
|
+
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
|
39
|
+
from model_compression_toolkit.qat.common.qat_config import _is_qat_applicable
|
|
40
|
+
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
|
|
41
|
+
from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
|
|
42
|
+
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
31
43
|
from model_compression_toolkit import get_target_platform_capabilities
|
|
44
|
+
from model_compression_toolkit.qat.common.qat_config import QATConfig
|
|
45
|
+
from model_compression_toolkit.qat.pytorch.quantizer.quantization_builder import quantization_builder
|
|
32
46
|
DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
|
33
47
|
|
|
34
|
-
|
|
48
|
+
|
|
49
|
+
def qat_wrapper(n: common.BaseNode, module: nn.Module, qat_config: QATConfig):
|
|
50
|
+
"""
|
|
51
|
+
A function which takes a computational graph node and a pytorch module and perform the quantization wrapping
|
|
52
|
+
Args:
|
|
53
|
+
n: A node of mct graph.
|
|
54
|
+
module: A Pytorch module
|
|
55
|
+
qat_config (QATConfig): QAT configuration
|
|
56
|
+
Returns: Wrapped layer
|
|
57
|
+
|
|
58
|
+
"""
|
|
59
|
+
if _is_qat_applicable(n, DEFAULT_PYTORCH_INFO):
|
|
60
|
+
weights_quantizers, activation_quantizers = quantization_builder(n, qat_config, DEFAULT_PYTORCH_INFO)
|
|
61
|
+
return qi.PytorchQuantizationWrapper(module, weights_quantizers, activation_quantizers)
|
|
62
|
+
else:
|
|
63
|
+
return module
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def pytorch_quantization_aware_training_init(in_model: Module,
|
|
35
67
|
representative_data_gen: Callable,
|
|
36
68
|
target_kpi: KPI = None,
|
|
37
69
|
core_config: CoreConfig = CoreConfig(),
|
|
70
|
+
qat_config: QATConfig = QATConfig(),
|
|
38
71
|
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
|
39
72
|
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
|
|
40
|
-
|
|
73
|
+
"""
|
|
74
|
+
Prepare a trained Pytorch model for quantization aware training. First the model quantization is optimized
|
|
75
|
+
with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
|
|
76
|
+
quantized using a symmetric quantization thresholds (power of two).
|
|
77
|
+
The model is first optimized using several transformations (e.g. BatchNormalization folding to
|
|
78
|
+
preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
|
|
79
|
+
being collected for each layer's output (and input, depends on the quantization configuration).
|
|
80
|
+
For each possible bit width (per layer) a threshold is then being calculated using the collected
|
|
81
|
+
statistics. Then, if given a mixed precision config in the core_config, using an ILP solver we find
|
|
82
|
+
a mixed-precision configuration, and set a bit-width for each layer. The model is built with fake_quant
|
|
83
|
+
nodes for quantizing activation. Weights are kept as float and are quantized online while training by the
|
|
84
|
+
quantization wrapper's weight quantizer.
|
|
85
|
+
In order to limit the maximal model's size, a target KPI need to be passed after weights_memory
|
|
86
|
+
is set (in bytes).
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
in_model (Model): Pytorch model to quantize.
|
|
90
|
+
representative_data_gen (Callable): Dataset used for initial calibration.
|
|
91
|
+
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
|
|
92
|
+
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
93
|
+
qat_config (QATConfig): QAT configuration
|
|
94
|
+
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Pytorch info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/pytorch/default_framework_info.py>`_
|
|
95
|
+
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Pytorch model according to.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
|
|
99
|
+
A quantized model.
|
|
100
|
+
User information that may be needed to handle the quantized model.
|
|
101
|
+
|
|
102
|
+
Examples:
|
|
103
|
+
|
|
104
|
+
Import MCT:
|
|
105
|
+
|
|
106
|
+
>>> import model_compression_toolkit as mct
|
|
107
|
+
|
|
108
|
+
Import a Pytorch model:
|
|
109
|
+
|
|
110
|
+
>>> from torchvision.models import mobilenet_v2
|
|
111
|
+
>>> model = mobilenet_v2(pretrained=True)
|
|
112
|
+
|
|
113
|
+
Create a random dataset generator, for required number of calibration iterations (num_calibration_batches):
|
|
114
|
+
In this example a random dataset of 10 batches each containing 4 images is used.
|
|
115
|
+
|
|
116
|
+
>>> import numpy as np
|
|
117
|
+
>>> num_calibration_batches = 10
|
|
118
|
+
>>> def repr_datagen():
|
|
119
|
+
>>> for _ in range(num_calibration_batches):
|
|
120
|
+
>>> yield [np.random.random((4, 3, 224, 224))]
|
|
121
|
+
|
|
122
|
+
Create a MCT core config, containing the quantization configuration:
|
|
123
|
+
|
|
124
|
+
>>> config = mct.CoreConfig()
|
|
125
|
+
|
|
126
|
+
Pass the model, the representative dataset generator, the configuration and the target KPI to get a
|
|
127
|
+
quantized model. Now the model contains quantizer wrappers for fine tunning the weights:
|
|
128
|
+
|
|
129
|
+
>>> quantized_model, quantization_info = pytorch_quantization_aware_training_init(model, repr_datagen, core_config=config)
|
|
130
|
+
|
|
131
|
+
For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
|
|
132
|
+
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
if core_config.mixed_precision_enable:
|
|
136
|
+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
|
|
137
|
+
common.Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
138
|
+
"MixedPrecisionQuantizationConfigV2. Please use pytorch_post_training_quantization API,"
|
|
139
|
+
"or pass a valid mixed precision configuration.")
|
|
140
|
+
|
|
141
|
+
common.Logger.info("Using experimental mixed-precision quantization. "
|
|
142
|
+
"If you encounter an issue please file a bug.")
|
|
143
|
+
|
|
144
|
+
tb_w = _init_tensorboard_writer(fw_info)
|
|
145
|
+
|
|
146
|
+
fw_impl = PytorchImplementation()
|
|
147
|
+
|
|
148
|
+
tg, bit_widths_config = core_runner(in_model=in_model,
|
|
149
|
+
representative_data_gen=representative_data_gen,
|
|
150
|
+
core_config=core_config,
|
|
151
|
+
fw_info=DEFAULT_PYTORCH_INFO,
|
|
152
|
+
fw_impl=fw_impl,
|
|
153
|
+
tpc=target_platform_capabilities,
|
|
154
|
+
target_kpi=target_kpi,
|
|
155
|
+
tb_w=tb_w)
|
|
156
|
+
|
|
157
|
+
tg = ptq_runner(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
|
|
158
|
+
|
|
159
|
+
_qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
|
|
160
|
+
|
|
161
|
+
qat_model, user_info = PyTorchModelBuilder(graph=tg, fw_info=fw_info, wrapper=_qat_wrapper).build_model()
|
|
162
|
+
|
|
163
|
+
user_info.mixed_precision_cfg = bit_widths_config
|
|
164
|
+
|
|
165
|
+
return qat_model, user_info
|
|
166
|
+
|
|
167
|
+
def pytorch_quantization_aware_training_finalize(in_model: Module):
|
|
168
|
+
"""
|
|
169
|
+
Convert a model fine-tuned by the user to a network with QuantizeWrappers containing
|
|
170
|
+
InferableQuantizers, that quantizes both the layers weights and outputs
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
in_model (Model): Pytorch model to remove QuantizeWrappers.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
A quantized model with QuantizeWrappers and InferableQuantizers.
|
|
177
|
+
|
|
178
|
+
Examples:
|
|
179
|
+
|
|
180
|
+
Import MCT:
|
|
181
|
+
|
|
182
|
+
>>> import model_compression_toolkit as mct
|
|
183
|
+
|
|
184
|
+
Import a Pytorch model:
|
|
185
|
+
|
|
186
|
+
>>> from torchvision.models import mobilenet_v2
|
|
187
|
+
>>> model = mobilenet_v2(pretrained=True)
|
|
188
|
+
|
|
189
|
+
Create a random dataset generator:
|
|
190
|
+
|
|
191
|
+
>>> import numpy as np
|
|
192
|
+
>>> def repr_datagen(): yield [np.random.random((1, 224, 224, 3))]
|
|
193
|
+
|
|
194
|
+
Create a MCT core config, containing the quantization configuration:
|
|
195
|
+
|
|
196
|
+
>>> config = mct.CoreConfig()
|
|
197
|
+
|
|
198
|
+
Pass the model, the representative dataset generator, the configuration and the target KPI to get a
|
|
199
|
+
quantized model:
|
|
200
|
+
|
|
201
|
+
>>> quantized_model, quantization_info = pytorch_quantization_aware_training_init(model, repr_datagen, core_config=config)
|
|
202
|
+
|
|
203
|
+
Use the quantized model for fine-tuning. Finally, remove the quantizer wrappers and keep a quantize model ready for inference.
|
|
204
|
+
|
|
205
|
+
>>> quantized_model = mct.pytorch_quantization_aware_training_finalize(quantized_model)
|
|
206
|
+
|
|
207
|
+
"""
|
|
208
|
+
exported_model = copy.deepcopy(in_model)
|
|
209
|
+
for _, layer in exported_model.named_children():
|
|
210
|
+
if isinstance(layer, PytorchQuantizationWrapper):
|
|
211
|
+
layer.convert_to_inferable_quantizers()
|
|
212
|
+
|
|
213
|
+
return exported_model
|
|
214
|
+
|
|
41
215
|
|
|
42
216
|
else:
|
|
43
217
|
# If torch is not installed,
|
|
@@ -45,4 +219,9 @@ else:
|
|
|
45
219
|
def pytorch_quantization_aware_training_init(*args, **kwargs):
|
|
46
220
|
Logger.critical('Installing Pytorch is mandatory '
|
|
47
221
|
'when using pytorch_quantization_aware_training_init. '
|
|
48
|
-
'Could not find the torch package.')
|
|
222
|
+
'Could not find the torch package.') # pragma: no cover
|
|
223
|
+
|
|
224
|
+
def pytorch_quantization_aware_training_finalize(*args, **kwargs):
|
|
225
|
+
Logger.critical('Installing Pytorch is mandatory '
|
|
226
|
+
'when using pytorch_quantization_aware_training_finalize. '
|
|
227
|
+
'Could not find the torch package.') # pragma: no cover
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import model_compression_toolkit.qat.pytorch.quantizer.ste_rounding.symmetric_ste
|
|
17
|
+
import model_compression_toolkit.qat.pytorch.quantizer.ste_rounding.uniform_ste
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import Union
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit.core.common.logger import Logger
|
|
18
|
+
from model_compression_toolkit.core.common.constants import FOUND_TORCH
|
|
19
|
+
|
|
20
|
+
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
21
|
+
TrainableQuantizerActivationConfig
|
|
22
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.pytorch.base_pytorch_quantizer import \
|
|
23
|
+
BasePytorchTrainableQuantizer
|
|
24
|
+
|
|
25
|
+
if FOUND_TORCH:
|
|
26
|
+
|
|
27
|
+
class BasePytorchQATTrainableQuantizer(BasePytorchTrainableQuantizer):
|
|
28
|
+
"""
|
|
29
|
+
A base class for trainable Keras quantizer for QAT.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self,
|
|
33
|
+
quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
|
|
34
|
+
"""
|
|
35
|
+
Initializes BasePytorchQATTrainableQuantizer object.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
quantization_config: quantizer config class contains all the information about a quantizer configuration.
|
|
39
|
+
"""
|
|
40
|
+
super().__init__(quantization_config)
|
|
41
|
+
|
|
42
|
+
else:
|
|
43
|
+
class BasePytorchQATTrainableQuantizer(BasePytorchTrainableQuantizer):
|
|
44
|
+
def __init__(self,
|
|
45
|
+
quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
|
|
46
|
+
super().__init__(quantization_config)
|
|
47
|
+
Logger.critical('Installing Pytorch is mandatory '
|
|
48
|
+
'when using BasePytorchQATTrainableQuantizer. '
|
|
49
|
+
'Could not find torch package.') # pragma: no cover
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import List, Dict, Tuple
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit.core import common
|
|
18
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
19
|
+
from model_compression_toolkit.qat.common.qat_config import QATConfig
|
|
20
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizer_config import \
|
|
21
|
+
get_trainable_quantizer_quantization_candidates, get_trainable_quantizer_weights_config, \
|
|
22
|
+
get_trainable_quantizer_activation_config
|
|
23
|
+
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
|
|
24
|
+
from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
|
|
25
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizers import \
|
|
26
|
+
get_trainable_quantizer_class
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def quantization_builder(n: common.BaseNode,
|
|
30
|
+
qat_config: QATConfig,
|
|
31
|
+
fw_info: FrameworkInfo,
|
|
32
|
+
) -> Tuple[Dict[str, BasePytorchQATTrainableQuantizer],
|
|
33
|
+
List[BasePytorchQATTrainableQuantizer]]:
|
|
34
|
+
"""
|
|
35
|
+
Build quantizers for a node according to its quantization configuration.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
n: Node to build its QuantizeConfig.
|
|
39
|
+
qat_config (QATConfig): QAT configuration
|
|
40
|
+
fw_info: Framework information (e.g., mapping from layers to their attributes to quantize).
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
weights_quantizers: A dictionary between a weight's name to its quantizer.
|
|
44
|
+
activation_quantizers: A list of activations quantization, one for each layer output.).
|
|
45
|
+
"""
|
|
46
|
+
if len(n.candidates_quantization_cfg) > 1:
|
|
47
|
+
wq_cand, aq_cand = get_trainable_quantizer_quantization_candidates(n)
|
|
48
|
+
else:
|
|
49
|
+
wq_cand, aq_cand = None, None
|
|
50
|
+
|
|
51
|
+
weight_quantizers = {}
|
|
52
|
+
if n.is_weights_quantization_enabled():
|
|
53
|
+
quant_method = n.final_weights_quantization_cfg.weights_quantization_method
|
|
54
|
+
quantizer_class = get_trainable_quantizer_class(QuantizationTarget.Weights,
|
|
55
|
+
qat_config.weight_training_method,
|
|
56
|
+
quant_method,
|
|
57
|
+
BasePytorchQATTrainableQuantizer)
|
|
58
|
+
attributes = fw_info.get_kernel_op_attributes(n.type)
|
|
59
|
+
for attr in attributes:
|
|
60
|
+
weight_quantizers.update({attr: quantizer_class(get_trainable_quantizer_weights_config(n, wq_cand),
|
|
61
|
+
**qat_config.weight_quantizer_params_override)})
|
|
62
|
+
|
|
63
|
+
activation_quantizers = []
|
|
64
|
+
if n.is_activation_quantization_enabled():
|
|
65
|
+
quant_method = n.final_activation_quantization_cfg.activation_quantization_method
|
|
66
|
+
quantizer_class = get_trainable_quantizer_class(QuantizationTarget.Activation,
|
|
67
|
+
qat_config.activation_training_method,
|
|
68
|
+
quant_method,
|
|
69
|
+
BasePytorchQATTrainableQuantizer)
|
|
70
|
+
|
|
71
|
+
activation_quantizers = [quantizer_class(get_trainable_quantizer_activation_config(n, aq_cand),
|
|
72
|
+
**qat_config.activation_quantizer_params_override)]
|
|
73
|
+
|
|
74
|
+
return weight_quantizers, activation_quantizers
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def ste_round(x: torch.Tensor) -> torch.Tensor:
|
|
20
|
+
"""
|
|
21
|
+
Calculate the rounded values of a tensor
|
|
22
|
+
Args:
|
|
23
|
+
x: input variable
|
|
24
|
+
Returns:
|
|
25
|
+
rounded value
|
|
26
|
+
"""
|
|
27
|
+
return (torch.round(x) - x).detach() + x
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def ste_clip(x: torch.Tensor, min_val=-1.0, max_val=1.0) -> torch.Tensor:
|
|
31
|
+
"""
|
|
32
|
+
Clip a variable between fixed values such that min_val<=output<=max_val
|
|
33
|
+
Args:
|
|
34
|
+
x: input variable
|
|
35
|
+
min_val: minimum value for clipping
|
|
36
|
+
max_val: maximum value for clipping
|
|
37
|
+
Returns:
|
|
38
|
+
clipped variable
|
|
39
|
+
"""
|
|
40
|
+
return (torch.clip(x, min=min_val, max=max_val) - x).detach() + x
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def fix_range_to_include_zero(range_min: torch.Tensor,
|
|
44
|
+
range_max: torch.Tensor,
|
|
45
|
+
n_bits: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
46
|
+
"""
|
|
47
|
+
Adjusting the quantization range to include representation of 0.0 in the quantization grid.
|
|
48
|
+
If quantization per-channel, then range_min and range_max should be tensors in the specific shape that allows
|
|
49
|
+
quantization along the channel_axis.
|
|
50
|
+
Args:
|
|
51
|
+
range_min: min bound of the quantization range (before adjustment).
|
|
52
|
+
range_max: max bound of the quantization range (before adjustment).
|
|
53
|
+
n_bits: Number of bits to quantize the tensor.
|
|
54
|
+
Returns: adjusted quantization range
|
|
55
|
+
"""
|
|
56
|
+
min_positive = range_min > 0
|
|
57
|
+
max_negative = range_max < 0
|
|
58
|
+
mid_range = torch.logical_and(torch.logical_not(min_positive), torch.logical_not(max_negative))
|
|
59
|
+
min_positive = min_positive.float()
|
|
60
|
+
max_negative = max_negative.float()
|
|
61
|
+
mid_range = mid_range.float()
|
|
62
|
+
|
|
63
|
+
scale = (range_max - range_min) / (2 ** n_bits - 1)
|
|
64
|
+
min_range_adj = scale * torch.round(range_min / scale)
|
|
65
|
+
max_range_adj = range_max - range_min + min_range_adj
|
|
66
|
+
|
|
67
|
+
min_range_adj = min_range_adj * mid_range + max_negative * range_min
|
|
68
|
+
max_range_adj = max_range_adj * mid_range + min_positive * range_max
|
|
69
|
+
return min_range_adj, max_range_adj
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def symmetric_quantizer(tensor_data: torch.Tensor,
|
|
73
|
+
threshold: torch.Tensor,
|
|
74
|
+
n_bits: int,
|
|
75
|
+
sign: bool = False) -> torch.Tensor:
|
|
76
|
+
"""
|
|
77
|
+
Quantize a tensor according to the number of bits and threshold.
|
|
78
|
+
Symmetric quantization.
|
|
79
|
+
Args:
|
|
80
|
+
tensor_data: Tensor values to quantize.
|
|
81
|
+
threshold: threshold for quantization.
|
|
82
|
+
n_bits: Number of bits to quantize the tensor.
|
|
83
|
+
sign: sign of tensor_data
|
|
84
|
+
Returns:
|
|
85
|
+
Quantized data.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
# Compute the step size of quantized values.
|
|
89
|
+
n_pos = 2 ** (n_bits - int(sign))
|
|
90
|
+
delta_tensor = threshold / n_pos
|
|
91
|
+
|
|
92
|
+
# Compute min/max int value
|
|
93
|
+
min_val = -int(sign) * n_pos
|
|
94
|
+
max_val = n_pos - 1
|
|
95
|
+
|
|
96
|
+
# Apply rounding
|
|
97
|
+
input_tensor_int = ste_round(tensor_data / delta_tensor)
|
|
98
|
+
|
|
99
|
+
# Clip data in range
|
|
100
|
+
clipped_tensor = ste_clip(input_tensor_int, min_val=min_val, max_val=max_val)
|
|
101
|
+
|
|
102
|
+
# Quantize the data between -threshold/threshold
|
|
103
|
+
q = delta_tensor * clipped_tensor
|
|
104
|
+
return q
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def uniform_quantizer(tensor_data: torch.Tensor,
|
|
108
|
+
range_min: torch.Tensor,
|
|
109
|
+
range_max: torch.Tensor,
|
|
110
|
+
n_bits: int) -> torch.Tensor:
|
|
111
|
+
"""
|
|
112
|
+
Quantize a tensor according to given range (min, max) and number of bits.
|
|
113
|
+
Uniform quantization.
|
|
114
|
+
Args:
|
|
115
|
+
tensor_data: Tensor values to quantize.
|
|
116
|
+
range_min: minimum bound of the range for quantization (or array of min values per channel).
|
|
117
|
+
range_max: maximum bound of the range for quantization (or array of max values per channel).
|
|
118
|
+
n_bits: Number of bits to quantize the tensor.
|
|
119
|
+
Returns:
|
|
120
|
+
Quantized data.
|
|
121
|
+
"""
|
|
122
|
+
# adjusts the quantization range so the quantization grid includes zero.
|
|
123
|
+
a, b = fix_range_to_include_zero(range_min, range_max, n_bits)
|
|
124
|
+
|
|
125
|
+
# Compute the step size of quantized values.
|
|
126
|
+
delta_tensor = (b - a) / (2 ** n_bits - 1)
|
|
127
|
+
|
|
128
|
+
# Apply rounding
|
|
129
|
+
input_tensor_int = ste_round((tensor_data - a) / delta_tensor)
|
|
130
|
+
|
|
131
|
+
# Clip data in range
|
|
132
|
+
clipped_tensor = ste_clip(input_tensor_int, min_val=0, max_val=2 ** n_bits - 1)
|
|
133
|
+
|
|
134
|
+
# Quantize the data between min/max of quantization range.
|
|
135
|
+
q = delta_tensor * clipped_tensor + a
|
|
136
|
+
return q
|