mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
from typing import Dict
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from model_compression_toolkit.core.common import max_power_of_two
|
|
21
|
+
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
22
|
+
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
23
|
+
from model_compression_toolkit.gptq.common.gptq_config import RoundingType
|
|
24
|
+
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
|
|
25
|
+
BasePytorchGPTQTrainableQuantizer
|
|
26
|
+
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
|
|
27
|
+
from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
|
|
28
|
+
from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
|
|
29
|
+
SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
|
|
30
|
+
from model_compression_toolkit.core.common.constants import THRESHOLD, MIN_THRESHOLD
|
|
31
|
+
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
|
|
32
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
|
|
33
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
|
|
34
|
+
get_threshold_reshape_shape
|
|
35
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def soft_rounding_symmetric_quantizer(input_tensor: torch.Tensor,
|
|
39
|
+
auxvar_tensor: torch.Tensor,
|
|
40
|
+
threshold_tensor: torch.Tensor,
|
|
41
|
+
num_bits: int,
|
|
42
|
+
signed: bool,
|
|
43
|
+
power_of_two: bool) -> torch.Tensor:
|
|
44
|
+
"""
|
|
45
|
+
Quantize a tensor symmetrically for GPTQ quantizers.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
|
|
49
|
+
auxvar_tensor: Tensor that manifests the bit shift of the quantized weights due to gptq training.
|
|
50
|
+
threshold_tensor: Tensor with values to compute the threshold.
|
|
51
|
+
num_bits: Num of bits to use.
|
|
52
|
+
signed: Signedness of the quantization range.
|
|
53
|
+
power_of_two: Whether the threshold should be constrained or not.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
A quantized tensor.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
if power_of_two:
|
|
60
|
+
threshold_tensor = qutils.power_of_two_max(threshold_tensor)
|
|
61
|
+
delta = qutils.calculate_delta(threshold_tensor, num_bits, signed)
|
|
62
|
+
with torch.no_grad():
|
|
63
|
+
input_tensor_int = torch.floor(input_tensor / delta)
|
|
64
|
+
tensor_q = input_tensor_int + auxvar_tensor
|
|
65
|
+
int_threshold = 2 ** (num_bits - int(signed))
|
|
66
|
+
return delta * qutils.ste_clip(tensor_q,
|
|
67
|
+
min_val=-int(signed) * int_threshold,
|
|
68
|
+
max_val=int_threshold - 1)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
|
|
72
|
+
quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
|
|
73
|
+
quantizer_type=RoundingType.SoftQuantizer)
|
|
74
|
+
class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
75
|
+
"""
|
|
76
|
+
Trainable symmetric quantizer to optimize the rounding of the quantized values using a soft quantization method.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(self,
|
|
80
|
+
quantization_config: TrainableQuantizerWeightsConfig,
|
|
81
|
+
quantization_parameter_learning: bool = False):
|
|
82
|
+
"""
|
|
83
|
+
Construct a Pytorch model that utilize a fake weight quantizer of soft-quantizer for symmetric quantizer.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
quantization_config: Trainable weights quantizer config.
|
|
87
|
+
quantization_parameter_learning (Bool): Whether to learn the threshold or not
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
super().__init__(quantization_config)
|
|
91
|
+
self.num_bits = quantization_config.weights_n_bits
|
|
92
|
+
self.per_channel = quantization_config.weights_per_channel_threshold
|
|
93
|
+
|
|
94
|
+
threshold_values = quantization_config.weights_quantization_params[THRESHOLD]
|
|
95
|
+
self.threshold_shape = np.asarray(threshold_values).shape
|
|
96
|
+
self.threshold_values = np.reshape(np.asarray(threshold_values), [-1]) if self.per_channel else float(
|
|
97
|
+
threshold_values)
|
|
98
|
+
|
|
99
|
+
self.quantization_axis = quantization_config.weights_channels_axis
|
|
100
|
+
self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
|
|
101
|
+
self.quantization_parameter_learning = quantization_parameter_learning
|
|
102
|
+
|
|
103
|
+
# gamma and zeta are stretch parameters for computing the rectified sigmoind function.
|
|
104
|
+
# See: https://arxiv.org/pdf/2004.10568.pdf
|
|
105
|
+
self.gamma = SOFT_ROUNDING_GAMMA
|
|
106
|
+
self.zeta = SOFT_ROUNDING_ZETA
|
|
107
|
+
|
|
108
|
+
self.quantizer_parameters = {}
|
|
109
|
+
|
|
110
|
+
def initialize_quantization(self,
|
|
111
|
+
tensor_shape: torch.Size,
|
|
112
|
+
name: str,
|
|
113
|
+
layer: qi.PytorchQuantizationWrapper):
|
|
114
|
+
"""
|
|
115
|
+
Add quantizer parameters to the quantizer parameters dictionary
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
tensor_shape: tensor shape of the quantized tensor.
|
|
119
|
+
name: Tensor name.
|
|
120
|
+
layer: Layer to quantize.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
if self.per_channel:
|
|
124
|
+
threshold_tensor = to_torch_tensor(self.threshold_values)
|
|
125
|
+
else:
|
|
126
|
+
threshold_tensor = torch.tensor(self.threshold_values)
|
|
127
|
+
layer.register_parameter(f"{name}_{PTQ_THRESHOLD}",
|
|
128
|
+
nn.Parameter(threshold_tensor, requires_grad=False))
|
|
129
|
+
|
|
130
|
+
w = layer.layer.weight
|
|
131
|
+
delta = qutils.calculate_delta(threshold_tensor.reshape(self.threshold_shape), self.num_bits, signed=True)
|
|
132
|
+
w_clipped_normed = torch.clip(w / delta, -2**(self.num_bits-1), 2**(self.num_bits-1)-1)
|
|
133
|
+
rest = w_clipped_normed - torch.floor(w_clipped_normed) # rest of rounding [0, 1)
|
|
134
|
+
# Note that (rest - self.gamma) can't be zero since rest is positive and gamma is negative, so the division
|
|
135
|
+
# is safe
|
|
136
|
+
alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # => sigmoid(alpha) = rest
|
|
137
|
+
|
|
138
|
+
layer.register_parameter(f"{name}_{AUXVAR}", nn.Parameter(alpha, requires_grad=True))
|
|
139
|
+
|
|
140
|
+
# save the quantizer added parameters for later calculations
|
|
141
|
+
self.add_quantizer_variable(PTQ_THRESHOLD, layer.get_parameter(f"{name}_{PTQ_THRESHOLD}"), VariableGroup.QPARAMS)
|
|
142
|
+
self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
|
|
143
|
+
|
|
144
|
+
if self.quantization_parameter_learning:
|
|
145
|
+
if self.per_channel:
|
|
146
|
+
layer.register_parameter(f"{name}_{SCALE_PTQ}",
|
|
147
|
+
nn.Parameter(to_torch_tensor(torch.ones_like(torch.Tensor(self.threshold_values))),
|
|
148
|
+
requires_grad=True))
|
|
149
|
+
else:
|
|
150
|
+
layer.register_parameter(f"{name}_{SCALE_PTQ}",
|
|
151
|
+
nn.Parameter(to_torch_tensor((torch.tensor([1.0], requires_grad=True)))))
|
|
152
|
+
self.add_quantizer_variable(SCALE_PTQ, layer.get_parameter(f"{name}_{SCALE_PTQ}"), VariableGroup.QPARAMS)
|
|
153
|
+
|
|
154
|
+
def get_soft_targets(self) -> torch.Tensor:
|
|
155
|
+
"""
|
|
156
|
+
Computes the rectified sigmoid function for the quantization target parameters.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
A tensor with the soft rounding targets values.
|
|
160
|
+
|
|
161
|
+
"""
|
|
162
|
+
scaled_sigmoid = torch.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma
|
|
163
|
+
return torch.clip(scaled_sigmoid, min=0, max=1)
|
|
164
|
+
|
|
165
|
+
def get_quant_config(self) -> Dict[str, np.ndarray]:
|
|
166
|
+
"""
|
|
167
|
+
Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
|
|
171
|
+
Keys must match NodeQuantizationConfig attributes
|
|
172
|
+
|
|
173
|
+
"""
|
|
174
|
+
old_threshold = torch_tensor_to_numpy(self.get_quantizer_variable(PTQ_THRESHOLD))
|
|
175
|
+
old_threshold = np.resize(old_threshold, self.threshold_shape)
|
|
176
|
+
if self.power_of_two:
|
|
177
|
+
old_threshold = max_power_of_two(old_threshold, MIN_THRESHOLD)
|
|
178
|
+
else:
|
|
179
|
+
if self.quantization_parameter_learning:
|
|
180
|
+
scale = torch.reshape(self.get_quantizer_variable(SCALE_PTQ), self.threshold_shape)
|
|
181
|
+
old_threshold = old_threshold * torch_tensor_to_numpy(scale)
|
|
182
|
+
old_threshold = old_threshold.reshape(self.threshold_shape)
|
|
183
|
+
return {THRESHOLD: old_threshold}
|
|
184
|
+
|
|
185
|
+
def __call__(self,
|
|
186
|
+
inputs: nn.Parameter,
|
|
187
|
+
training: bool) -> torch.Tensor:
|
|
188
|
+
"""
|
|
189
|
+
Quantize a tensor.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
inputs: Input tensor to quantize.
|
|
193
|
+
training: whether in training mode or not
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
quantized tensor
|
|
197
|
+
"""
|
|
198
|
+
auxvar = self.get_quantizer_variable(AUXVAR)
|
|
199
|
+
ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
|
|
200
|
+
|
|
201
|
+
#####################################################
|
|
202
|
+
# Soft Rounding
|
|
203
|
+
#####################################################
|
|
204
|
+
aux_var = self.get_soft_targets()
|
|
205
|
+
if not training:
|
|
206
|
+
aux_var = (aux_var >= 0.5).to(auxvar.dtype)
|
|
207
|
+
|
|
208
|
+
if self.per_channel:
|
|
209
|
+
reshape_shape = get_threshold_reshape_shape(inputs.shape,
|
|
210
|
+
quant_axis=self.quantization_axis,
|
|
211
|
+
quant_axis_dim=-1)
|
|
212
|
+
|
|
213
|
+
##########################################################
|
|
214
|
+
# Calculate soft rounding targets and optimized threshold
|
|
215
|
+
##########################################################
|
|
216
|
+
ptq_threshold_tensor_hat = torch.reshape(ptq_threshold_tensor, reshape_shape)
|
|
217
|
+
|
|
218
|
+
#####################################################
|
|
219
|
+
# Quantized Input
|
|
220
|
+
#####################################################
|
|
221
|
+
q_tensor = soft_rounding_symmetric_quantizer(input_tensor=inputs,
|
|
222
|
+
auxvar_tensor=aux_var,
|
|
223
|
+
threshold_tensor=ptq_threshold_tensor_hat,
|
|
224
|
+
num_bits=self.num_bits,
|
|
225
|
+
signed=True,
|
|
226
|
+
power_of_two=self.power_of_two)
|
|
227
|
+
|
|
228
|
+
if self.quantization_parameter_learning and not self.power_of_two:
|
|
229
|
+
scale = torch.reshape(self.get_quantizer_variable(SCALE_PTQ), reshape_shape)
|
|
230
|
+
q_tensor *= scale
|
|
231
|
+
|
|
232
|
+
else:
|
|
233
|
+
q_tensor = soft_rounding_symmetric_quantizer(input_tensor=inputs,
|
|
234
|
+
auxvar_tensor=aux_var,
|
|
235
|
+
threshold_tensor=ptq_threshold_tensor,
|
|
236
|
+
num_bits=self.num_bits,
|
|
237
|
+
signed=True,
|
|
238
|
+
power_of_two=self.power_of_two)
|
|
239
|
+
|
|
240
|
+
if self.quantization_parameter_learning and not self.power_of_two:
|
|
241
|
+
scale = self.get_quantizer_variable(SCALE_PTQ)
|
|
242
|
+
q_tensor *= scale
|
|
243
|
+
|
|
244
|
+
return q_tensor
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
from typing import Dict
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
21
|
+
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
22
|
+
from model_compression_toolkit.gptq.common.gptq_config import RoundingType
|
|
23
|
+
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
|
|
24
|
+
BasePytorchGPTQTrainableQuantizer
|
|
25
|
+
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
|
|
26
|
+
from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
|
|
27
|
+
from model_compression_toolkit.gptq.common.gptq_constants import SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
|
|
28
|
+
from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import fix_range_to_include_zero
|
|
29
|
+
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
|
|
30
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import \
|
|
31
|
+
mark_quantizer
|
|
32
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import \
|
|
33
|
+
VariableGroup
|
|
34
|
+
from model_compression_toolkit.core.common.constants import RANGE_MAX, RANGE_MIN
|
|
35
|
+
from model_compression_toolkit.qat.common.constants import FQ_MIN, FQ_MAX
|
|
36
|
+
|
|
37
|
+
def soft_rounding_unifrom_quantizer(input_tensor: torch.Tensor,
|
|
38
|
+
auxvar_tensor: torch.Tensor,
|
|
39
|
+
min_range: torch.Tensor,
|
|
40
|
+
max_range: torch.Tensor,
|
|
41
|
+
num_bits: int) -> torch.Tensor:
|
|
42
|
+
"""
|
|
43
|
+
Quantize a tensor uniformly for GPTQ quantizers.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
|
|
47
|
+
auxvar_tensor: Tensor that manifests the bit shift of the quantized weights due to gptq training.
|
|
48
|
+
min_range: Tensor with min values to compute the delta grid.
|
|
49
|
+
max_range: Tensor with max values to compute the delta grid.
|
|
50
|
+
num_bits: Num of bits to use.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
A quantized tensor.
|
|
54
|
+
"""
|
|
55
|
+
# adjusts the quantization range so the quantization grid includes zero.
|
|
56
|
+
min_range, max_range = fix_range_to_include_zero(min_range, max_range, num_bits)
|
|
57
|
+
delta = qutils.calculate_delta_uniform(max_range, min_range, num_bits)
|
|
58
|
+
with torch.no_grad():
|
|
59
|
+
input_tensor_int = torch.floor(input_tensor / delta)
|
|
60
|
+
tensor_q = input_tensor_int + auxvar_tensor
|
|
61
|
+
return delta * qutils.ste_clip(tensor_q,
|
|
62
|
+
min_val=0,
|
|
63
|
+
max_val=2 ** num_bits - 1)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
|
|
67
|
+
quantization_method=[QuantizationMethod.UNIFORM],
|
|
68
|
+
quantizer_type=RoundingType.SoftQuantizer)
|
|
69
|
+
class UniformSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
70
|
+
"""
|
|
71
|
+
Trainable uniform quantizer to optimize the rounding of the quantized values using a soft quantization method.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self,
|
|
75
|
+
quantization_config: TrainableQuantizerWeightsConfig,
|
|
76
|
+
quantization_parameter_learning: bool = False):
|
|
77
|
+
"""
|
|
78
|
+
Construct a Pytorch model that utilize a fake weight quantizer of soft-quantizer for symmetric quantizer.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
quantization_config: Trainable weights quantizer config.
|
|
82
|
+
quantization_parameter_learning (Bool): Whether to learn the min/max ranges or not
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
super().__init__(quantization_config)
|
|
86
|
+
self.num_bits = quantization_config.weights_n_bits
|
|
87
|
+
self.per_channel = quantization_config.weights_per_channel_threshold
|
|
88
|
+
|
|
89
|
+
self.min_values = quantization_config.weights_quantization_params[RANGE_MIN]
|
|
90
|
+
self.max_values = quantization_config.weights_quantization_params[RANGE_MAX]
|
|
91
|
+
|
|
92
|
+
self.quantization_axis = quantization_config.weights_channels_axis
|
|
93
|
+
self.quantization_parameter_learning = quantization_parameter_learning
|
|
94
|
+
|
|
95
|
+
# gamma and zeta are stretch parameters for computing the rectified sigmoid function.
|
|
96
|
+
# See: https://arxiv.org/pdf/2004.10568.pdf
|
|
97
|
+
self.gamma = SOFT_ROUNDING_GAMMA
|
|
98
|
+
self.zeta = SOFT_ROUNDING_ZETA
|
|
99
|
+
|
|
100
|
+
def initialize_quantization(self,
|
|
101
|
+
tensor_shape: torch.Size,
|
|
102
|
+
name: str,
|
|
103
|
+
layer: qi.PytorchQuantizationWrapper):
|
|
104
|
+
"""
|
|
105
|
+
Add quantizer parameters to the quantizer parameters dictionary
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
tensor_shape: tensor shape of the quantized tensor.
|
|
109
|
+
name: Tensor name.
|
|
110
|
+
layer: Layer to quantize.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
# Add min and max variables to layer.
|
|
114
|
+
if self.per_channel:
|
|
115
|
+
min_values = to_torch_tensor(self.min_values)
|
|
116
|
+
max_values = to_torch_tensor(self.max_values)
|
|
117
|
+
else:
|
|
118
|
+
min_values = torch.tensor(self.min_values)
|
|
119
|
+
max_values = torch.tensor(self.max_values)
|
|
120
|
+
|
|
121
|
+
layer.register_parameter(name+"_"+FQ_MIN, nn.Parameter(min_values, requires_grad=self.quantization_parameter_learning))
|
|
122
|
+
layer.register_parameter(name+"_"+FQ_MAX, nn.Parameter(max_values, requires_grad=self.quantization_parameter_learning))
|
|
123
|
+
|
|
124
|
+
w = layer.layer.weight
|
|
125
|
+
delta = qutils.calculate_delta_uniform(max_values, min_values, self.num_bits)
|
|
126
|
+
w_clipped_normed = torch.clip(w / delta, 0, 2 ** self.num_bits - 1)
|
|
127
|
+
rest = w_clipped_normed - torch.floor(w_clipped_normed) # rest of rounding [0, 1)
|
|
128
|
+
alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # => sigmoid(alpha) = rest
|
|
129
|
+
layer.register_parameter(f"{name}_{AUXVAR}", nn.Parameter(alpha, requires_grad=True))
|
|
130
|
+
|
|
131
|
+
# Save the quantizer parameters
|
|
132
|
+
self.add_quantizer_variable(FQ_MIN, layer.get_parameter(name+"_"+FQ_MIN), VariableGroup.QPARAMS)
|
|
133
|
+
self.add_quantizer_variable(FQ_MAX, layer.get_parameter(name+"_"+FQ_MAX), VariableGroup.QPARAMS)
|
|
134
|
+
self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def get_soft_targets(self) -> torch.Tensor:
|
|
138
|
+
"""
|
|
139
|
+
Computes the rectified sigmoid function for the quantization target parameters.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
A tensor with the soft rounding targets values.
|
|
143
|
+
|
|
144
|
+
"""
|
|
145
|
+
scaled_sigmoid = torch.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma
|
|
146
|
+
return torch.clip(scaled_sigmoid, min=0, max=1)
|
|
147
|
+
|
|
148
|
+
def get_quant_config(self) -> Dict[str, np.ndarray]:
|
|
149
|
+
"""
|
|
150
|
+
Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
|
|
154
|
+
Keys must match NodeQuantizationConfig attributes
|
|
155
|
+
|
|
156
|
+
"""
|
|
157
|
+
min_values = torch_tensor_to_numpy(self.get_quantizer_variable(FQ_MIN))
|
|
158
|
+
max_values = torch_tensor_to_numpy(self.get_quantizer_variable(FQ_MAX))
|
|
159
|
+
return {RANGE_MIN: min_values,
|
|
160
|
+
RANGE_MAX: max_values}
|
|
161
|
+
|
|
162
|
+
def __call__(self,
|
|
163
|
+
inputs: nn.Parameter,
|
|
164
|
+
training: bool) -> torch.Tensor:
|
|
165
|
+
"""
|
|
166
|
+
Quantize a tensor.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
inputs: Input tensor to quantize.
|
|
170
|
+
training: whether in training mode or not
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
quantized tensor
|
|
174
|
+
"""
|
|
175
|
+
auxvar = self.get_quantizer_variable(AUXVAR)
|
|
176
|
+
min_range = self.get_quantizer_variable(FQ_MIN)
|
|
177
|
+
max_range = self.get_quantizer_variable(FQ_MAX)
|
|
178
|
+
|
|
179
|
+
#####################################################
|
|
180
|
+
# Soft Rounding
|
|
181
|
+
#####################################################
|
|
182
|
+
aux_var = self.get_soft_targets()
|
|
183
|
+
if not training:
|
|
184
|
+
aux_var = (aux_var >= 0.5).to(auxvar.dtype)
|
|
185
|
+
|
|
186
|
+
#####################################################
|
|
187
|
+
# Quantized Input
|
|
188
|
+
#####################################################
|
|
189
|
+
q_tensor = soft_rounding_unifrom_quantizer(input_tensor=inputs,
|
|
190
|
+
auxvar_tensor=aux_var,
|
|
191
|
+
min_range=min_range,
|
|
192
|
+
max_range=max_range,
|
|
193
|
+
num_bits=self.num_bits)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
return q_tensor
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
from typing import Dict
|
|
18
|
+
import numpy as np
|
|
19
|
+
from model_compression_toolkit.core.common.defaultdict import DefaultDict
|
|
20
|
+
|
|
21
|
+
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
22
|
+
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
23
|
+
from model_compression_toolkit.gptq.common.gptq_config import RoundingType
|
|
24
|
+
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
|
|
25
|
+
BasePytorchGPTQTrainableQuantizer
|
|
26
|
+
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
|
|
27
|
+
from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
|
|
28
|
+
from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD, MAX_LSB_CHANGE
|
|
29
|
+
from model_compression_toolkit.core.common.constants import THRESHOLD
|
|
30
|
+
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
|
|
31
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import \
|
|
32
|
+
mark_quantizer
|
|
33
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
34
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
|
|
35
|
+
get_threshold_reshape_shape
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def pertubation_symmetric_quantizer(input_tensor: torch.Tensor,
|
|
39
|
+
auxvar_tensor: nn.Parameter,
|
|
40
|
+
max_tensor: torch.Tensor,
|
|
41
|
+
num_bits: int,
|
|
42
|
+
signed: bool,
|
|
43
|
+
power_of_two: bool,
|
|
44
|
+
max_lsbs_change: int = MAX_LSB_CHANGE) -> nn.Parameter:
|
|
45
|
+
"""
|
|
46
|
+
Quantize a tensor symmetrically with maximum LSBs shift.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
|
|
50
|
+
auxvar_tensor: Tensor that manifests the bit shift the weight due to gptq
|
|
51
|
+
max_tensor: Tensor with max values to compute the threshold.
|
|
52
|
+
num_bits: Num of bits to use.
|
|
53
|
+
signed: Signedness of the quantization range.
|
|
54
|
+
power_of_two: Whether the threshold should be constrained or not.
|
|
55
|
+
max_lsbs_change: maximum number of LSBs that the auxvar is allowed to change
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
A quantized tensor.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
if power_of_two:
|
|
62
|
+
max_tensor = qutils.power_of_two_max(max_tensor)
|
|
63
|
+
delta = qutils.calculate_delta(max_tensor, num_bits, signed)
|
|
64
|
+
delta = to_torch_tensor(delta)
|
|
65
|
+
max_tensor_change = delta * max_lsbs_change
|
|
66
|
+
|
|
67
|
+
min_int = -int(signed) * (2 ** (num_bits - int(signed)))
|
|
68
|
+
max_int = (2 ** (num_bits - int(signed))) - 1
|
|
69
|
+
|
|
70
|
+
tensor_clipped = qutils.ste_clip(auxvar_tensor, min_val=-max_tensor_change, max_val=max_tensor_change) / delta
|
|
71
|
+
input_tensor_int = torch.round(input_tensor / delta).detach()
|
|
72
|
+
|
|
73
|
+
tensor_q = qutils.ste_round(qutils.ste_round(input_tensor_int + tensor_clipped))
|
|
74
|
+
|
|
75
|
+
return delta * qutils.ste_clip(tensor_q, max_val=max_int, min_val=min_int)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
|
|
79
|
+
quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
|
|
80
|
+
quantizer_type=RoundingType.STE)
|
|
81
|
+
class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
|
|
82
|
+
"""
|
|
83
|
+
Trainable symmetric quantizer to quantize a layer weights.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __init__(self,
|
|
87
|
+
quantization_config: TrainableQuantizerWeightsConfig,
|
|
88
|
+
max_lsbs_change_map: dict = DefaultDict({}, lambda: 1)):
|
|
89
|
+
"""
|
|
90
|
+
Construct a Pytorch model that utilize a fake weight quantizer of STE (Straight Through Estimator) for symmetric quantizer.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
quantization_config: Trainable weights quantizer config.
|
|
94
|
+
"""
|
|
95
|
+
super().__init__(quantization_config)
|
|
96
|
+
self.num_bits = quantization_config.weights_n_bits
|
|
97
|
+
self.per_channel = quantization_config.weights_per_channel_threshold
|
|
98
|
+
|
|
99
|
+
threshold_values = quantization_config.weights_quantization_params[THRESHOLD]
|
|
100
|
+
self.threshold_shape = np.asarray(threshold_values).shape
|
|
101
|
+
self.threshold_values = np.reshape(np.asarray(threshold_values), [-1]) if self.per_channel else float(
|
|
102
|
+
threshold_values)
|
|
103
|
+
|
|
104
|
+
self.quantization_axis = quantization_config.weights_channels_axis
|
|
105
|
+
self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
|
|
106
|
+
self.max_lsbs_change = max_lsbs_change_map.get(self.num_bits)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def initialize_quantization(self,
|
|
110
|
+
tensor_shape: torch.Size,
|
|
111
|
+
name: str,
|
|
112
|
+
layer: qi.PytorchQuantizationWrapper):
|
|
113
|
+
"""
|
|
114
|
+
Add quantizer parameters to the quantizer parameters dictionary
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
tensor_shape: tensor shape of the quantized tensor.
|
|
118
|
+
name: Tensor name.
|
|
119
|
+
layer: Layer to quantize.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
layer.register_parameter(f"{name}_{PTQ_THRESHOLD}",
|
|
123
|
+
nn.Parameter(torch.tensor(self.threshold_values, requires_grad=False)
|
|
124
|
+
if not self.per_channel
|
|
125
|
+
else to_torch_tensor(self.threshold_values),requires_grad=False))
|
|
126
|
+
layer.register_parameter(f"{name}_{AUXVAR}", nn.Parameter(to_torch_tensor(torch.zeros(self.threshold_shape)),
|
|
127
|
+
requires_grad=True))
|
|
128
|
+
|
|
129
|
+
# save the quantizer added parameters for later calculations
|
|
130
|
+
self.add_quantizer_variable(PTQ_THRESHOLD, layer.get_parameter(f"{name}_{PTQ_THRESHOLD}"), VariableGroup.QPARAMS)
|
|
131
|
+
self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def get_quant_config(self) -> Dict[str, np.ndarray]:
|
|
135
|
+
"""
|
|
136
|
+
Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
|
|
140
|
+
Keys must match NodeQuantizationConfig attributes
|
|
141
|
+
|
|
142
|
+
"""
|
|
143
|
+
old_threshold = self.get_quantizer_variable(PTQ_THRESHOLD)
|
|
144
|
+
return {THRESHOLD: torch_tensor_to_numpy(old_threshold).reshape(self.threshold_shape)}
|
|
145
|
+
|
|
146
|
+
def __call__(self,
|
|
147
|
+
inputs: nn.Parameter,
|
|
148
|
+
training: bool) -> nn.Parameter:
|
|
149
|
+
"""
|
|
150
|
+
Quantize a tensor.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
inputs: Input tensor to quantize.
|
|
154
|
+
training: whether in training mode or not
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
quantized tensor
|
|
158
|
+
"""
|
|
159
|
+
auxvar = self.get_quantizer_variable(AUXVAR)
|
|
160
|
+
ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
|
|
161
|
+
|
|
162
|
+
if self.per_channel:
|
|
163
|
+
reshape_shape = get_threshold_reshape_shape(inputs.shape,
|
|
164
|
+
quant_axis=self.quantization_axis,
|
|
165
|
+
quant_axis_dim=-1)
|
|
166
|
+
ptq_threshold_tensor = torch.reshape(ptq_threshold_tensor, reshape_shape)
|
|
167
|
+
|
|
168
|
+
q_tensor = pertubation_symmetric_quantizer(inputs,
|
|
169
|
+
auxvar,
|
|
170
|
+
ptq_threshold_tensor,
|
|
171
|
+
self.num_bits,
|
|
172
|
+
signed=True,
|
|
173
|
+
power_of_two=self.power_of_two,
|
|
174
|
+
max_lsbs_change=self.max_lsbs_change)
|
|
175
|
+
return q_tensor
|
|
176
|
+
else:
|
|
177
|
+
return pertubation_symmetric_quantizer(inputs,
|
|
178
|
+
auxvar,
|
|
179
|
+
ptq_threshold_tensor,
|
|
180
|
+
self.num_bits,
|
|
181
|
+
signed=True,
|
|
182
|
+
power_of_two=self.power_of_two)
|
|
@@ -125,8 +125,8 @@ if FOUND_TF:
|
|
|
125
125
|
if core_config.mixed_precision_enable:
|
|
126
126
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
|
|
127
127
|
common.Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
128
|
-
"MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization
|
|
129
|
-
"or pass a valid mixed precision configuration.")
|
|
128
|
+
"MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
|
|
129
|
+
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
130
130
|
|
|
131
131
|
common.Logger.info("Using experimental mixed-precision quantization. "
|
|
132
132
|
"If you encounter an issue please file a bug.")
|
|
@@ -171,4 +171,4 @@ else:
|
|
|
171
171
|
def keras_post_training_quantization_experimental(*args, **kwargs):
|
|
172
172
|
Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
173
173
|
'when using keras_post_training_quantization_experimental. '
|
|
174
|
-
'Could not find Tensorflow package.')
|
|
174
|
+
'Could not find Tensorflow package.') # pragma: no cover
|