mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
from model_compression_toolkit.core.common.constants import FOUND_TORCH
|
|
19
|
+
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
20
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer \
|
|
21
|
+
import mark_quantizer, \
|
|
22
|
+
QuantizationTarget
|
|
23
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants \
|
|
24
|
+
import MULTIPLIER_N_BITS, EPS
|
|
25
|
+
|
|
26
|
+
if FOUND_TORCH:
|
|
27
|
+
import torch
|
|
28
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizer_utils \
|
|
29
|
+
import to_torch_tensor, get_working_device, lut_quantizer
|
|
30
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers \
|
|
31
|
+
.base_lut_symmetric_inferable_quantizer import BaseLUTSymmetricInferableQuantizer
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@mark_quantizer(quantization_target=QuantizationTarget.Weights,
|
|
35
|
+
quantization_method=[QuantizationMethod.LUT_SYM_QUANTIZER],
|
|
36
|
+
quantizer_type=None)
|
|
37
|
+
class WeightsLUTSymmetricInferableQuantizer(BaseLUTSymmetricInferableQuantizer):
|
|
38
|
+
"""
|
|
39
|
+
Class for quantizing weights using a lut symmetric quantizer
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self,
|
|
43
|
+
num_bits: int,
|
|
44
|
+
cluster_centers: np.ndarray,
|
|
45
|
+
threshold: np.ndarray,
|
|
46
|
+
per_channel: bool,
|
|
47
|
+
channel_axis: int = None,
|
|
48
|
+
multiplier_n_bits: int = MULTIPLIER_N_BITS,
|
|
49
|
+
eps: float = EPS):
|
|
50
|
+
"""
|
|
51
|
+
Initialize the quantizer with the specified parameters.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
num_bits: number of bits to use for quantization
|
|
55
|
+
cluster_centers: the cluster centers to assign the weights
|
|
56
|
+
threshold: threshold for quantizing weights
|
|
57
|
+
per_channel: whether to use per-channel quantization
|
|
58
|
+
channel_axis: Axis of input to apply per-channel quantization on
|
|
59
|
+
multiplier_n_bits: Number of bits that determines the quantization range
|
|
60
|
+
eps: Small value for numerical stability in division
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
super(WeightsLUTSymmetricInferableQuantizer, self).__init__(threshold=threshold,
|
|
64
|
+
num_bits=num_bits,
|
|
65
|
+
cluster_centers=cluster_centers,
|
|
66
|
+
signed=True,
|
|
67
|
+
multiplier_n_bits=multiplier_n_bits,
|
|
68
|
+
eps=eps)
|
|
69
|
+
|
|
70
|
+
if per_channel:
|
|
71
|
+
assert channel_axis is not None, f'Channel axis is missing in per channel quantization'
|
|
72
|
+
assert len(
|
|
73
|
+
threshold) >= 1, f'In per-channel quantization threshold should be of length >= 1 but is ' \
|
|
74
|
+
f'{len(threshold)}'
|
|
75
|
+
else:
|
|
76
|
+
assert len(
|
|
77
|
+
threshold) == 1, f'In per-tensor quantization threshold should be of length 1 but is ' \
|
|
78
|
+
f'{len(threshold)}'
|
|
79
|
+
|
|
80
|
+
self.per_channel = per_channel
|
|
81
|
+
self.channel_axis = channel_axis
|
|
82
|
+
|
|
83
|
+
self.threshold = to_torch_tensor(self.threshold).to(get_working_device())
|
|
84
|
+
self.cluster_centers = to_torch_tensor(self.cluster_centers).to(get_working_device())
|
|
85
|
+
|
|
86
|
+
def __call__(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
87
|
+
"""
|
|
88
|
+
Quantize the given inputs using the quantizer parameters.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
inputs: input tensor to quantize
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
quantized tensor.
|
|
95
|
+
"""
|
|
96
|
+
inputs.requires_grad = False
|
|
97
|
+
return lut_quantizer(inputs, cluster_centers=self.cluster_centers, signed=True,
|
|
98
|
+
threshold=self.threshold, multiplier_n_bits=self.multiplier_n_bits, eps=self.eps)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
else:
|
|
102
|
+
class WeightsLUTSymmetricInferableQuantizer: # pragma: no cover
|
|
103
|
+
def __init__(self, *args, **kwargs):
|
|
104
|
+
raise Exception('Installing torch is mandatory '
|
|
105
|
+
'when using WeightsLUTSymmetricInferableQuantizer. '
|
|
106
|
+
'Could not find torch package.')
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
from model_compression_toolkit.core.common.constants import FOUND_TORCH
|
|
19
|
+
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
20
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
|
|
21
|
+
QuantizationTarget
|
|
22
|
+
|
|
23
|
+
if FOUND_TORCH:
|
|
24
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.weights_inferable_quantizers.weights_symmetric_inferable_quantizer import \
|
|
25
|
+
WeightsSymmetricInferableQuantizer
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@mark_quantizer(quantization_target=QuantizationTarget.Weights,
|
|
29
|
+
quantization_method=[QuantizationMethod.POWER_OF_TWO],
|
|
30
|
+
quantizer_type=None)
|
|
31
|
+
class WeightsPOTInferableQuantizer(WeightsSymmetricInferableQuantizer):
|
|
32
|
+
"""
|
|
33
|
+
Class for quantizing weights using power-of-two quantizer
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self,
|
|
37
|
+
num_bits: int,
|
|
38
|
+
threshold: np.ndarray,
|
|
39
|
+
per_channel: bool,
|
|
40
|
+
channel_axis: int = None
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Initialize the quantizer with the specified parameters.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
num_bits: number of bits to use for quantization
|
|
47
|
+
threshold: threshold for quantizing activations
|
|
48
|
+
per_channel: whether to use per-channel quantization
|
|
49
|
+
channel_axis: Axis of input to apply per-channel quantization on.
|
|
50
|
+
"""
|
|
51
|
+
# target of Weights quantization
|
|
52
|
+
super(WeightsPOTInferableQuantizer, self).__init__(num_bits=num_bits,
|
|
53
|
+
threshold=threshold,
|
|
54
|
+
per_channel=per_channel,
|
|
55
|
+
channel_axis=channel_axis)
|
|
56
|
+
|
|
57
|
+
is_threshold_pot = np.all(np.round(np.log2(threshold.flatten()))==np.log2(threshold.flatten()))
|
|
58
|
+
assert is_threshold_pot, f'Expected threshold to be power of 2 but is {threshold}'
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
else:
|
|
62
|
+
class WeightsPOTInferableQuantizer: # pragma: no cover
|
|
63
|
+
def __init__(self, *args, **kwargs):
|
|
64
|
+
raise Exception('Installing torch is mandatory '
|
|
65
|
+
'when using WeightsPOTInferableQuantizer. '
|
|
66
|
+
'Could not find torch package.')
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
from model_compression_toolkit.core.common.constants import FOUND_TORCH
|
|
19
|
+
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
20
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer, \
|
|
21
|
+
QuantizationTarget
|
|
22
|
+
|
|
23
|
+
if FOUND_TORCH:
|
|
24
|
+
import torch
|
|
25
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizer_utils import to_torch_tensor, \
|
|
26
|
+
get_working_device
|
|
27
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.base_symmetric_inferable_quantizer import \
|
|
28
|
+
BaseSymmetricInferableQuantizer
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@mark_quantizer(quantization_target=QuantizationTarget.Weights,
|
|
32
|
+
quantization_method=[QuantizationMethod.SYMMETRIC],
|
|
33
|
+
quantizer_type=None)
|
|
34
|
+
class WeightsSymmetricInferableQuantizer(BaseSymmetricInferableQuantizer):
|
|
35
|
+
"""
|
|
36
|
+
Class for quantizing weights using a symmetric quantizer
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self,
|
|
40
|
+
num_bits: int,
|
|
41
|
+
threshold: np.ndarray,
|
|
42
|
+
per_channel: bool,
|
|
43
|
+
channel_axis: int = None
|
|
44
|
+
):
|
|
45
|
+
"""
|
|
46
|
+
Initialize the quantizer with the specified parameters.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
num_bits: number of bits to use for quantization
|
|
50
|
+
threshold: threshold for quantizing weights
|
|
51
|
+
per_channel: whether to use per-channel quantization
|
|
52
|
+
channel_axis: Axis of input to apply per-channel quantization on.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
super(WeightsSymmetricInferableQuantizer, self).__init__(threshold=threshold,
|
|
56
|
+
num_bits=num_bits,
|
|
57
|
+
signed=True)
|
|
58
|
+
|
|
59
|
+
if per_channel:
|
|
60
|
+
assert channel_axis is not None, f'Channel axis is missing in per channel quantization'
|
|
61
|
+
assert len(
|
|
62
|
+
threshold) >= 1, f'In per-channel quantization threshold should be of length >= 1 but is ' \
|
|
63
|
+
f'{len(threshold)}'
|
|
64
|
+
else:
|
|
65
|
+
assert len(
|
|
66
|
+
threshold) == 1, f'In per-tensor quantization threshold should be of length 1 but is {len(threshold)}'
|
|
67
|
+
|
|
68
|
+
self.per_channel = per_channel
|
|
69
|
+
self.channel_axis = channel_axis
|
|
70
|
+
|
|
71
|
+
self.scales = to_torch_tensor(self.scales).to(get_working_device())
|
|
72
|
+
self.zero_points = torch.zeros(len(threshold), dtype=torch.int32).to(get_working_device())
|
|
73
|
+
|
|
74
|
+
def __call__(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
75
|
+
"""
|
|
76
|
+
Quantize the given inputs using the quantizer parameters.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
inputs: input tensor to quantize
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
quantized tensor.
|
|
83
|
+
"""
|
|
84
|
+
inputs.requires_grad = False
|
|
85
|
+
if self.per_channel:
|
|
86
|
+
return torch.fake_quantize_per_channel_affine(inputs,
|
|
87
|
+
self.scales,
|
|
88
|
+
self.zero_points,
|
|
89
|
+
axis=self.channel_axis,
|
|
90
|
+
quant_min=self.min_quantized_domain,
|
|
91
|
+
quant_max=self.max_quantized_domain)
|
|
92
|
+
return torch.fake_quantize_per_tensor_affine(inputs,
|
|
93
|
+
self.scales,
|
|
94
|
+
self.zero_points,
|
|
95
|
+
quant_min=self.min_quantized_domain,
|
|
96
|
+
quant_max=self.max_quantized_domain)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
else:
|
|
100
|
+
class WeightsSymmetricInferableQuantizer: # pragma: no cover
|
|
101
|
+
def __init__(self, *args, **kwargs):
|
|
102
|
+
raise Exception('Installing torch is mandatory '
|
|
103
|
+
'when using WeightsSymmetricInferableQuantizer. '
|
|
104
|
+
'Could not find torch package.')
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
from model_compression_toolkit.core.common.constants import FOUND_TORCH
|
|
19
|
+
from model_compression_toolkit.core.common.logger import Logger
|
|
20
|
+
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget, \
|
|
22
|
+
mark_quantizer
|
|
23
|
+
|
|
24
|
+
if FOUND_TORCH:
|
|
25
|
+
import torch
|
|
26
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizer_utils import get_working_device, \
|
|
27
|
+
fix_range_to_include_zero, to_torch_tensor
|
|
28
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers.base_uniform_inferable_quantizer import \
|
|
29
|
+
BaseUniformInferableQuantizer
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@mark_quantizer(quantization_target=QuantizationTarget.Weights,
|
|
33
|
+
quantization_method=[QuantizationMethod.UNIFORM],
|
|
34
|
+
quantizer_type=None)
|
|
35
|
+
class WeightsUniformInferableQuantizer(BaseUniformInferableQuantizer):
|
|
36
|
+
"""
|
|
37
|
+
Class for quantizing weights using a uniform quantizer
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self,
|
|
41
|
+
num_bits: int,
|
|
42
|
+
min_range: np.ndarray,
|
|
43
|
+
max_range: np.ndarray,
|
|
44
|
+
per_channel: bool,
|
|
45
|
+
channel_axis: int = None
|
|
46
|
+
):
|
|
47
|
+
"""
|
|
48
|
+
Initialize the quantizer with the specified parameters.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
num_bits: number of bits to use for quantization
|
|
52
|
+
min_range: min quantization range for quantizing weights
|
|
53
|
+
max_range: max quantization range for quantizing weights
|
|
54
|
+
per_channel: whether to use per-channel quantization
|
|
55
|
+
channel_axis: Axis of input to apply per-channel quantization on.
|
|
56
|
+
"""
|
|
57
|
+
super(WeightsUniformInferableQuantizer, self).__init__(num_bits=num_bits,
|
|
58
|
+
min_range=min_range,
|
|
59
|
+
max_range=max_range)
|
|
60
|
+
|
|
61
|
+
# Align mix/max numpy arrays so they are torch Tensors on the working device
|
|
62
|
+
min_range = to_torch_tensor(min_range).to(get_working_device())
|
|
63
|
+
max_range = to_torch_tensor(max_range).to(get_working_device())
|
|
64
|
+
|
|
65
|
+
self.per_channel = per_channel
|
|
66
|
+
self.channel_axis = channel_axis
|
|
67
|
+
|
|
68
|
+
min_range, max_range = fix_range_to_include_zero(min_range,
|
|
69
|
+
max_range,
|
|
70
|
+
num_bits)
|
|
71
|
+
# Compute the step size of quantized values.
|
|
72
|
+
self.scales = (max_range - min_range) / (2 ** num_bits - 1)
|
|
73
|
+
self.zero_points = -(
|
|
74
|
+
min_range / self.scales).int() # zp has to be positive, and a <=0, so we multiply by -1
|
|
75
|
+
|
|
76
|
+
self.scales = self.scales.to(get_working_device())
|
|
77
|
+
self.zero_points = self.zero_points.to(get_working_device())
|
|
78
|
+
|
|
79
|
+
def __call__(self,
|
|
80
|
+
inputs: torch.Tensor) -> torch.Tensor:
|
|
81
|
+
"""
|
|
82
|
+
Weight fake quantizer
|
|
83
|
+
Args:
|
|
84
|
+
inputs: weights to quantize.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
quantized weights
|
|
88
|
+
"""
|
|
89
|
+
inputs.requires_grad = False
|
|
90
|
+
if self.per_channel:
|
|
91
|
+
return torch.fake_quantize_per_channel_affine(inputs,
|
|
92
|
+
self.scales.flatten(),
|
|
93
|
+
self.zero_points.flatten(),
|
|
94
|
+
axis=self.channel_axis,
|
|
95
|
+
quant_min=self.min_quantized_domain,
|
|
96
|
+
quant_max=self.max_quantized_domain)
|
|
97
|
+
return torch.fake_quantize_per_tensor_affine(inputs,
|
|
98
|
+
self.scales,
|
|
99
|
+
self.zero_points,
|
|
100
|
+
quant_min=self.min_quantized_domain,
|
|
101
|
+
quant_max=self.max_quantized_domain)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
else:
|
|
105
|
+
class WeightsUniformInferableQuantizer: # pragma: no cover
|
|
106
|
+
def __init__(self, *args, **kwargs):
|
|
107
|
+
Logger.error('Installing torch is mandatory '
|
|
108
|
+
'when using WeightsUniformInferableQuantizer. '
|
|
109
|
+
'Could not find torch package.')
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from abc import abstractmethod
|
|
16
|
+
from enum import Enum
|
|
17
|
+
from typing import Union, List, Any
|
|
18
|
+
from inspect import signature
|
|
19
|
+
|
|
20
|
+
from model_compression_toolkit.core import common
|
|
21
|
+
from model_compression_toolkit.core.common import Logger
|
|
22
|
+
|
|
23
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer, \
|
|
24
|
+
QuantizationTarget
|
|
25
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
|
|
26
|
+
TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig
|
|
27
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import QUANTIZATION_METHOD, \
|
|
28
|
+
QUANTIZATION_TARGET
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
VAR = 'var'
|
|
32
|
+
GROUP = 'group'
|
|
33
|
+
|
|
34
|
+
class VariableGroup(Enum):
|
|
35
|
+
"""
|
|
36
|
+
An enum for choosing trainable variable group
|
|
37
|
+
0. WEIGHTS
|
|
38
|
+
1. QPARAMS
|
|
39
|
+
"""
|
|
40
|
+
WEIGHTS = 0
|
|
41
|
+
QPARAMS = 1
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class BaseTrainableQuantizer(BaseInferableQuantizer):
|
|
45
|
+
def __init__(self,
|
|
46
|
+
quantization_config: Union[TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig]):
|
|
47
|
+
"""
|
|
48
|
+
This class is a base quantizer which validates the provided quantization config and defines an abstract function which any quantizer needs to implment.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
quantization_config: quantizer config class contains all the information about the quantizer configuration.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
# verify the quantizer class that inherits this class only has a config argument and key-word arguments
|
|
55
|
+
for i, (k, v) in enumerate(self.get_sig().parameters.items()):
|
|
56
|
+
if i == 0:
|
|
57
|
+
if v.annotation not in [TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]:
|
|
58
|
+
common.Logger.error(f"First parameter must be either TrainableQuantizerWeightsConfig or TrainableQuantizerActivationConfig") # pragma: no cover
|
|
59
|
+
elif v.default is v.empty:
|
|
60
|
+
common.Logger.error(f"Parameter {k} doesn't have a default value") # pragma: no cover
|
|
61
|
+
|
|
62
|
+
super(BaseTrainableQuantizer, self).__init__()
|
|
63
|
+
self.quantization_config = quantization_config
|
|
64
|
+
|
|
65
|
+
# Inherited class should be decorated with @mark_quantizer decorator, and define the following static properties
|
|
66
|
+
static_quantization_method = getattr(self, QUANTIZATION_METHOD, None)
|
|
67
|
+
static_quantization_target = getattr(self, QUANTIZATION_TARGET, None)
|
|
68
|
+
|
|
69
|
+
if static_quantization_method is None or static_quantization_target is None:
|
|
70
|
+
Logger.error("A quantizer class that inherit from BaseTrainableQuantizer is not defined appropriately."
|
|
71
|
+
"Either it misses the @mark_quantizer decorator or the decorator is not used correctly.")
|
|
72
|
+
|
|
73
|
+
if static_quantization_target == QuantizationTarget.Weights:
|
|
74
|
+
self.validate_weights()
|
|
75
|
+
if self.quantization_config.weights_quantization_method not in static_quantization_method:
|
|
76
|
+
common.Logger.error(
|
|
77
|
+
f'Quantization method mismatch expected: {static_quantization_method} and got {self.quantization_config.weights_quantization_method}')
|
|
78
|
+
elif static_quantization_target == QuantizationTarget.Activation:
|
|
79
|
+
self.validate_activation()
|
|
80
|
+
if self.quantization_config.activation_quantization_method not in static_quantization_method:
|
|
81
|
+
common.Logger.error(
|
|
82
|
+
f'Quantization method mismatch expected: {static_quantization_method} and got {self.quantization_config.activation_quantization_method}')
|
|
83
|
+
else:
|
|
84
|
+
common.Logger.error(
|
|
85
|
+
f'Unknown Quantization Part:{static_quantization_target}') # pragma: no cover
|
|
86
|
+
|
|
87
|
+
self.quantizer_parameters = {}
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def get_sig(cls):
|
|
91
|
+
return signature(cls)
|
|
92
|
+
|
|
93
|
+
def initialize_quantization(self,
|
|
94
|
+
tensor_shape,
|
|
95
|
+
name: str,
|
|
96
|
+
layer):
|
|
97
|
+
"""
|
|
98
|
+
This initializes the quantizer parameters given the parameter name and shape.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
tensor_shape: tensor shape
|
|
102
|
+
name: tensor name
|
|
103
|
+
layer: layer to quantized
|
|
104
|
+
|
|
105
|
+
Returns: None
|
|
106
|
+
|
|
107
|
+
"""
|
|
108
|
+
raise NotImplemented # pragma: no cover
|
|
109
|
+
|
|
110
|
+
def __call__(self,
|
|
111
|
+
input2quantize,
|
|
112
|
+
training: bool):
|
|
113
|
+
"""
|
|
114
|
+
Quantize a tensor.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
input2quantize: Input tensor to quantize.
|
|
118
|
+
training: Whether the graph is in training mode.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
The quantized tensor.
|
|
122
|
+
"""
|
|
123
|
+
raise NotImplemented # pragma: no cover
|
|
124
|
+
|
|
125
|
+
def activation_quantization(self) -> bool:
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
Returns: A boolean stating is this activation quantizer
|
|
129
|
+
|
|
130
|
+
"""
|
|
131
|
+
return isinstance(self.quantization_config, TrainableQuantizerActivationConfig)
|
|
132
|
+
|
|
133
|
+
def weights_quantization(self) -> bool:
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
Returns: A boolean stating is this weights quantizer
|
|
137
|
+
|
|
138
|
+
"""
|
|
139
|
+
return isinstance(self.quantization_config, TrainableQuantizerWeightsConfig)
|
|
140
|
+
|
|
141
|
+
def validate_weights(self) -> None:
|
|
142
|
+
"""
|
|
143
|
+
This function validates the quantization config compared with its parameters.
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
"""
|
|
147
|
+
if self.activation_quantization() or not self.weights_quantization():
|
|
148
|
+
common.Logger.error(f'Expect weight quantization got activation')
|
|
149
|
+
|
|
150
|
+
def validate_activation(self) -> None:
|
|
151
|
+
"""
|
|
152
|
+
This function validates the quantization config compared with its parameters.
|
|
153
|
+
|
|
154
|
+
"""
|
|
155
|
+
if not self.activation_quantization() or self.weights_quantization():
|
|
156
|
+
common.Logger.error(f'Expect activation quantization got weight')
|
|
157
|
+
|
|
158
|
+
def convert2inferable(self) -> BaseInferableQuantizer:
|
|
159
|
+
"""
|
|
160
|
+
Convert quantizer to inferable quantizer.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
BaseInferableQuantizer object.
|
|
164
|
+
"""
|
|
165
|
+
raise NotImplemented # pragma: no cover
|
|
166
|
+
|
|
167
|
+
def add_quantizer_variable(self, name: str, variable: Any, group: VariableGroup = VariableGroup.WEIGHTS):
|
|
168
|
+
"""
|
|
169
|
+
Add a quantizer variable to quantizer_parameters dictionary
|
|
170
|
+
"""
|
|
171
|
+
self.quantizer_parameters.update({name: {VAR: variable, GROUP: group}})
|
|
172
|
+
|
|
173
|
+
def get_quantizer_variable(self, name: str) -> Any:
|
|
174
|
+
"""
|
|
175
|
+
Get a quantizer variable by name
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
name: variable name
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
trainable variable
|
|
182
|
+
"""
|
|
183
|
+
if name in self.quantizer_parameters:
|
|
184
|
+
return self.quantizer_parameters[name][VAR]
|
|
185
|
+
else:
|
|
186
|
+
common.Logger.error(f'Variable {name} is not exist in quantizers parameters!') # pragma: no cover
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@abstractmethod
|
|
190
|
+
def get_trainable_variables(self, group: VariableGroup) -> List[Any]:
|
|
191
|
+
"""
|
|
192
|
+
Get trainable parameters with specific group from quantizer
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
group: Enum of variable group
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
List of trainable variables
|
|
199
|
+
"""
|
|
200
|
+
raise NotImplemented # pragma: no cover
|