mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -1,263 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from typing import Dict, Any, List
|
|
16
|
-
|
|
17
|
-
import numpy as np
|
|
18
|
-
import tensorflow as tf
|
|
19
|
-
|
|
20
|
-
from model_compression_toolkit import GumbelConfig
|
|
21
|
-
from model_compression_toolkit.core.keras.quantizer.base_quantizer import BaseTrainableQuantizer
|
|
22
|
-
from model_compression_toolkit.core.common.defaultdict import DefaultDict
|
|
23
|
-
from model_compression_toolkit.core import common
|
|
24
|
-
from model_compression_toolkit.gptq.keras.quantizer import kernel_functions
|
|
25
|
-
from model_compression_toolkit.gptq.keras.quantizer.gumbel_rounding.gumbel_softmax import sample_gumbel
|
|
26
|
-
from model_compression_toolkit.gptq.common import gptq_constants
|
|
27
|
-
from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
|
|
28
|
-
from tensorflow.python.framework.tensor_shape import TensorShape
|
|
29
|
-
from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
|
|
30
|
-
|
|
31
|
-
P_INIT = 0.01
|
|
32
|
-
GR_SHIFT_BASE = 2
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def init_aux_var(ceil_indicator: np.ndarray, w_shape: List[int], m: int, p: float = P_INIT) -> np.ndarray:
|
|
36
|
-
"""
|
|
37
|
-
This function generate a random pi matrix for Gumbel Rounding such that the search start at the rounding point.
|
|
38
|
-
Args:
|
|
39
|
-
ceil_indicator: An array of indicator if the value should be ceil or floor.
|
|
40
|
-
w_shape(List[int]): A list of integers that represent the shape of the weights tensor to be quantization.
|
|
41
|
-
p(float): A floating point number that represent the probability of non round options of pi matrix.
|
|
42
|
-
m(int): An integer that define the number of shift.
|
|
43
|
-
|
|
44
|
-
Returns: A numpy array of pi tensor
|
|
45
|
-
|
|
46
|
-
"""
|
|
47
|
-
if m < 2:
|
|
48
|
-
common.logger.Logger.error("m must be larger than two")
|
|
49
|
-
if m % 2 != 0:
|
|
50
|
-
common.logger.Logger.error("m must be module two")
|
|
51
|
-
m_hat = m // 2 - 1
|
|
52
|
-
shift = -np.log(-np.log(1 - p))
|
|
53
|
-
n = np.random.randn(*[m, *w_shape]) * np.sqrt(np.power(np.pi, 2) / 6)
|
|
54
|
-
n = n.reshape([m, -1]).T
|
|
55
|
-
ceil_indicator = ceil_indicator.flatten()
|
|
56
|
-
n[np.arange(ceil_indicator.size), ceil_indicator + m_hat] += shift
|
|
57
|
-
n = n.T.reshape(*[m, *w_shape])
|
|
58
|
-
return n
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def _init_shift_var(m: int) -> List[int]:
|
|
62
|
-
"""
|
|
63
|
-
This function generate an list of 2*m+1 from -m to m
|
|
64
|
-
Args:
|
|
65
|
-
m: An integer value the represent m
|
|
66
|
-
|
|
67
|
-
Returns: A list of size m
|
|
68
|
-
|
|
69
|
-
"""
|
|
70
|
-
m_hat = m // 2
|
|
71
|
-
aux_index_shift = [-m_hat + 1 + i for i in range(m)]
|
|
72
|
-
return aux_index_shift
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
class GumbelRoundingBase(BaseTrainableQuantizer):
|
|
76
|
-
def __init__(self,
|
|
77
|
-
num_bits: int,
|
|
78
|
-
per_axis: bool,
|
|
79
|
-
signed: bool,
|
|
80
|
-
symmetric: bool,
|
|
81
|
-
power_of_two: bool,
|
|
82
|
-
quantization_parameter_learning: bool,
|
|
83
|
-
quantization_axis: int,
|
|
84
|
-
gumbel_config: GumbelConfig,
|
|
85
|
-
max_lsbs_change_map: dict = DefaultDict({}, lambda: 1),
|
|
86
|
-
max_iteration: int = 10000):
|
|
87
|
-
"""
|
|
88
|
-
A base class for GumRounding
|
|
89
|
-
|
|
90
|
-
Args:
|
|
91
|
-
num_bits: Number of bits to use for the quantization.
|
|
92
|
-
per_axis: Whether to quantize per-channel or per-tensor.
|
|
93
|
-
signed: Signedness to use for the quantization range.
|
|
94
|
-
symmetric: Whether to quantize is symmetric.
|
|
95
|
-
power_of_two: Whether to quantize is power-of-two.
|
|
96
|
-
quantization_parameter_learning: A bool flag state if the quantizer parameter are trainable
|
|
97
|
-
quantization_axis: Axis of tensor to use for the quantization.
|
|
98
|
-
gumbel_config: A class with the gumbel rounding configurations.
|
|
99
|
-
max_lsbs_change_map: a mapping between number of bits to max lsb change.
|
|
100
|
-
max_iteration: The number of iteration of gptq.
|
|
101
|
-
"""
|
|
102
|
-
self.num_bits = num_bits
|
|
103
|
-
self.per_axis = per_axis
|
|
104
|
-
self.signed = signed
|
|
105
|
-
self.quantization_axis = quantization_axis
|
|
106
|
-
self.max_iteration = max_iteration
|
|
107
|
-
self.power_of_two = power_of_two
|
|
108
|
-
self.symmetric = symmetric
|
|
109
|
-
self.quantization_parameter_learning = quantization_parameter_learning
|
|
110
|
-
self.temperature_learning = gumbel_config.temperature_learning
|
|
111
|
-
self.quantizer_parameters = {}
|
|
112
|
-
self.gumbel_config = gumbel_config
|
|
113
|
-
|
|
114
|
-
self.max_lsbs_change_map = max_lsbs_change_map
|
|
115
|
-
self.max_lsbs_change = max_lsbs_change_map.get(num_bits)
|
|
116
|
-
self.m = GR_SHIFT_BASE * self.max_lsbs_change + GR_SHIFT_BASE
|
|
117
|
-
|
|
118
|
-
self.n_cycles = gumbel_config.n_cycles
|
|
119
|
-
self.minimal_temp = gumbel_config.minimal_temp
|
|
120
|
-
self.maximal_temp = gumbel_config.maximal_temp
|
|
121
|
-
self.cycle_iterations = max(1, int(self.max_iteration / self.n_cycles))
|
|
122
|
-
self.tau = None
|
|
123
|
-
self.g_t = None
|
|
124
|
-
self.p_t = None
|
|
125
|
-
scale = self.cycle_iterations / (-2 * np.log(0.001))
|
|
126
|
-
|
|
127
|
-
self.gumbel_scale = gumbel_config.gumbel_scale
|
|
128
|
-
self.gumbel_scale_per_bitwidth = gumbel_config.gumbel_scale_per_bitwidth
|
|
129
|
-
|
|
130
|
-
def tau_function(i):
|
|
131
|
-
"""
|
|
132
|
-
A function the generate the gumbel temperature.
|
|
133
|
-
Args:
|
|
134
|
-
i: An int the represent the current iteration number
|
|
135
|
-
|
|
136
|
-
Returns: A temperature value.
|
|
137
|
-
|
|
138
|
-
"""
|
|
139
|
-
if i < (self.cycle_iterations - 1):
|
|
140
|
-
index = ((i + 1) % self.cycle_iterations) / scale
|
|
141
|
-
else:
|
|
142
|
-
index = (i % self.cycle_iterations) / scale
|
|
143
|
-
|
|
144
|
-
x = tf.exp(-index)
|
|
145
|
-
return self.minimal_temp + (self.maximal_temp - self.minimal_temp) * x
|
|
146
|
-
|
|
147
|
-
self.tau_function = tau_function
|
|
148
|
-
self.w_shape = None
|
|
149
|
-
self.update_gumbel_param = True
|
|
150
|
-
|
|
151
|
-
def enable_update(self):
|
|
152
|
-
self.update_gumbel_param = True
|
|
153
|
-
|
|
154
|
-
def disable_update(self):
|
|
155
|
-
self.update_gumbel_param = False
|
|
156
|
-
|
|
157
|
-
def build(self, tensor_shape: TensorShape,
|
|
158
|
-
name: str,
|
|
159
|
-
layer: QuantizeWrapper) -> Dict[str, tf.Variable]:
|
|
160
|
-
"""
|
|
161
|
-
Add min and max variables to layer.
|
|
162
|
-
Args:
|
|
163
|
-
tensor_shape: Tensor shape the quantizer quantize.
|
|
164
|
-
name: Prefix of variables names.
|
|
165
|
-
layer: Layer to add the variables to. The variables are saved
|
|
166
|
-
in the layer's scope.
|
|
167
|
-
|
|
168
|
-
Returns:
|
|
169
|
-
Dictionary of new variables.
|
|
170
|
-
"""
|
|
171
|
-
w_shape = kernel_functions.get_kernel(layer.weights).shape
|
|
172
|
-
self.w_shape = w_shape
|
|
173
|
-
|
|
174
|
-
ar_iter = layer.add_weight(
|
|
175
|
-
name + gptq_constants.GPTQ_ITER,
|
|
176
|
-
shape=(),
|
|
177
|
-
initializer=tf.keras.initializers.Constant(0.0),
|
|
178
|
-
trainable=False)
|
|
179
|
-
|
|
180
|
-
temp_tensor = layer.add_weight(
|
|
181
|
-
name + gptq_constants.TEMP,
|
|
182
|
-
shape=[1, *self.w_shape],
|
|
183
|
-
initializer=tf.keras.initializers.Constant(self.maximal_temp),
|
|
184
|
-
trainable=True)
|
|
185
|
-
|
|
186
|
-
shift_tensor = layer.add_weight(name + gptq_constants.AUXSHIFT,
|
|
187
|
-
shape=self.m,
|
|
188
|
-
initializer=tf.keras.initializers.Constant(0.0),
|
|
189
|
-
trainable=False)
|
|
190
|
-
shift_tensor.assign(_init_shift_var(self.m))
|
|
191
|
-
|
|
192
|
-
self.quantizer_parameters = {gptq_constants.GPTQ_ITER: ar_iter,
|
|
193
|
-
gptq_constants.AUXSHIFT: shift_tensor,
|
|
194
|
-
gptq_constants.TEMP: temp_tensor}
|
|
195
|
-
return self.quantizer_parameters
|
|
196
|
-
|
|
197
|
-
def get_aux_variable(self) -> tf.Tensor:
|
|
198
|
-
return self.quantizer_parameters[gptq_constants.AUXVAR]
|
|
199
|
-
|
|
200
|
-
def get_trainable_parameters(self) -> List[tf.Tensor]:
|
|
201
|
-
"""
|
|
202
|
-
A function to get a list trainable of trainable parameters of the quantizer for GPTQ retraining
|
|
203
|
-
|
|
204
|
-
Returns:
|
|
205
|
-
A list of trainable Tensors
|
|
206
|
-
|
|
207
|
-
"""
|
|
208
|
-
return [t for t in self.quantizer_parameters.values() if t.trainable]
|
|
209
|
-
|
|
210
|
-
def __eq__(self, other: Any) -> bool:
|
|
211
|
-
"""
|
|
212
|
-
Check if equals to another object.
|
|
213
|
-
Args:
|
|
214
|
-
other: Other object to compare.
|
|
215
|
-
|
|
216
|
-
Returns:
|
|
217
|
-
Whether they are equal or not.
|
|
218
|
-
"""
|
|
219
|
-
if not isinstance(other, GumbelRoundingBase):
|
|
220
|
-
return False
|
|
221
|
-
|
|
222
|
-
return (self.num_bits == other.num_bits and
|
|
223
|
-
self.per_axis == other.per_axis and
|
|
224
|
-
self.symmetric == other.symmetric)
|
|
225
|
-
|
|
226
|
-
def __ne__(self, other: Any) -> bool:
|
|
227
|
-
"""
|
|
228
|
-
Check if not equals to another object.
|
|
229
|
-
Args:
|
|
230
|
-
other: Other object to compare.
|
|
231
|
-
|
|
232
|
-
Returns:
|
|
233
|
-
Whether they are differ or not.
|
|
234
|
-
"""
|
|
235
|
-
return not self.__eq__(other)
|
|
236
|
-
|
|
237
|
-
def get_config(self) -> Dict[str, Any]:
|
|
238
|
-
"""
|
|
239
|
-
Returns: Configuration of TrainableQuantizer.
|
|
240
|
-
"""
|
|
241
|
-
|
|
242
|
-
return {
|
|
243
|
-
'num_bits': self.num_bits,
|
|
244
|
-
'per_axis': self.per_axis,
|
|
245
|
-
'symmetric': self.symmetric,
|
|
246
|
-
'power_of_two': self.power_of_two
|
|
247
|
-
}
|
|
248
|
-
|
|
249
|
-
def update_iteration(self, training, ar_iter):
|
|
250
|
-
if self.temperature_learning:
|
|
251
|
-
self.tau = qutils.ste_clip(self.quantizer_parameters[gptq_constants.TEMP], self.maximal_temp,
|
|
252
|
-
self.minimal_temp)
|
|
253
|
-
else:
|
|
254
|
-
self.tau = self.tau_function(ar_iter)
|
|
255
|
-
if self.update_gumbel_param and training:
|
|
256
|
-
ar_iter.assign_add(1.0)
|
|
257
|
-
self.g_t = sample_gumbel([self.m, *self.w_shape])
|
|
258
|
-
|
|
259
|
-
def get_temperature_variable(self):
|
|
260
|
-
return self.quantizer_parameters[gptq_constants.TEMP]
|
|
261
|
-
|
|
262
|
-
def get_gumbel_probability(self):
|
|
263
|
-
return self.p_t
|
|
@@ -1,75 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
import tensorflow as tf
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def sample_gumbel(shape, eps=1e-6) -> tf.Tensor:
|
|
19
|
-
"""
|
|
20
|
-
A function that sample a tensor of i.i.d gumbel random variable.
|
|
21
|
-
Args:
|
|
22
|
-
shape: The tensor output shape
|
|
23
|
-
eps: A small number for numeric stability.
|
|
24
|
-
|
|
25
|
-
Returns: A tensor of i.i.d gumbel random variable.
|
|
26
|
-
|
|
27
|
-
"""
|
|
28
|
-
u = tf.random.uniform(shape)
|
|
29
|
-
return -tf.math.log(-tf.math.log(u + eps) + eps)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def gumbel_softmax(in_pi: tf.Tensor, in_tau: tf.Tensor, in_gumbel: tf.Tensor, eps: float = 1e-6, axis=0,
|
|
33
|
-
gumbel_scale: float = 1.0) -> tf.Tensor:
|
|
34
|
-
"""
|
|
35
|
-
A gumbel softmax function.
|
|
36
|
-
Args:
|
|
37
|
-
in_pi: A tensor of log probability.
|
|
38
|
-
in_tau: A temperature tensor.
|
|
39
|
-
in_gumbel: A tensor of gumbel random variable.
|
|
40
|
-
eps: A small number for numeric stability.
|
|
41
|
-
axis: A integer representing the axis of which the gumbel softmax applyed on.
|
|
42
|
-
gumbel_scale: A normalization factor for the gumbel tensor values
|
|
43
|
-
|
|
44
|
-
Returns: A gumbel softmax probability tensor.
|
|
45
|
-
|
|
46
|
-
"""
|
|
47
|
-
return tf.nn.softmax((tf.nn.log_softmax(in_pi, axis=axis) + gumbel_scale * in_gumbel) / (in_tau + eps), axis=axis)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def ste_gumbel(in_prob: tf.Tensor) -> tf.Tensor:
|
|
51
|
-
"""
|
|
52
|
-
This function apply ste on the output of the gumbel softmax.
|
|
53
|
-
Args:
|
|
54
|
-
in_prob:A tensor of probability
|
|
55
|
-
|
|
56
|
-
Returns: A Tensor of ohe hot vector with STE.
|
|
57
|
-
|
|
58
|
-
"""
|
|
59
|
-
|
|
60
|
-
delta = tf.stop_gradient(select_gumbel(in_prob) - in_prob)
|
|
61
|
-
return in_prob + delta
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def select_gumbel(in_prob: tf.Tensor) -> tf.Tensor:
|
|
65
|
-
"""
|
|
66
|
-
This function apply ste on the output of the gumbel softmax.
|
|
67
|
-
Args:
|
|
68
|
-
in_prob: A tensor of probability.
|
|
69
|
-
|
|
70
|
-
Returns: A Tensor of ohe hot vector
|
|
71
|
-
|
|
72
|
-
"""
|
|
73
|
-
max_index = tf.argmax(in_prob, axis=0)
|
|
74
|
-
one_hot_prob = tf.one_hot(max_index, depth=in_prob.shape[0], axis=0)
|
|
75
|
-
return one_hot_prob + 0 * in_prob
|
|
@@ -1,266 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
import tensorflow as tf
|
|
16
|
-
import numpy as np
|
|
17
|
-
|
|
18
|
-
from model_compression_toolkit import GumbelConfig
|
|
19
|
-
from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
|
|
20
|
-
from model_compression_toolkit.gptq.keras.quantizer.gumbel_rounding.base_gumbel_rounding import GumbelRoundingBase, \
|
|
21
|
-
init_aux_var
|
|
22
|
-
from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
|
|
23
|
-
from tensorflow.python.framework.tensor_shape import TensorShape
|
|
24
|
-
from model_compression_toolkit.core.common.defaultdict import DefaultDict
|
|
25
|
-
from typing import Dict, Any, List
|
|
26
|
-
from model_compression_toolkit.gptq.keras.quantizer.gumbel_rounding.gumbel_softmax import gumbel_softmax, ste_gumbel
|
|
27
|
-
from model_compression_toolkit.core.common.constants import THRESHOLD, GUMBEL_MAX_ITER, MIN_THRESHOLD
|
|
28
|
-
from model_compression_toolkit.gptq.common import gptq_constants
|
|
29
|
-
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import max_power_of_two
|
|
30
|
-
from model_compression_toolkit.gptq.keras.quantizer.ste_rounding.symmetric_ste import symmetric_quantizer
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def gumbel_rounding_symmetric_quantizer(input_tensor: tf.Tensor,
|
|
34
|
-
auxvar_tensor: tf.Variable,
|
|
35
|
-
max_tensor: tf.Tensor,
|
|
36
|
-
num_bits: int,
|
|
37
|
-
signed: bool,
|
|
38
|
-
power_of_two: bool) -> tf.Tensor:
|
|
39
|
-
"""
|
|
40
|
-
Quantize a tensor symmetrically with maximum LSBs shift.
|
|
41
|
-
Args:
|
|
42
|
-
input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
|
|
43
|
-
auxvar_tensor: Tensor that manifests the bit shift the weight due to gptq.
|
|
44
|
-
max_tensor: Tensor with max values to compute the threshold.
|
|
45
|
-
num_bits: Num of bits to use.
|
|
46
|
-
signed: Signedness of the quantization range.
|
|
47
|
-
power_of_two: Whether the threshold should be constrained or not.
|
|
48
|
-
|
|
49
|
-
Returns:
|
|
50
|
-
A quantized tensor.
|
|
51
|
-
"""
|
|
52
|
-
|
|
53
|
-
if power_of_two:
|
|
54
|
-
max_tensor = qutils.power_of_two_max(max_tensor)
|
|
55
|
-
delta = qutils.calculate_delta(max_tensor, num_bits, signed)
|
|
56
|
-
input_tensor = tf.stop_gradient(input_tensor)
|
|
57
|
-
input_tensor_int = tf.floor(input_tensor / delta)
|
|
58
|
-
tensor_q = input_tensor_int + auxvar_tensor
|
|
59
|
-
min_int = -int(signed) * (2 ** (num_bits - int(signed)))
|
|
60
|
-
max_int = (2 ** (num_bits - int(signed))) - 1
|
|
61
|
-
return delta * qutils.clip(tensor_q, max_val=max_int, min_val=min_int)
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
class SymmetricGumbelRounding(GumbelRoundingBase):
|
|
65
|
-
"""
|
|
66
|
-
Trainable constrained quantizer to quantize a layer inputs.
|
|
67
|
-
"""
|
|
68
|
-
PTQ_THRESHOLD = "_ptq_threshold"
|
|
69
|
-
SCALE_PTQ = "_scale"
|
|
70
|
-
|
|
71
|
-
def __init__(self, num_bits: int,
|
|
72
|
-
per_axis: bool,
|
|
73
|
-
signed: bool,
|
|
74
|
-
power_of_two: bool,
|
|
75
|
-
quantization_parameter_learning: bool,
|
|
76
|
-
threshold_values: np.ndarray,
|
|
77
|
-
gumbel_config: GumbelConfig,
|
|
78
|
-
quantization_axis: int = -1,
|
|
79
|
-
max_lsbs_change_map: dict = DefaultDict({}, lambda: 1),
|
|
80
|
-
max_iteration: int = GUMBEL_MAX_ITER):
|
|
81
|
-
"""
|
|
82
|
-
Initialize a TrainableWeightQuantizer object with parameters to use
|
|
83
|
-
for the quantization.
|
|
84
|
-
|
|
85
|
-
Args:
|
|
86
|
-
num_bits: Number of bits to use for the quantization.
|
|
87
|
-
per_axis: Whether to quantize per-channel or per-tensor.
|
|
88
|
-
signed: Signedness to use for the quantization range.
|
|
89
|
-
threshold_values: Threshold to use for the quantization.
|
|
90
|
-
gumbel_config: A class with the gumbel rounding configurations.
|
|
91
|
-
quantization_axis: Axis of tensor to use for the quantization.
|
|
92
|
-
power_of_two: Whether the threshold should be constrained or not.
|
|
93
|
-
max_lsbs_change_map: a mapping between number of bits to max lsb change.
|
|
94
|
-
max_iteration: The number of iteration of gptq.
|
|
95
|
-
"""
|
|
96
|
-
super().__init__(num_bits, per_axis, signed, True, power_of_two, quantization_parameter_learning,
|
|
97
|
-
quantization_axis, gumbel_config,
|
|
98
|
-
max_lsbs_change_map,
|
|
99
|
-
max_iteration)
|
|
100
|
-
self.threshold_shape = np.asarray(threshold_values).shape
|
|
101
|
-
self.threshold_values = np.reshape(np.asarray(threshold_values), [-1]) if self.per_axis else float(
|
|
102
|
-
threshold_values)
|
|
103
|
-
self.k_threshold = len(self.threshold_values) if self.per_axis else 1
|
|
104
|
-
|
|
105
|
-
def build(self,
|
|
106
|
-
tensor_shape: TensorShape,
|
|
107
|
-
name: str,
|
|
108
|
-
layer: QuantizeWrapper) -> Dict[str, tf.Variable]:
|
|
109
|
-
"""
|
|
110
|
-
Add min and max variables to layer.
|
|
111
|
-
Args:
|
|
112
|
-
tensor_shape: Tensor shape the quantizer quantize.
|
|
113
|
-
name: Prefix of variables names.
|
|
114
|
-
layer: Layer to add the variables to. The variables are saved
|
|
115
|
-
in the layer's scope.
|
|
116
|
-
|
|
117
|
-
Returns:
|
|
118
|
-
Dictionary of new variables.
|
|
119
|
-
"""
|
|
120
|
-
super().build(tensor_shape, name, layer)
|
|
121
|
-
|
|
122
|
-
if self.per_axis:
|
|
123
|
-
input_shape = tensor_shape
|
|
124
|
-
n_axis = len(input_shape)
|
|
125
|
-
quantization_axis = n_axis + self.quantization_axis if self.quantization_axis < 0 else \
|
|
126
|
-
self.quantization_axis
|
|
127
|
-
reshape_shape = [self.k_threshold if i == quantization_axis else 1 for i in range(n_axis)]
|
|
128
|
-
else:
|
|
129
|
-
reshape_shape = [self.k_threshold]
|
|
130
|
-
|
|
131
|
-
ptq_threshold_tensor = layer.add_weight(
|
|
132
|
-
name + self.PTQ_THRESHOLD,
|
|
133
|
-
shape=reshape_shape,
|
|
134
|
-
initializer=tf.keras.initializers.Constant(1.0),
|
|
135
|
-
trainable=False)
|
|
136
|
-
ptq_threshold_tensor.assign(self.threshold_values.reshape(reshape_shape))
|
|
137
|
-
|
|
138
|
-
auxvar_tensor = layer.add_weight(
|
|
139
|
-
name + gptq_constants.AUXVAR,
|
|
140
|
-
shape=[self.m, *self.w_shape],
|
|
141
|
-
initializer=tf.keras.initializers.Constant(0.0),
|
|
142
|
-
trainable=True)
|
|
143
|
-
w = getattr(layer.layer, name)
|
|
144
|
-
|
|
145
|
-
q_error = w - symmetric_quantizer(w,
|
|
146
|
-
ptq_threshold_tensor,
|
|
147
|
-
num_bits=self.num_bits,
|
|
148
|
-
signed=True,
|
|
149
|
-
power_of_two=self.power_of_two)
|
|
150
|
-
|
|
151
|
-
ceil_indicator = (q_error < 0).numpy().astype("int") # Negative error means the choose point is rounded to ceil.
|
|
152
|
-
auxvar_tensor.assign(init_aux_var(ceil_indicator, self.w_shape, self.m))
|
|
153
|
-
|
|
154
|
-
self.quantizer_parameters.update({gptq_constants.AUXVAR: auxvar_tensor,
|
|
155
|
-
self.PTQ_THRESHOLD: ptq_threshold_tensor})
|
|
156
|
-
|
|
157
|
-
if self.quantization_parameter_learning and not self.power_of_two:
|
|
158
|
-
scale = layer.add_weight(
|
|
159
|
-
name + self.SCALE_PTQ,
|
|
160
|
-
shape=self.k_threshold,
|
|
161
|
-
initializer=tf.keras.initializers.Constant(1.0),
|
|
162
|
-
trainable=True)
|
|
163
|
-
self.quantizer_parameters.update({self.SCALE_PTQ: scale})
|
|
164
|
-
|
|
165
|
-
return self.quantizer_parameters
|
|
166
|
-
|
|
167
|
-
def get_quantization_variable(self) -> List[tf.Tensor]:
|
|
168
|
-
"""
|
|
169
|
-
This function return a list of quantizer parameters.
|
|
170
|
-
Returns: A list of the quantizer parameters
|
|
171
|
-
|
|
172
|
-
"""
|
|
173
|
-
if self.quantization_parameter_learning and not self.power_of_two:
|
|
174
|
-
return [self.quantizer_parameters[self.SCALE_PTQ]]
|
|
175
|
-
else:
|
|
176
|
-
return []
|
|
177
|
-
|
|
178
|
-
def __call__(self, inputs: tf.Tensor,
|
|
179
|
-
training: bool,
|
|
180
|
-
weights: Dict[str, tf.Variable],
|
|
181
|
-
**kwargs: Dict[str, Any]):
|
|
182
|
-
"""
|
|
183
|
-
Quantize a tensor.
|
|
184
|
-
Args:
|
|
185
|
-
inputs: Input tensor to quantize.
|
|
186
|
-
training: Whether the graph is in training mode.
|
|
187
|
-
weights: Dictionary of weights the quantizer can use to quantize the tensor.
|
|
188
|
-
**kwargs: Additional variables the quantizer may receive.
|
|
189
|
-
|
|
190
|
-
Returns:
|
|
191
|
-
The quantized tensor.
|
|
192
|
-
"""
|
|
193
|
-
|
|
194
|
-
auxvar = weights[gptq_constants.AUXVAR]
|
|
195
|
-
ar_iter = weights[gptq_constants.GPTQ_ITER]
|
|
196
|
-
ptq_threshold_tensor = weights[self.PTQ_THRESHOLD]
|
|
197
|
-
aux_index_shift = weights[gptq_constants.AUXSHIFT]
|
|
198
|
-
self.update_iteration(training, ar_iter)
|
|
199
|
-
if self.per_axis:
|
|
200
|
-
input_shape = inputs.shape
|
|
201
|
-
n_axis = len(input_shape)
|
|
202
|
-
quantization_axis = n_axis + self.quantization_axis if self.quantization_axis < 0 else \
|
|
203
|
-
self.quantization_axis
|
|
204
|
-
reshape_shape = [-1 if i == quantization_axis else 1 for i in range(n_axis)]
|
|
205
|
-
|
|
206
|
-
reshape_shape_aux_ind = [-1, *[1 for _ in range(n_axis)]]
|
|
207
|
-
#####################################################
|
|
208
|
-
# Gumbel Softmax
|
|
209
|
-
#####################################################
|
|
210
|
-
if training:
|
|
211
|
-
gumbel_scale = self.gumbel_scale if self.gumbel_scale_per_bitwidth is None \
|
|
212
|
-
else self.gumbel_scale_per_bitwidth.get(self.num_bits, self.gumbel_scale)
|
|
213
|
-
p_t = gumbel_softmax(auxvar, self.tau, self.g_t, gumbel_scale=gumbel_scale)
|
|
214
|
-
else:
|
|
215
|
-
p_t = gumbel_softmax(auxvar, self.minimal_temp, 0)
|
|
216
|
-
p_t = ste_gumbel(p_t)
|
|
217
|
-
self.p_t = p_t
|
|
218
|
-
#####################################################
|
|
219
|
-
# Calculate v hat and threshold hat
|
|
220
|
-
#####################################################
|
|
221
|
-
ptq_threshold_tensor_hat = tf.reshape(ptq_threshold_tensor, reshape_shape)
|
|
222
|
-
auxvar_hat = tf.reduce_sum(p_t * tf.reshape(aux_index_shift, reshape_shape_aux_ind), axis=0)
|
|
223
|
-
#####################################################
|
|
224
|
-
# Quantized Input
|
|
225
|
-
#####################################################
|
|
226
|
-
q_tensor = gumbel_rounding_symmetric_quantizer(inputs, auxvar_hat,
|
|
227
|
-
ptq_threshold_tensor_hat,
|
|
228
|
-
self.num_bits,
|
|
229
|
-
self.signed,
|
|
230
|
-
self.power_of_two)
|
|
231
|
-
if self.quantization_parameter_learning and not self.power_of_two:
|
|
232
|
-
scale = tf.reshape(self.quantizer_parameters[self.SCALE_PTQ], reshape_shape)
|
|
233
|
-
q_tensor *= scale
|
|
234
|
-
|
|
235
|
-
return q_tensor
|
|
236
|
-
else:
|
|
237
|
-
return gumbel_rounding_symmetric_quantizer(inputs, auxvar,
|
|
238
|
-
ptq_threshold_tensor,
|
|
239
|
-
self.num_bits,
|
|
240
|
-
self.signed,
|
|
241
|
-
self.power_of_two)
|
|
242
|
-
|
|
243
|
-
def get_quant_config(self, layer) -> Dict[str, np.ndarray]:
|
|
244
|
-
"""
|
|
245
|
-
Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
|
|
246
|
-
|
|
247
|
-
Args:
|
|
248
|
-
layer: quantized layer
|
|
249
|
-
|
|
250
|
-
Returns:
|
|
251
|
-
A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
|
|
252
|
-
Keys must match NodeQuantizationConfig attributes
|
|
253
|
-
|
|
254
|
-
"""
|
|
255
|
-
|
|
256
|
-
if self.power_of_two:
|
|
257
|
-
old_threshold = self.quantizer_parameters[self.PTQ_THRESHOLD]
|
|
258
|
-
old_threshold = max_power_of_two(old_threshold, MIN_THRESHOLD)
|
|
259
|
-
else:
|
|
260
|
-
old_threshold = self.quantizer_parameters[self.PTQ_THRESHOLD]
|
|
261
|
-
if self.quantization_parameter_learning:
|
|
262
|
-
scale = tf.reshape(self.quantizer_parameters[self.SCALE_PTQ], self.threshold_shape)
|
|
263
|
-
old_threshold = old_threshold * scale
|
|
264
|
-
old_threshold = old_threshold.numpy()
|
|
265
|
-
old_threshold = old_threshold.reshape(self.threshold_shape)
|
|
266
|
-
return {THRESHOLD: old_threshold}
|