mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -12,16 +12,22 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from typing import Callable, List, Tuple
|
|
15
|
+
from typing import Callable, List, Tuple, Union
|
|
16
16
|
|
|
17
17
|
import tensorflow as tf
|
|
18
|
-
from
|
|
18
|
+
from keras import Model
|
|
19
|
+
from tensorflow.keras.layers import Layer
|
|
19
20
|
from tqdm import tqdm
|
|
20
21
|
|
|
21
22
|
# As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
|
|
22
|
-
from model_compression_toolkit.
|
|
23
|
+
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
24
|
+
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
|
23
25
|
from packaging import version
|
|
24
26
|
|
|
27
|
+
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
28
|
+
from model_compression_toolkit.gptq.keras.quantizer.quantization_builder import quantization_builder
|
|
29
|
+
from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
|
|
30
|
+
|
|
25
31
|
if version.parse(tf.__version__) < version.parse("2.6"):
|
|
26
32
|
from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
|
|
27
33
|
else:
|
|
@@ -31,15 +37,14 @@ from model_compression_toolkit.core import common
|
|
|
31
37
|
from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
|
|
32
38
|
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
|
|
33
39
|
from model_compression_toolkit.core.common import Graph
|
|
34
|
-
from model_compression_toolkit.gptq.keras.graph_info import
|
|
35
|
-
|
|
40
|
+
from model_compression_toolkit.gptq.keras.graph_info import get_weights_for_loss, get_gptq_trainable_parameters
|
|
41
|
+
from model_compression_toolkit.gptq.keras.quantizer.regularization_factory import get_regularization
|
|
36
42
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
37
43
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
38
44
|
import numpy as np
|
|
39
45
|
import copy
|
|
40
46
|
from model_compression_toolkit.core.keras.constants import BIAS, USE_BIAS
|
|
41
|
-
from model_compression_toolkit
|
|
42
|
-
from model_compression_toolkit.gptq.keras.optimizers.sam_optimizer import SAM
|
|
47
|
+
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
43
48
|
|
|
44
49
|
|
|
45
50
|
class KerasGPTQTrainer(GPTQTrainer):
|
|
@@ -77,11 +82,10 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
77
82
|
self.loss_list = []
|
|
78
83
|
self.input_scale = 1
|
|
79
84
|
|
|
80
|
-
trainable_weights, bias_weights, trainable_threshold
|
|
85
|
+
trainable_weights, bias_weights, trainable_threshold = get_gptq_trainable_parameters(
|
|
81
86
|
self.fxp_model,
|
|
82
87
|
fw_info,
|
|
83
|
-
add_bias=gptq_config.train_bias
|
|
84
|
-
is_gumbel=gptq_config.is_gumbel)
|
|
88
|
+
add_bias=gptq_config.train_bias)
|
|
85
89
|
|
|
86
90
|
self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model)
|
|
87
91
|
|
|
@@ -96,29 +100,70 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
96
100
|
trainable_quantization_parameters = trainable_threshold
|
|
97
101
|
self.optimizer_with_param = self.get_optimizer_with_param(flattened_trainable_weights,
|
|
98
102
|
flattened_bias_weights,
|
|
99
|
-
trainable_quantization_parameters
|
|
100
|
-
|
|
101
|
-
|
|
103
|
+
trainable_quantization_parameters)
|
|
104
|
+
self.has_params_to_train = np.sum(
|
|
105
|
+
[len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0
|
|
102
106
|
|
|
103
107
|
if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
|
|
104
108
|
common.Logger.error("Input scale mismatch between float and GPTQ networks") # pragma: no cover
|
|
105
109
|
else:
|
|
106
110
|
self.input_scale = self.gptq_user_info.input_scale
|
|
107
111
|
|
|
108
|
-
self.weights_for_average_loss = self.
|
|
112
|
+
self.weights_for_average_loss = self.compute_hessian_based_weights(representative_data_gen)
|
|
113
|
+
|
|
114
|
+
self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
|
|
109
115
|
|
|
110
|
-
def
|
|
116
|
+
def _is_gptq_applicable(self,
|
|
117
|
+
node: common.BaseNode) -> bool:
|
|
118
|
+
"""
|
|
119
|
+
A function for deciding if a layer should be fine-tuned during GPTQ.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
node (BaseNode): Node for quantization decision
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
A boolean whether the layer is to be wrapped with a QuantizeWrapper
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
if node.is_weights_quantization_enabled() and not self.fw_info.is_kernel_op(node.type):
|
|
129
|
+
common.Logger.error(f"GPTQ Error: Quantizing node {node.name} of type {node.type} "
|
|
130
|
+
f"without a kernel isn't supported")
|
|
131
|
+
return node.is_weights_quantization_enabled()
|
|
132
|
+
|
|
133
|
+
def gptq_wrapper(self, n: common.BaseNode, layer: Layer) -> Union[qi.KerasQuantizationWrapper, Layer]:
|
|
134
|
+
"""
|
|
135
|
+
A function which takes a computational graph node and a keras layer and perform the quantization wrapping.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
n: A node of mct graph.
|
|
139
|
+
layer: A keras layer
|
|
140
|
+
|
|
141
|
+
Returns: Wrapped layer if the layer should be wrap, otherwise returns the layer as is.
|
|
142
|
+
|
|
143
|
+
"""
|
|
144
|
+
if self._is_gptq_applicable(n):
|
|
145
|
+
weights_quantizers, activation_quantizers = quantization_builder(n, self.gptq_config)
|
|
146
|
+
return qi.KerasQuantizationWrapper(layer,
|
|
147
|
+
weights_quantizers=weights_quantizers,
|
|
148
|
+
activation_quantizers=activation_quantizers)
|
|
149
|
+
else:
|
|
150
|
+
return layer
|
|
151
|
+
|
|
152
|
+
def build_gptq_model(self) -> Tuple[Model, UserInformation]:
|
|
111
153
|
"""
|
|
112
154
|
Build the GPTQ model with QuantizationWrappers
|
|
155
|
+
|
|
113
156
|
Returns:
|
|
114
157
|
Quantized graph for GPTQ fine-tuning, GPTQ graph user info
|
|
115
158
|
"""
|
|
116
159
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
160
|
+
gptq_model, gptq_user_info = KerasModelBuilder(graph=self.graph_quant,
|
|
161
|
+
append2output=self.compare_points,
|
|
162
|
+
fw_info=self.fw_info,
|
|
163
|
+
return_float_outputs=True,
|
|
164
|
+
wrapper=self.gptq_wrapper).build_model()
|
|
165
|
+
|
|
166
|
+
return gptq_model, gptq_user_info
|
|
122
167
|
|
|
123
168
|
def compute_gradients(self, in_y_float: List[tf.Tensor], input_data: List[np.ndarray],
|
|
124
169
|
in_optimizer_with_param: List,
|
|
@@ -149,18 +194,9 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
149
194
|
self.compare_points_std,
|
|
150
195
|
self.weights_for_average_loss)
|
|
151
196
|
|
|
152
|
-
|
|
153
|
-
gumbel_prob = get_gumbel_probability(self.fxp_model)
|
|
154
|
-
gumbel_reg = 0
|
|
155
|
-
for p in gumbel_prob:
|
|
156
|
-
entropy = -tf.reduce_mean(
|
|
157
|
-
tf.reduce_sum(p * tf.math.log(tf.maximum(p,
|
|
158
|
-
self.gptq_config.eps)),
|
|
159
|
-
axis=0))
|
|
197
|
+
reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor)
|
|
160
198
|
|
|
161
|
-
|
|
162
|
-
gumbel_reg /= len(gumbel_prob)
|
|
163
|
-
loss_value += self.gptq_config.quantizer_config.gumbel_entropy_regularization * gumbel_reg
|
|
199
|
+
loss_value += reg_value
|
|
164
200
|
|
|
165
201
|
# Use the gradient tape to automatically retrieve
|
|
166
202
|
# the gradients of the trainable variables with respect to the loss.
|
|
@@ -179,9 +215,6 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
179
215
|
representative_data_gen: Dataset to use for inputs of the models.
|
|
180
216
|
"""
|
|
181
217
|
compute_gradients = self.compute_gradients
|
|
182
|
-
if self.gptq_config.sam_optimization:
|
|
183
|
-
sam = SAM(self.fxp_model, self.compute_gradients, self.optimizer_with_param, self.gptq_config.rho)
|
|
184
|
-
compute_gradients = sam.compute_gradients
|
|
185
218
|
|
|
186
219
|
# ----------------------------------------------
|
|
187
220
|
# Training loop
|
|
@@ -237,7 +270,8 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
237
270
|
for data in tqdm(data_function()):
|
|
238
271
|
input_data = [d * self.input_scale for d in data]
|
|
239
272
|
|
|
240
|
-
loss_value_step, grads = self.nano_training_step(input_data, in_compute_gradients,
|
|
273
|
+
loss_value_step, grads = self.nano_training_step(input_data, in_compute_gradients,
|
|
274
|
+
in_optimizer_with_param, is_training)
|
|
241
275
|
# Run one step of gradient descent by updating
|
|
242
276
|
# the value of the variables to minimize the loss.
|
|
243
277
|
for i, (o, p) in enumerate(in_optimizer_with_param):
|
|
@@ -258,16 +292,17 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
258
292
|
graph = copy.copy(self.graph_quant)
|
|
259
293
|
|
|
260
294
|
for layer in self.fxp_model.layers:
|
|
261
|
-
if isinstance(layer,
|
|
262
|
-
layer.quantize_config, WeightQuantizeConfig):
|
|
295
|
+
if isinstance(layer, KerasQuantizationWrapper):
|
|
263
296
|
node = graph.find_node_by_name(layer.layer.name)
|
|
264
297
|
if len(node) == 0 and isinstance(layer.layer, TensorFlowOpLayer):
|
|
265
298
|
node = graph.find_node_by_name('_'.join(layer.layer.name.split('_')[3:]))
|
|
266
299
|
if len(node) != 1:
|
|
267
300
|
common.Logger.error(f"Can't update GPTQ graph due to missing layer named: {layer.layer.name}")
|
|
268
301
|
node = node[0]
|
|
302
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
|
|
303
|
+
fw_info=self.fw_info)
|
|
269
304
|
weights, weight_quant_config, activation_quant_config = \
|
|
270
|
-
layer.
|
|
305
|
+
layer.weights_quantizers[kernel_attribute].update_layer_quantization_params(layer)
|
|
271
306
|
for weight_attr, weight in weights.items():
|
|
272
307
|
node.set_weights_by_keys(weight_attr, weight.numpy())
|
|
273
308
|
for config_attr, config_value in weight_quant_config.items():
|
|
@@ -281,4 +316,3 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
281
316
|
node.set_weights_by_keys(BIAS, new_bias)
|
|
282
317
|
|
|
283
318
|
return graph
|
|
284
|
-
|
|
@@ -13,22 +13,21 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
|
|
17
16
|
import tensorflow as tf
|
|
18
|
-
from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
|
|
19
17
|
from typing import Tuple, List
|
|
20
|
-
|
|
21
18
|
from model_compression_toolkit.core.keras.constants import USE_BIAS
|
|
22
|
-
from model_compression_toolkit.gptq.keras.quantizer import WeightQuantizeConfig
|
|
23
19
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
24
20
|
from tensorflow.keras.models import Model
|
|
21
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
22
|
+
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
23
|
+
from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
|
|
24
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
def
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
List[tf.Variable], List[tf.Variable], List[tf.Variable], List[tf.Variable], List[tf.Variable]):
|
|
27
|
+
def get_gptq_trainable_parameters(fxp_model: Model,
|
|
28
|
+
fw_info: FrameworkInfo,
|
|
29
|
+
add_bias: bool = False) -> (
|
|
30
|
+
List[tf.Variable], List[tf.Variable], List[tf.Variable]):
|
|
32
31
|
"""
|
|
33
32
|
Get trainable parameters from all layers in a model
|
|
34
33
|
|
|
@@ -36,7 +35,6 @@ def get_trainable_parameters(fxp_model: Model,
|
|
|
36
35
|
fxp_model: Model to get its trainable parameters.
|
|
37
36
|
fw_info: Framework information needed for keras kernel ops list.
|
|
38
37
|
add_bias: Whether to include biases of the model (if there are) or not.
|
|
39
|
-
is_gumbel: Whether the fxp model is quantized using Gumbel Rounding
|
|
40
38
|
|
|
41
39
|
Returns:
|
|
42
40
|
A list of trainable variables in a model. Each item is a list of a layers weights.
|
|
@@ -45,15 +43,17 @@ def get_trainable_parameters(fxp_model: Model,
|
|
|
45
43
|
trainable_weights: List[tf.Tensor] = []
|
|
46
44
|
trainable_threshold: List[tf.Tensor] = []
|
|
47
45
|
bias_weights: List[List[tf.Tensor]] = []
|
|
48
|
-
|
|
46
|
+
|
|
49
47
|
for layer in fxp_model.layers:
|
|
50
|
-
if isinstance(layer,
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
48
|
+
if isinstance(layer, KerasQuantizationWrapper):
|
|
49
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
|
50
|
+
fw_info=DEFAULT_KERAS_INFO)
|
|
51
|
+
|
|
52
|
+
# collect trainable weights per quantizer
|
|
53
|
+
quantizer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.WEIGHTS)
|
|
54
|
+
quantizer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.QPARAMS)
|
|
55
|
+
trainable_weights.append(quantizer_trainable_weights)
|
|
56
|
+
trainable_threshold.extend(quantizer_trainable_threshold)
|
|
57
57
|
|
|
58
58
|
if add_bias:
|
|
59
59
|
kernel_ops_attrs = fw_info.kernel_ops_attributes_mapping.get(type(layer.layer))
|
|
@@ -61,27 +61,8 @@ def get_trainable_parameters(fxp_model: Model,
|
|
|
61
61
|
and layer.layer.get_config().get(USE_BIAS)
|
|
62
62
|
if use_bias is not None and use_bias:
|
|
63
63
|
bias_weights.append([layer.layer.bias])
|
|
64
|
-
trainable_weights.append(layer_trainable_weights)
|
|
65
|
-
trainable_threshold.extend(layer_trainable_threshold)
|
|
66
64
|
|
|
67
|
-
return trainable_weights, bias_weights, trainable_threshold
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def get_gumbel_probability(fxp_model: Model) -> List[tf.Tensor]:
|
|
71
|
-
"""
|
|
72
|
-
This function return the gumbel softmax probability of GumRounding
|
|
73
|
-
Args:
|
|
74
|
-
fxp_model: A model to be quantized with GumRounding
|
|
75
|
-
|
|
76
|
-
Returns: A list of tensors.
|
|
77
|
-
|
|
78
|
-
"""
|
|
79
|
-
gumbel_prob_aux: List[tf.Tensor] = []
|
|
80
|
-
for layer in fxp_model.layers:
|
|
81
|
-
if isinstance(layer, QuantizeWrapper) and isinstance(
|
|
82
|
-
layer.quantize_config, WeightQuantizeConfig):
|
|
83
|
-
gumbel_prob_aux.append(layer.quantize_config.get_gumbel_probability())
|
|
84
|
-
return gumbel_prob_aux
|
|
65
|
+
return trainable_weights, bias_weights, trainable_threshold
|
|
85
66
|
|
|
86
67
|
|
|
87
68
|
def get_weights_for_loss(fxp_model: Model) -> Tuple[List[list], List[list]]:
|
|
@@ -99,14 +80,14 @@ def get_weights_for_loss(fxp_model: Model) -> Tuple[List[list], List[list]]:
|
|
|
99
80
|
flp_weights_list = []
|
|
100
81
|
fxp_weights_list = []
|
|
101
82
|
for layer in fxp_model.layers:
|
|
102
|
-
if isinstance(layer,
|
|
103
|
-
layer.quantize_config, WeightQuantizeConfig):
|
|
83
|
+
if isinstance(layer, KerasQuantizationWrapper):
|
|
104
84
|
|
|
105
85
|
# collect pairs of float and quantized weights per layer
|
|
106
86
|
_layer_flp_weights, _layer_fxp_weights = [], []
|
|
107
|
-
for weight,
|
|
108
|
-
_layer_flp_weights.append(
|
|
109
|
-
_layer_fxp_weights.append(quantizer(
|
|
87
|
+
for weight, quantizer_vars, quantizer in layer.get_weights_vars():
|
|
88
|
+
_layer_flp_weights.append(quantizer_vars)
|
|
89
|
+
_layer_fxp_weights.append(quantizer(training=False, inputs=quantizer_vars))
|
|
90
|
+
|
|
110
91
|
flp_weights_list.append(_layer_flp_weights)
|
|
111
92
|
fxp_weights_list.append(_layer_fxp_weights)
|
|
112
93
|
|
|
@@ -85,26 +85,18 @@ if common.constants.FOUND_TF:
|
|
|
85
85
|
|
|
86
86
|
Create a GradientPTQConfigV2 to run for 5 epochs:
|
|
87
87
|
|
|
88
|
-
>>> gptq_conf = mct.get_keras_gptq_config(n_epochs=5)
|
|
88
|
+
>>> gptq_conf = mct.gptq.get_keras_gptq_config(n_epochs=5)
|
|
89
89
|
|
|
90
90
|
Other Tensorflow optimizers can be passed:
|
|
91
91
|
|
|
92
|
-
>>> gptq_conf = mct.get_keras_gptq_config(n_epochs=3, optimizer=tf.keras.optimizers.Nadam())
|
|
92
|
+
>>> gptq_conf = mct.gptq.get_keras_gptq_config(n_epochs=3, optimizer=tf.keras.optimizers.Nadam())
|
|
93
93
|
|
|
94
94
|
The configuration can be passed to :func:`~model_compression_toolkit.keras_post_training_quantization` in order to quantize a keras model using gptq.
|
|
95
95
|
|
|
96
96
|
"""
|
|
97
97
|
bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
optimizer,
|
|
101
|
-
optimizer_rest=optimizer_rest,
|
|
102
|
-
loss=loss,
|
|
103
|
-
log_function=log_function,
|
|
104
|
-
train_bias=True,
|
|
105
|
-
quantization_parameters_learning=True,
|
|
106
|
-
optimizer_bias=bias_optimizer,
|
|
107
|
-
optimizer_quantization_parameter=optimizer_quantization_parameter)
|
|
98
|
+
return GradientPTQConfigV2(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
|
|
99
|
+
log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer)
|
|
108
100
|
|
|
109
101
|
|
|
110
102
|
def keras_gradient_post_training_quantization_experimental(in_model: Model,
|
|
@@ -183,11 +175,11 @@ if common.constants.FOUND_TF:
|
|
|
183
175
|
|
|
184
176
|
Create GPTQ config:
|
|
185
177
|
|
|
186
|
-
>>> gptq_config = mct.get_keras_gptq_config(n_epochs=1)
|
|
178
|
+
>>> gptq_config = mct.gptq.get_keras_gptq_config(n_epochs=1)
|
|
187
179
|
|
|
188
180
|
Pass the model with the representative dataset generator to get a quantized model:
|
|
189
181
|
|
|
190
|
-
>>> quantized_model, quantization_info = mct.keras_gradient_post_training_quantization_experimental(model, repr_datagen, gptq_config, target_kpi=kpi, core_config=config)
|
|
182
|
+
>>> quantized_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization_experimental(model, repr_datagen, gptq_config, target_kpi=kpi, core_config=config)
|
|
191
183
|
|
|
192
184
|
"""
|
|
193
185
|
KerasModelValidation(model=in_model,
|
|
@@ -196,8 +188,8 @@ if common.constants.FOUND_TF:
|
|
|
196
188
|
if core_config.mixed_precision_enable:
|
|
197
189
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
|
|
198
190
|
common.Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
199
|
-
"MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization
|
|
200
|
-
"or pass a valid mixed precision configuration.")
|
|
191
|
+
"MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
|
|
192
|
+
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
201
193
|
|
|
202
194
|
common.Logger.info("Using experimental mixed-precision quantization. "
|
|
203
195
|
"If you encounter an issue please file a bug.")
|
|
@@ -243,10 +235,10 @@ else:
|
|
|
243
235
|
def get_keras_gptq_config(*args, **kwargs):
|
|
244
236
|
Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
245
237
|
'when using keras_post_training_quantization_mixed_precision. '
|
|
246
|
-
'Could not find Tensorflow package.')
|
|
238
|
+
'Could not find Tensorflow package.') # pragma: no cover
|
|
247
239
|
|
|
248
240
|
|
|
249
241
|
def keras_gradient_post_training_quantization_experimental(*args, **kwargs):
|
|
250
242
|
Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
251
243
|
'when using keras_gradient_post_training_quantization_experimental. '
|
|
252
|
-
'Could not find Tensorflow package.')
|
|
244
|
+
'Could not find Tensorflow package.') # pragma: no cover
|
|
@@ -13,4 +13,5 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
|
|
16
|
+
import model_compression_toolkit.gptq.keras.quantizer.ste_rounding.symmetric_ste
|
|
17
|
+
import model_compression_toolkit.gptq.keras.quantizer.soft_rounding.symmetric_soft_quantizer
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from abc import abstractmethod
|
|
16
|
+
from typing import Union, Dict, List
|
|
17
|
+
|
|
18
|
+
from model_compression_toolkit.core.common import Logger
|
|
19
|
+
from model_compression_toolkit.core.common.constants import FOUND_TF
|
|
20
|
+
from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
|
|
21
|
+
|
|
22
|
+
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
23
|
+
TrainableQuantizerActivationConfig
|
|
24
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer
|
|
25
|
+
|
|
26
|
+
if FOUND_TF:
|
|
27
|
+
import tensorflow as tf
|
|
28
|
+
|
|
29
|
+
from model_compression_toolkit.quantizers_infrastructure import BaseKerasTrainableQuantizer, \
|
|
30
|
+
KerasQuantizationWrapper
|
|
31
|
+
|
|
32
|
+
class BaseKerasGPTQTrainableQuantizer(BaseKerasTrainableQuantizer):
|
|
33
|
+
"""
|
|
34
|
+
A base class for trainable Keras quantizer for GPTQ.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self,
|
|
38
|
+
quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
|
|
39
|
+
"""
|
|
40
|
+
Initializes BaseKerasGPTQTrainableQuantizer object.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
quantization_config: quantizer config class contains all the information about a quantizer configuration.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
super().__init__(quantization_config)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def update_layer_quantization_params(self, layer: KerasQuantizationWrapper
|
|
50
|
+
) -> (Dict[str, tf.Tensor], Dict[str, Dict], Dict):
|
|
51
|
+
"""
|
|
52
|
+
A Function to calculate the needed change in attributes in NodeQuantizationConfig after retraining.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
layer: A wrapped Keras layer.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
3 dictionaries describing the change in layer's weights, weights config, activation config
|
|
59
|
+
that changed during GPTQ retraining.
|
|
60
|
+
Keys must match NodeQuantizationConfig attributes
|
|
61
|
+
|
|
62
|
+
"""
|
|
63
|
+
weights = {}
|
|
64
|
+
for weight, quantizer_vars, quantizer in layer.get_weights_vars():
|
|
65
|
+
if not isinstance(quantizer, BaseTrainableQuantizer):
|
|
66
|
+
Logger.error(f"Expecting a GPTQ trainable quantizer, " # pragma: no cover
|
|
67
|
+
f"but got {type(quantizer)} which is not callable.")
|
|
68
|
+
weights.update({weight: quantizer(training=False, inputs=quantizer_vars)})
|
|
69
|
+
|
|
70
|
+
quant_config = {WEIGHTS_QUANTIZATION_PARAMS: self.get_quant_config()}
|
|
71
|
+
|
|
72
|
+
return weights, quant_config, {}
|
|
73
|
+
|
|
74
|
+
def get_aux_variable(self) -> List[tf.Tensor]:
|
|
75
|
+
"""
|
|
76
|
+
This function return a list with the quantizer's quantization auxiliary variables.
|
|
77
|
+
|
|
78
|
+
Returns: A list with the quantization auxiliary variables.
|
|
79
|
+
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
return [] # pragma: no cover
|
|
83
|
+
|
|
84
|
+
def get_quantization_variable(self) -> List[tf.Tensor]:
|
|
85
|
+
"""
|
|
86
|
+
This function return a list with the quantizer's quantization parameters variables.
|
|
87
|
+
|
|
88
|
+
Returns: A list with the quantization parameters.
|
|
89
|
+
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
return [] # pragma: no cover
|
|
93
|
+
|
|
94
|
+
@abstractmethod
|
|
95
|
+
def get_quant_config(self):
|
|
96
|
+
"""
|
|
97
|
+
Returns the config used to edit NodeQuantizationConfig after GPTQ retraining.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
|
|
101
|
+
Keys must match NodeQuantizationConfig attributes.
|
|
102
|
+
|
|
103
|
+
"""
|
|
104
|
+
raise NotImplemented(f'{self.__class__.__name__} have to implement the ' # pragma: no cover
|
|
105
|
+
f'quantizer\'s get_quant_config.')
|
|
106
|
+
|
|
107
|
+
else:
|
|
108
|
+
class BaseKerasGPTQTrainableQuantizer: # pragma: no cover
|
|
109
|
+
def __init__(self, *args, **kwargs):
|
|
110
|
+
Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
111
|
+
'when using BaseKerasGPTQTrainableQuantizer. '
|
|
112
|
+
'Could not find Tensorflow package.') # pragma: no cover
|
|
@@ -26,6 +26,19 @@ def ste_ceil(x: tf.Tensor) -> tf.Tensor:
|
|
|
26
26
|
return error + x
|
|
27
27
|
|
|
28
28
|
|
|
29
|
+
def safe_log(x: tf.Tensor, eps: float) -> tf.Tensor:
|
|
30
|
+
"""
|
|
31
|
+
Computes log function of x unless x is smaller than some small value, so the log function would not fail.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
x: input variable.
|
|
35
|
+
eps: limit value.
|
|
36
|
+
|
|
37
|
+
Returns: log of x where x > eps, else, log of eps.
|
|
38
|
+
"""
|
|
39
|
+
return tf.math.log(tf.maximum(x, eps))
|
|
40
|
+
|
|
41
|
+
|
|
29
42
|
def ste_round(x: tf.Tensor) -> tf.Tensor:
|
|
30
43
|
"""
|
|
31
44
|
Return the rounded values of a tensor.
|
|
@@ -59,20 +72,6 @@ def calculate_delta(max_tensor: tf.Tensor,
|
|
|
59
72
|
return max_tensor / (2 ** (num_bits - int(signed)))
|
|
60
73
|
|
|
61
74
|
|
|
62
|
-
def adjustable_steps(x: tf.Variable, t: float) -> tf.Tensor:
|
|
63
|
-
"""
|
|
64
|
-
A function to gradually quantize a float variable to an integer of values [-1, 0 ,1]
|
|
65
|
-
Args:
|
|
66
|
-
x: input float variable
|
|
67
|
-
t: temperature to control quantization
|
|
68
|
-
|
|
69
|
-
Returns:
|
|
70
|
-
semi-quantized variable
|
|
71
|
-
|
|
72
|
-
"""
|
|
73
|
-
return tf.sigmoid(tf.add(x, 1) / t) + tf.sigmoid(tf.add(x, -1) / t) - 1
|
|
74
|
-
|
|
75
|
-
|
|
76
75
|
def ste_clip(x: [tf.Tensor, tf.Variable], max_val=1, min_val=None) -> tf.Tensor:
|
|
77
76
|
"""
|
|
78
77
|
clip a variable between fixed values such that min_val<=output<=max_val
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import Dict, List, Tuple
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit.gptq import GradientPTQConfigV2
|
|
18
|
+
from model_compression_toolkit.core import common
|
|
19
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
20
|
+
from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \
|
|
21
|
+
get_inferable_quantizer_kwargs
|
|
22
|
+
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
23
|
+
from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
|
|
24
|
+
from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
|
|
25
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
|
|
26
|
+
get_inferable_quantizer_class
|
|
27
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer import \
|
|
28
|
+
BaseKerasInferableQuantizer
|
|
29
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizer_config import \
|
|
30
|
+
get_trainable_quantizer_weights_config
|
|
31
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizers import \
|
|
32
|
+
get_trainable_quantizer_class
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def quantization_builder(n: common.BaseNode,
|
|
36
|
+
gptq_config: GradientPTQConfigV2
|
|
37
|
+
) -> Tuple[Dict[str, BaseKerasGPTQTrainableQuantizer], List[BaseKerasInferableQuantizer]]:
|
|
38
|
+
"""
|
|
39
|
+
Build quantizers for a node according to its quantization configuration and
|
|
40
|
+
a global NoOpQuantizeConfig object.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
n: Node to build its QuantizeConfig.
|
|
44
|
+
gptq_config (GradientPTQConfigV2): GradientPTQConfigV2 configuration.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
A dictionary which maps the weights kernel attribute to a quantizer for GPTQ training.
|
|
48
|
+
Note that we return a dictionary although there is only a single attribute that is being mapped to a quantizer,
|
|
49
|
+
to be compatible with the quantization infrastructure template.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
weights_quantizers = {}
|
|
53
|
+
if n.is_weights_quantization_enabled():
|
|
54
|
+
quant_method = n.final_weights_quantization_cfg.weights_quantization_method
|
|
55
|
+
|
|
56
|
+
quantizer_class = get_trainable_quantizer_class(quant_target=QuantizationTarget.Weights,
|
|
57
|
+
quantizer_type=gptq_config.rounding_type,
|
|
58
|
+
quant_method=quant_method,
|
|
59
|
+
quantizer_base_class=BaseKerasGPTQTrainableQuantizer)
|
|
60
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=n.type,
|
|
61
|
+
fw_info=DEFAULT_KERAS_INFO)
|
|
62
|
+
|
|
63
|
+
weights_quantizers.update({kernel_attribute: quantizer_class(get_trainable_quantizer_weights_config(n),
|
|
64
|
+
**gptq_config.gptq_quantizer_params_override)})
|
|
65
|
+
|
|
66
|
+
activation_quantizers = []
|
|
67
|
+
if n.is_activation_quantization_enabled():
|
|
68
|
+
quant_method = n.final_activation_quantization_cfg.activation_quantization_method
|
|
69
|
+
|
|
70
|
+
quantizer_class = get_inferable_quantizer_class(quant_target=QuantizationTarget.Activation,
|
|
71
|
+
quant_method=quant_method,
|
|
72
|
+
quantizer_base_class=BaseKerasInferableQuantizer)
|
|
73
|
+
|
|
74
|
+
kwargs = get_inferable_quantizer_kwargs(n, QuantizationTarget.Activation)
|
|
75
|
+
|
|
76
|
+
activation_quantizers.append(quantizer_class(**kwargs))
|
|
77
|
+
|
|
78
|
+
return weights_quantizers, activation_quantizers
|