mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -14,35 +14,69 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Any
|
|
16
16
|
|
|
17
|
-
from keras.engine.input_layer import InputLayer
|
|
18
|
-
from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapperV2
|
|
19
17
|
|
|
20
18
|
from model_compression_toolkit.core.common import Logger
|
|
21
|
-
from model_compression_toolkit.
|
|
22
|
-
SUPPORTED_QUANTIZATION_CONFIG
|
|
23
|
-
from model_compression_toolkit.exporter.model_wrapper.keras.extended_quantize_wrapper import ExtendedQuantizeWrapper
|
|
19
|
+
from model_compression_toolkit.core.common.constants import FOUND_TF
|
|
24
20
|
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer
|
|
25
22
|
|
|
26
|
-
def is_keras_layer_exportable(layer: Any) -> bool:
|
|
27
|
-
"""
|
|
28
|
-
Check whether a Keras layer is a valid exportable layer or not.
|
|
29
23
|
|
|
30
|
-
|
|
31
|
-
|
|
24
|
+
if FOUND_TF:
|
|
25
|
+
from keras.engine.input_layer import InputLayer
|
|
26
|
+
from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
|
|
32
27
|
|
|
33
|
-
|
|
28
|
+
def is_keras_layer_exportable(layer: Any) -> bool:
|
|
29
|
+
"""
|
|
34
30
|
Check whether a Keras layer is a valid exportable layer or not.
|
|
35
|
-
"""
|
|
36
|
-
# Keras Input layers are not wrapped
|
|
37
|
-
if isinstance(layer, InputLayer):
|
|
38
|
-
return True
|
|
39
31
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
32
|
+
Args:
|
|
33
|
+
layer: Keras layer to check if considered to be valid for exporting.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Check whether a Keras layer is a valid exportable layer or not.
|
|
37
|
+
"""
|
|
38
|
+
# Keras Input layers are not wrapped
|
|
39
|
+
if isinstance(layer, InputLayer):
|
|
40
|
+
return True
|
|
41
|
+
|
|
42
|
+
valid_layer = isinstance(layer, KerasQuantizationWrapper)
|
|
43
|
+
if not valid_layer:
|
|
44
|
+
Logger.error(
|
|
45
|
+
f'Exportable layer must be wrapped using KerasQuantizationWrapper, but layer {layer.name} is of type '
|
|
46
|
+
f'{type(layer)}') # pragma: no cover
|
|
47
|
+
|
|
48
|
+
valid_weights_quantizers = isinstance(layer.weights_quantizers, dict)
|
|
49
|
+
if not valid_weights_quantizers:
|
|
50
|
+
Logger.error(
|
|
51
|
+
f'KerasQuantizationWrapper must have a weights_quantizers but has a '
|
|
52
|
+
f'{type(layer.weights_quantizers)} object') # pragma: no cover
|
|
43
53
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
54
|
+
for _, weights_quantizer in layer.weights_quantizers.items():
|
|
55
|
+
if not isinstance(weights_quantizer, BaseInferableQuantizer):
|
|
56
|
+
Logger.error(
|
|
57
|
+
f'weights_quantizer must be a BaseInferableQuantizer object but has a '
|
|
58
|
+
f'{type(weights_quantizer)} object') # pragma: no cover
|
|
47
59
|
|
|
48
|
-
|
|
60
|
+
valid_activation_quantizers = isinstance(layer.activation_quantizers, list)
|
|
61
|
+
if not valid_activation_quantizers:
|
|
62
|
+
Logger.error(
|
|
63
|
+
f'KerasQuantizationWrapper must have a activation_quantizers list but has a '
|
|
64
|
+
f'{type(layer.activation_quantizers)} object') # pragma: no cover
|
|
65
|
+
|
|
66
|
+
for activation_quantizers in layer.activation_quantizers:
|
|
67
|
+
if not isinstance(activation_quantizers, BaseInferableQuantizer):
|
|
68
|
+
Logger.error(
|
|
69
|
+
f'activation_quantizers must be a BaseInferableQuantizer object but has a '
|
|
70
|
+
f'{type(activation_quantizers)} object') # pragma: no cover
|
|
71
|
+
|
|
72
|
+
quantizers = layer.activation_quantizers + list(layer.weights_quantizers.values())
|
|
73
|
+
is_valid_quantizers = all([isinstance(x, BaseInferableQuantizer) for x in quantizers])
|
|
74
|
+
if not is_valid_quantizers:
|
|
75
|
+
Logger.error(f'Found a quantizer that is not of type BaseInferableQuantizer') # pragma: no cover
|
|
76
|
+
|
|
77
|
+
return True
|
|
78
|
+
else:
|
|
79
|
+
def is_keras_layer_exportable(*args, **kwargs): # pragma: no cover
|
|
80
|
+
Logger.error('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
81
|
+
'when using is_keras_layer_exportable. '
|
|
82
|
+
'Could not find Tensorflow package.')
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -13,132 +13,49 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from typing import List, Any, Tuple
|
|
17
|
-
|
|
18
|
-
import torch
|
|
19
16
|
|
|
17
|
+
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
20
18
|
from model_compression_toolkit.core import common
|
|
21
|
-
from model_compression_toolkit.core.common import
|
|
22
|
-
from model_compression_toolkit.core.common.
|
|
23
|
-
from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder
|
|
24
|
-
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder, \
|
|
25
|
-
PytorchModel
|
|
26
|
-
from model_compression_toolkit.core.pytorch.back2framework.quantization_wrapper.quantized_layer_wrapper import \
|
|
27
|
-
QuantizedLayerWrapper
|
|
28
|
-
from model_compression_toolkit.core.pytorch.constants import BUFFER, CONSTANT
|
|
29
|
-
from model_compression_toolkit.core.pytorch.reader.node_holders import BufferHolder, ConstantHolder
|
|
30
|
-
from model_compression_toolkit.core.pytorch.utils import get_working_device
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantize_config import get_quantization_config
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def get_fully_quantized_pytorch_model(graph: Graph):
|
|
37
|
-
"""
|
|
38
|
-
Convert graph to fully quantized PyTorch model.
|
|
19
|
+
from model_compression_toolkit.core.common import Graph, Logger
|
|
20
|
+
from model_compression_toolkit.core.common.constants import FOUND_TORCH
|
|
39
21
|
|
|
40
|
-
|
|
41
|
-
|
|
22
|
+
if FOUND_TORCH:
|
|
23
|
+
import torch
|
|
24
|
+
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
|
|
25
|
+
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizers import \
|
|
26
|
+
get_quantization_quantizers
|
|
42
27
|
|
|
43
|
-
|
|
44
|
-
Fully quantized PyTorch model.
|
|
45
|
-
"""
|
|
46
|
-
return FullyQuantizedPyTorchModelBuilder(graph=graph).build_model()
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class FullyQuantizedPyTorchModel(PytorchModel):
|
|
51
|
-
"""
|
|
52
|
-
PyTorch model with all quantization information.
|
|
53
|
-
"""
|
|
54
|
-
|
|
55
|
-
def __init__(self,
|
|
56
|
-
graph: common.Graph):
|
|
28
|
+
def fully_quantized_wrapper(node: common.BaseNode, module: torch.nn.Module) -> qi.PytorchQuantizationWrapper:
|
|
57
29
|
"""
|
|
30
|
+
A function which takes a computational graph node and a pytorch module and
|
|
31
|
+
perform the quantization wrapping
|
|
58
32
|
|
|
59
33
|
Args:
|
|
60
|
-
|
|
61
|
-
|
|
34
|
+
node: A node of mct graph.
|
|
35
|
+
module: A Pytorch module
|
|
62
36
|
|
|
63
|
-
|
|
37
|
+
Returns: Wrapped layer
|
|
64
38
|
|
|
65
|
-
|
|
66
|
-
def _add_modules(self):
|
|
67
39
|
"""
|
|
40
|
+
weight_quantizers, activation_quantizers = get_quantization_quantizers(node)
|
|
41
|
+
wrapped_layer = qi.PytorchQuantizationWrapper(module, weight_quantizers, activation_quantizers)
|
|
42
|
+
return wrapped_layer
|
|
68
43
|
|
|
69
|
-
Add nodes in graph as modules.
|
|
70
|
-
|
|
71
|
-
"""
|
|
72
|
-
for n in self.node_sort:
|
|
73
|
-
if n.type == BufferHolder:
|
|
74
|
-
self.add_module(n.name, node_builder(n))
|
|
75
|
-
self.get_submodule(n.name).register_buffer(n.name,
|
|
76
|
-
torch.Tensor(n.get_weights_by_keys(BUFFER)).to(get_working_device()))
|
|
77
|
-
elif n.type == ConstantHolder:
|
|
78
|
-
self.add_module(n.name, node_builder(n))
|
|
79
|
-
self.get_submodule(n.name).register_buffer(n.name,
|
|
80
|
-
torch.Tensor(n.get_weights_by_keys(CONSTANT)).to(get_working_device()))
|
|
81
|
-
|
|
82
|
-
else:
|
|
83
|
-
# Create a wrapper based on the corresponding quantization config.
|
|
84
|
-
layer_wrapper = QuantizedLayerWrapper(n, get_quantization_config(n))
|
|
85
|
-
# Add the wrapped layer to the model.
|
|
86
|
-
self.add_module(n.name, layer_wrapper)
|
|
87
|
-
|
|
88
|
-
def _get_op_func(self,
|
|
89
|
-
node: BaseNode,
|
|
90
|
-
configurable_nodes_names: List[str]) -> Any:
|
|
91
|
-
"""
|
|
92
|
-
Get the operator corresponding to the passed node.
|
|
93
|
-
|
|
94
|
-
Args:
|
|
95
|
-
node: Node to get its op.
|
|
96
|
-
configurable_nodes_names: List of nodes that are configurable.
|
|
97
|
-
|
|
98
|
-
Returns:
|
|
99
|
-
Operator (module) of the node.
|
|
100
|
-
"""
|
|
101
|
-
return getattr(self, node.name)
|
|
102
44
|
|
|
103
|
-
def
|
|
104
|
-
node: BaseNode,
|
|
105
|
-
input_tensors: List[torch.Tensor]) -> List[torch.Tensor]:
|
|
45
|
+
def get_exportable_pytorch_model(graph: Graph):
|
|
106
46
|
"""
|
|
107
|
-
|
|
47
|
+
Convert graph to fully quantized PyTorch model.
|
|
108
48
|
|
|
109
49
|
Args:
|
|
110
|
-
|
|
111
|
-
input_tensors: Input tensors of the node.
|
|
50
|
+
graph: Graph to convert to a PyTorch model.
|
|
112
51
|
|
|
113
52
|
Returns:
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
class FullyQuantizedPyTorchModelBuilder(PyTorchModelBuilder):
|
|
124
|
-
"""
|
|
125
|
-
Fully-Quantized PyTorch model.
|
|
126
|
-
"""
|
|
127
|
-
|
|
128
|
-
def __init__(self,
|
|
129
|
-
graph: common.Graph):
|
|
130
|
-
"""
|
|
131
|
-
|
|
132
|
-
Args:
|
|
133
|
-
graph: Graph to build the model from.
|
|
134
|
-
"""
|
|
135
|
-
|
|
136
|
-
super().__init__(graph)
|
|
137
|
-
|
|
138
|
-
def build_model(self) -> Tuple[PytorchModel, UserInformation]:
|
|
139
|
-
"""
|
|
140
|
-
Build a PyTorch fully quantized model and return it.
|
|
141
|
-
Returns: Fully quantized PyTorch model and user information.
|
|
142
|
-
|
|
143
|
-
"""
|
|
144
|
-
return FullyQuantizedPyTorchModel(self.graph), self.graph.user_info
|
|
53
|
+
Fully quantized PyTorch model.
|
|
54
|
+
"""
|
|
55
|
+
return PyTorchModelBuilder(graph=graph,
|
|
56
|
+
wrapper=fully_quantized_wrapper).build_model()
|
|
57
|
+
else:
|
|
58
|
+
def get_exportable_pytorch_model(*args, **kwargs): # pragma: no cover
|
|
59
|
+
Logger.error('Installing torch is mandatory '
|
|
60
|
+
'when using get_exportable_pytorch_model. '
|
|
61
|
+
'Could not find PyTorch package.')
|
|
@@ -12,31 +12,83 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
import numpy as np
|
|
16
15
|
|
|
17
|
-
from typing import
|
|
16
|
+
from typing import Dict, Any
|
|
18
17
|
|
|
19
18
|
from model_compression_toolkit.core.common import BaseNode, Logger
|
|
20
|
-
from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX
|
|
21
|
-
|
|
19
|
+
from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX, \
|
|
20
|
+
SCALE_PER_CHANNEL, CLUSTER_CENTERS
|
|
22
21
|
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
22
|
+
from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
|
|
23
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
|
|
24
|
+
get_inferable_quantizer_class
|
|
25
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
|
|
26
|
+
constants as qi_inferable_quantizers_constants, BasePyTorchInferableQuantizer
|
|
27
|
+
import numpy as np
|
|
23
28
|
|
|
24
|
-
# Supporting other quantizer types in the future
|
|
25
|
-
from model_compression_toolkit.exporter.model_wrapper.pytorch.quantizers.fq_quantizer import FakeQuantQuantizer
|
|
26
|
-
from model_compression_toolkit.exporter.model_wrapper.pytorch.quantizers.uniform_weights_quantizer import \
|
|
27
|
-
UniformWeightsQuantizer
|
|
28
|
-
import torch
|
|
29
29
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
30
|
+
def get_weights_inferable_quantizer_kwargs(node: BaseNode) -> Dict[str, Any]:
|
|
31
|
+
# Get the weights quantization configuration for the node
|
|
32
|
+
node_w_qc = node.final_weights_quantization_cfg
|
|
33
|
+
quantization_method = node_w_qc.weights_quantization_method
|
|
34
|
+
|
|
35
|
+
# Return the appropriate quantization parameters based on the quantization method
|
|
36
|
+
if quantization_method in [QuantizationMethod.POWER_OF_TWO,
|
|
37
|
+
QuantizationMethod.SYMMETRIC]:
|
|
38
|
+
return {qi_inferable_quantizers_constants.NUM_BITS: node_w_qc.weights_n_bits,
|
|
39
|
+
qi_inferable_quantizers_constants.THRESHOLD: node_w_qc.weights_quantization_params[THRESHOLD].flatten(),
|
|
40
|
+
qi_inferable_quantizers_constants.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
|
|
41
|
+
qi_inferable_quantizers_constants.CHANNEL_AXIS: node_w_qc.weights_channels_axis}
|
|
42
|
+
|
|
43
|
+
elif quantization_method in [QuantizationMethod.UNIFORM]:
|
|
44
|
+
return {qi_inferable_quantizers_constants.NUM_BITS: node_w_qc.weights_n_bits,
|
|
45
|
+
qi_inferable_quantizers_constants.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
|
|
46
|
+
qi_inferable_quantizers_constants.MIN_RANGE: node_w_qc.weights_quantization_params[RANGE_MIN].flatten(),
|
|
47
|
+
qi_inferable_quantizers_constants.MAX_RANGE: node_w_qc.weights_quantization_params[RANGE_MAX].flatten(),
|
|
48
|
+
qi_inferable_quantizers_constants.CHANNEL_AXIS: node_w_qc.weights_channels_axis}
|
|
49
|
+
|
|
50
|
+
elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER, QuantizationMethod.LUT_SYM_QUANTIZER]:
|
|
51
|
+
return {qi_inferable_quantizers_constants.NUM_BITS: node_w_qc.weights_n_bits,
|
|
52
|
+
qi_inferable_quantizers_constants.CLUSTER_CENTERS: node_w_qc.weights_quantization_params[CLUSTER_CENTERS].flatten(),
|
|
53
|
+
qi_inferable_quantizers_constants.THRESHOLD: node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten(),
|
|
54
|
+
qi_inferable_quantizers_constants.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
|
|
55
|
+
qi_inferable_quantizers_constants.CHANNEL_AXIS: node_w_qc.weights_channels_axis}
|
|
56
|
+
# TODO: Add MULTIPLIER_N_BITS & EPS to node quantization config
|
|
33
57
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
58
|
+
else:
|
|
59
|
+
Logger.critical(f'Not supported quantization method for weights inferable quantizers.') # pragma: no cover
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_activation_inferable_quantizer_kwargs(node: BaseNode) -> Dict[str, Any]:
|
|
63
|
+
# Get the activation quantization configuration for the node
|
|
64
|
+
node_qc = node.final_activation_quantization_cfg
|
|
65
|
+
quantization_method = node_qc.activation_quantization_method
|
|
66
|
+
|
|
67
|
+
# Return the appropriate quantization parameters based on the quantization method
|
|
68
|
+
if quantization_method in [QuantizationMethod.POWER_OF_TWO,
|
|
69
|
+
QuantizationMethod.SYMMETRIC]:
|
|
70
|
+
return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.activation_n_bits,
|
|
71
|
+
qi_inferable_quantizers_constants.THRESHOLD: np.asarray([node_qc.activation_quantization_params[THRESHOLD]]),
|
|
72
|
+
qi_inferable_quantizers_constants.SIGNED: node_qc.activation_quantization_params.get(SIGNED)}
|
|
73
|
+
|
|
74
|
+
elif quantization_method in [QuantizationMethod.UNIFORM]:
|
|
75
|
+
return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.activation_n_bits,
|
|
76
|
+
qi_inferable_quantizers_constants.MIN_RANGE: np.asarray([node_qc.activation_quantization_params[RANGE_MIN]]),
|
|
77
|
+
qi_inferable_quantizers_constants.MAX_RANGE: np.asarray([node_qc.activation_quantization_params[RANGE_MAX]])}
|
|
78
|
+
|
|
79
|
+
elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER]:
|
|
80
|
+
return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.activation_n_bits,
|
|
81
|
+
qi_inferable_quantizers_constants.CLUSTER_CENTERS: np.asarray(
|
|
82
|
+
[node_qc.activation_quantization_params[CLUSTER_CENTERS]]),
|
|
83
|
+
qi_inferable_quantizers_constants.THRESHOLD: np.asarray(
|
|
84
|
+
[node_qc.activation_quantization_params[THRESHOLD]]),
|
|
85
|
+
qi_inferable_quantizers_constants.SIGNED: node_qc.activation_quantization_params.get(SIGNED)}
|
|
86
|
+
# TODO: Add MULTIPLIER_N_BITS & EPS to node quantization config
|
|
87
|
+
else:
|
|
88
|
+
Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
|
|
37
89
|
|
|
38
90
|
|
|
39
|
-
def get_weights_quantizer_for_node(node: BaseNode) ->
|
|
91
|
+
def get_weights_quantizer_for_node(node: BaseNode) -> BasePyTorchInferableQuantizer:
|
|
40
92
|
"""
|
|
41
93
|
Get weights quantizer for a node.
|
|
42
94
|
|
|
@@ -48,41 +100,20 @@ def get_weights_quantizer_for_node(node: BaseNode) -> List[Callable]:
|
|
|
48
100
|
|
|
49
101
|
"""
|
|
50
102
|
if node.final_weights_quantization_cfg is None:
|
|
51
|
-
Logger.critical(f'Can not set quantizer for a node with no final weights quantization configuration')
|
|
52
|
-
|
|
103
|
+
Logger.critical(f'Can not set quantizer for a node with no final weights quantization configuration') # pragma:
|
|
104
|
+
# no cover
|
|
53
105
|
node_w_qc = node.final_weights_quantization_cfg
|
|
54
106
|
weights_quantization_method = node_w_qc.weights_quantization_method
|
|
55
107
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
if weights_quantization_method in [QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC]:
|
|
61
|
-
weight_thresholds = node_w_qc.weights_quantization_params.get(THRESHOLD)
|
|
62
|
-
assert weight_thresholds is not None
|
|
63
|
-
if weights_quantization_method == QuantizationMethod.POWER_OF_TWO:
|
|
64
|
-
is_threshold_pot = np.all([int(np.log2(x)) == np.log2(x) for x in weight_thresholds.flatten()])
|
|
65
|
-
if not is_threshold_pot:
|
|
66
|
-
Logger.error(f'Expected threshold to be power of 2 but is {weight_thresholds}')
|
|
67
|
-
|
|
68
|
-
min_range = -weight_thresholds
|
|
69
|
-
max_range = weight_thresholds - calculate_delta(weight_thresholds,
|
|
70
|
-
n_bits=node_w_qc.weights_n_bits,
|
|
71
|
-
signed=True)
|
|
72
|
-
|
|
73
|
-
else:
|
|
74
|
-
Logger.error(f'For now fully quantized models support only {SUPPORTED_WEIGHT_QUANTIZER_TYPES} for weights quantization, but found {weights_quantization_method}')
|
|
108
|
+
quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Weights,
|
|
109
|
+
weights_quantization_method,
|
|
110
|
+
BasePyTorchInferableQuantizer)
|
|
111
|
+
kwargs = get_weights_inferable_quantizer_kwargs(node)
|
|
75
112
|
|
|
76
|
-
return
|
|
77
|
-
max_range=max_range,
|
|
78
|
-
min_range=min_range,
|
|
79
|
-
quantization_method=node_w_qc.weights_quantization_method,
|
|
80
|
-
per_channel=node_w_qc.weights_per_channel_threshold,
|
|
81
|
-
output_channels_axis=node_w_qc.weights_channels_axis
|
|
82
|
-
)]
|
|
113
|
+
return quantier_for_node(**kwargs)
|
|
83
114
|
|
|
84
115
|
|
|
85
|
-
def get_activations_quantizer_for_node(node: BaseNode) ->
|
|
116
|
+
def get_activations_quantizer_for_node(node: BaseNode) -> BasePyTorchInferableQuantizer:
|
|
86
117
|
"""
|
|
87
118
|
Get activation quantizer for a node.
|
|
88
119
|
|
|
@@ -93,43 +124,16 @@ def get_activations_quantizer_for_node(node: BaseNode) -> List[Callable]:
|
|
|
93
124
|
Quantizer for the node's activations.
|
|
94
125
|
|
|
95
126
|
"""
|
|
96
|
-
|
|
97
127
|
if node.final_activation_quantization_cfg is None:
|
|
98
|
-
Logger.critical(f'Can not set quantizer for a node with no final activation quantization configuration')
|
|
99
|
-
|
|
128
|
+
Logger.critical(f'Can not set quantizer for a node with no final activation quantization configuration') #
|
|
129
|
+
# pragma: no cover
|
|
100
130
|
node_act_qc = node.final_activation_quantization_cfg
|
|
101
131
|
activation_quantization_method = node_act_qc.activation_quantization_method
|
|
102
132
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
activation_thresholds = node_act_qc.activation_quantization_params.get(THRESHOLD)
|
|
109
|
-
|
|
110
|
-
if activation_quantization_method in [QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC]:
|
|
111
|
-
if activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
|
|
112
|
-
is_threshold_pot = np.all([int(np.log2(x)) == np.log2(x) for x in activation_thresholds.flatten()])
|
|
113
|
-
if not is_threshold_pot:
|
|
114
|
-
Logger.error(f'Expected threshold to be power of 2 but is {node_act_qc.activation_quantization_params.get(THRESHOLD)}')
|
|
133
|
+
quantizer_for_node = get_inferable_quantizer_class(QuantizationTarget.Activation,
|
|
134
|
+
activation_quantization_method,
|
|
135
|
+
BasePyTorchInferableQuantizer)
|
|
136
|
+
kwargs = get_activation_inferable_quantizer_kwargs(node)
|
|
115
137
|
|
|
116
|
-
|
|
117
|
-
if node_act_qc.activation_quantization_params.get(SIGNED):
|
|
118
|
-
min_range = -activation_thresholds
|
|
119
|
-
|
|
120
|
-
max_range = activation_thresholds - calculate_delta(
|
|
121
|
-
activation_thresholds,
|
|
122
|
-
n_bits=node_act_qc.activation_n_bits,
|
|
123
|
-
signed=node_act_qc.activation_quantization_params.get(SIGNED))
|
|
124
|
-
|
|
125
|
-
elif activation_quantization_method in [QuantizationMethod.UNIFORM]:
|
|
126
|
-
min_range = node_act_qc.activation_quantization_params.get(RANGE_MIN)
|
|
127
|
-
max_range = node_act_qc.activation_quantization_params.get(RANGE_MAX)
|
|
128
|
-
|
|
129
|
-
else:
|
|
130
|
-
Logger.error(f'For now fully quantized models support only {SUPPORTED_ACTIVATION_QUANTIZER_TYPES} for activation quantization, but found {activation_quantization_method}')
|
|
138
|
+
return quantizer_for_node(**kwargs)
|
|
131
139
|
|
|
132
|
-
return [FakeQuantQuantizer(nbits=node_act_qc.activation_n_bits,
|
|
133
|
-
min_range=min_range,
|
|
134
|
-
max_range=max_range,
|
|
135
|
-
quantization_method=activation_quantization_method)]
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import Tuple, List, Dict
|
|
16
|
+
from model_compression_toolkit.core.common import BaseNode
|
|
17
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
18
|
+
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizer import \
|
|
19
|
+
get_activations_quantizer_for_node, get_weights_quantizer_for_node
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_quantization_quantizers(node: BaseNode) -> Tuple[Dict, List]:
|
|
23
|
+
"""
|
|
24
|
+
Create quantizers to wrap a layer for its corresponding node.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
node: Node to create quantizers for.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
weight_quantizers: A dictionary between a weight's name to its quantizer.
|
|
31
|
+
activation_quantizers: A list of activations quantization, one for each layer output.
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
weight_quantizers = {}
|
|
35
|
+
if node.is_weights_quantization_enabled():
|
|
36
|
+
weight_attrs = DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(node.type)
|
|
37
|
+
weight_quantizer = get_weights_quantizer_for_node(node)
|
|
38
|
+
for attr in weight_attrs:
|
|
39
|
+
weight_quantizers[attr] = weight_quantizer
|
|
40
|
+
|
|
41
|
+
activation_quantizers = []
|
|
42
|
+
if node.is_activation_quantization_enabled():
|
|
43
|
+
num_of_outputs = len(node.output_shape) if isinstance(node.output_shape, list) else 1
|
|
44
|
+
activation_quantizers = [get_activations_quantizer_for_node(node)] * num_of_outputs
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
return weight_quantizers, activation_quantizers
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit.core.common import Logger
|
|
18
|
+
from model_compression_toolkit.core.common.constants import FOUND_TORCH
|
|
19
|
+
|
|
20
|
+
if FOUND_TORCH:
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
|
|
22
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
|
|
23
|
+
BasePyTorchInferableQuantizer
|
|
24
|
+
def is_pytorch_layer_exportable(layer: Any) -> bool:
|
|
25
|
+
"""
|
|
26
|
+
Check whether a torch Module is a valid exportable module or not.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
layer: PyTorch module to check if considered to be valid for exporting.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Check whether a PyTorch layer is a valid exportable layer or not.
|
|
33
|
+
"""
|
|
34
|
+
if isinstance(layer, PytorchQuantizationWrapper):
|
|
35
|
+
quantizers = list(layer.weights_quantizers.values())
|
|
36
|
+
quantizers.extend(layer.activation_quantizers)
|
|
37
|
+
if all([isinstance(q, BasePyTorchInferableQuantizer) for q in quantizers]):
|
|
38
|
+
return True
|
|
39
|
+
return False
|
|
40
|
+
else:
|
|
41
|
+
def is_pytorch_layer_exportable(*args, **kwargs): # pragma: no cover
|
|
42
|
+
Logger.error('Installing torch is mandatory '
|
|
43
|
+
'when using is_pytorch_layer_exportable. '
|
|
44
|
+
'Could not find PyTorch package.')
|
|
@@ -12,3 +12,9 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, GradientPTQConfigV2
|
|
17
|
+
from model_compression_toolkit.gptq.keras.quantization_facade import keras_gradient_post_training_quantization_experimental
|
|
18
|
+
from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config
|
|
19
|
+
from model_compression_toolkit.gptq.pytorch.quantization_facade import pytorch_gradient_post_training_quantization_experimental
|
|
20
|
+
from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config
|