mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -13,45 +13,24 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from typing import Dict, Any
|
|
16
|
+
from typing import Dict, Any
|
|
17
17
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
import tensorflow as tf
|
|
20
|
-
|
|
21
|
-
from
|
|
22
|
-
from model_compression_toolkit
|
|
23
|
-
from model_compression_toolkit.
|
|
20
|
+
|
|
21
|
+
from model_compression_toolkit.gptq import RoundingType
|
|
22
|
+
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
23
|
+
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
24
|
+
from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD
|
|
25
|
+
from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
|
|
24
26
|
from model_compression_toolkit.core.common.constants import THRESHOLD
|
|
25
27
|
from model_compression_toolkit.core.common.defaultdict import DefaultDict
|
|
26
|
-
from model_compression_toolkit.gptq.keras.quantizer.
|
|
27
|
-
from model_compression_toolkit.
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
num_bits: int,
|
|
33
|
-
signed: bool,
|
|
34
|
-
power_of_two: bool = False) -> tf.Tensor:
|
|
35
|
-
"""
|
|
36
|
-
Quantize a tensor symmetrically.
|
|
37
|
-
Args:
|
|
38
|
-
input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
|
|
39
|
-
max_tensor: Tensor with max values to compute the threshold.
|
|
40
|
-
num_bits: Num of bits to use.
|
|
41
|
-
signed: Signedness of the quantization range.
|
|
42
|
-
power_of_two: Whether the threshold should be constrained or not.
|
|
43
|
-
|
|
44
|
-
Returns:
|
|
45
|
-
A quantized tensor.
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
if power_of_two:
|
|
49
|
-
max_tensor = qutils.power_of_two_max(max_tensor)
|
|
50
|
-
delta = qutils.calculate_delta(max_tensor, num_bits, signed)
|
|
51
|
-
tensor_q = qutils.ste_round(input_tensor / delta)
|
|
52
|
-
min_int = -int(signed) * (2 ** (num_bits - int(signed)))
|
|
53
|
-
max_int = (2 ** (num_bits - int(signed))) - 1
|
|
54
|
-
return delta * qutils.ste_clip(tensor_q, max_val=max_int, min_val=min_int)
|
|
28
|
+
from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
|
|
29
|
+
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
|
|
30
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
|
|
31
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
|
|
32
|
+
get_threshold_reshape_shape
|
|
33
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
55
34
|
|
|
56
35
|
|
|
57
36
|
def pertubation_symmetric_quantizer(input_tensor: tf.Tensor,
|
|
@@ -63,6 +42,7 @@ def pertubation_symmetric_quantizer(input_tensor: tf.Tensor,
|
|
|
63
42
|
max_lsbs_change: int = 1) -> tf.Tensor:
|
|
64
43
|
"""
|
|
65
44
|
Quantize a tensor symmetrically with maximum LSBs shift.
|
|
45
|
+
|
|
66
46
|
Args:
|
|
67
47
|
input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
|
|
68
48
|
auxvar_tensor: Tensor that manifests the bit shift the weight due to gptq
|
|
@@ -87,195 +67,115 @@ def pertubation_symmetric_quantizer(input_tensor: tf.Tensor,
|
|
|
87
67
|
return delta * qutils.ste_clip(tensor_q, max_val=max_int, min_val=min_int)
|
|
88
68
|
|
|
89
69
|
|
|
90
|
-
|
|
70
|
+
@mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
|
|
71
|
+
quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
|
|
72
|
+
quantizer_type=RoundingType.STE)
|
|
73
|
+
class STEWeightGPTQQuantizer(BaseKerasGPTQTrainableQuantizer):
|
|
91
74
|
"""
|
|
92
|
-
Trainable
|
|
75
|
+
Trainable symmetric quantizer to quantize a layer weights.
|
|
93
76
|
"""
|
|
94
77
|
|
|
95
78
|
def __init__(self,
|
|
96
|
-
|
|
97
|
-
per_axis: bool,
|
|
98
|
-
signed: bool,
|
|
99
|
-
threshold_values: np.ndarray,
|
|
100
|
-
quantization_axis: int = -1,
|
|
101
|
-
power_of_two: bool = True,
|
|
79
|
+
quantization_config: TrainableQuantizerWeightsConfig,
|
|
102
80
|
max_lsbs_change_map: dict = DefaultDict({}, lambda: 1)):
|
|
103
81
|
"""
|
|
104
|
-
Initialize a
|
|
105
|
-
for the quantization.
|
|
82
|
+
Initialize a STEWeightGPTQQuantizer object with parameters to use for the quantization.
|
|
106
83
|
|
|
107
84
|
Args:
|
|
108
|
-
|
|
109
|
-
per_axis: Whether to quantize per-channel or per-tensor.
|
|
110
|
-
signed: Signedness to use for the quantization range.
|
|
111
|
-
threshold_values: Threshold to use for the quantization.
|
|
112
|
-
quantization_axis: Axis of tensor to use for the quantization.
|
|
113
|
-
power_of_two: Whether the threshold should be constrained or not.
|
|
85
|
+
quantization_config: Trainable weights quantizer config.
|
|
114
86
|
max_lsbs_change_map: a mapping between number of bits to max lsb change.
|
|
115
87
|
"""
|
|
116
|
-
|
|
117
|
-
self.
|
|
118
|
-
self.
|
|
88
|
+
super().__init__(quantization_config)
|
|
89
|
+
self.num_bits = quantization_config.weights_n_bits
|
|
90
|
+
self.per_channel = quantization_config.weights_per_channel_threshold
|
|
91
|
+
|
|
92
|
+
threshold_values = quantization_config.weights_quantization_params[THRESHOLD]
|
|
119
93
|
self.threshold_shape = np.asarray(threshold_values).shape
|
|
120
|
-
self.threshold_values = np.reshape(np.asarray(threshold_values), [-1]) if self.
|
|
94
|
+
self.threshold_values = np.reshape(np.asarray(threshold_values), [-1]) if self.per_channel else float(
|
|
121
95
|
threshold_values)
|
|
122
|
-
self.quantization_axis = quantization_axis
|
|
123
|
-
self.power_of_two = power_of_two
|
|
124
|
-
self.max_lsbs_change = max_lsbs_change_map.get(num_bits)
|
|
125
|
-
self.quantizer_parameters = {}
|
|
126
96
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
97
|
+
self.quantization_axis = quantization_config.weights_channels_axis
|
|
98
|
+
self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
|
|
99
|
+
self.max_lsbs_change = max_lsbs_change_map.get(self.num_bits)
|
|
100
|
+
|
|
101
|
+
def initialize_quantization(self,
|
|
102
|
+
tensor_shape: Any,
|
|
103
|
+
name: str,
|
|
104
|
+
layer: Any):
|
|
131
105
|
"""
|
|
132
|
-
Add
|
|
133
|
-
Args:
|
|
134
|
-
tensor_shape: Tensor shape the quantizer quantize.
|
|
135
|
-
name: Prefix of variables names.
|
|
136
|
-
layer: Layer to add the variables to. The variables are saved
|
|
137
|
-
in the layer's scope.
|
|
106
|
+
Add quantizer parameters to the quantizer parameters dictionary
|
|
138
107
|
|
|
139
|
-
|
|
140
|
-
|
|
108
|
+
Args:
|
|
109
|
+
tensor_shape: tensor shape of the quantized tensor.
|
|
110
|
+
name: Tensor name.
|
|
111
|
+
layer: Layer to quantize.
|
|
141
112
|
"""
|
|
142
|
-
w_shape = get_kernel(layer.weights).shape
|
|
143
|
-
ar_iter = layer.add_weight(
|
|
144
|
-
name + gptq_constants.GPTQ_ITER,
|
|
145
|
-
shape=(),
|
|
146
|
-
initializer=tf.keras.initializers.Constant(0.0),
|
|
147
|
-
trainable=False)
|
|
148
113
|
|
|
149
114
|
ptq_threshold_tensor = layer.add_weight(
|
|
150
|
-
name
|
|
151
|
-
shape=len(self.threshold_values) if self.
|
|
115
|
+
f"{name}_{PTQ_THRESHOLD}",
|
|
116
|
+
shape=len(self.threshold_values) if self.per_channel else (),
|
|
152
117
|
initializer=tf.keras.initializers.Constant(1.0),
|
|
153
118
|
trainable=False)
|
|
154
119
|
ptq_threshold_tensor.assign(self.threshold_values)
|
|
155
120
|
|
|
121
|
+
w = getattr(layer.layer, name)
|
|
156
122
|
auxvar_tensor = layer.add_weight(
|
|
157
|
-
name
|
|
158
|
-
shape=
|
|
123
|
+
f"{name}_{AUXVAR}",
|
|
124
|
+
shape=list(w.shape),
|
|
159
125
|
initializer=tf.keras.initializers.Constant(0.0),
|
|
160
126
|
trainable=True)
|
|
161
127
|
|
|
162
128
|
# save the quantizer added parameters for later calculations
|
|
163
|
-
self.
|
|
164
|
-
|
|
165
|
-
gptq_constants.GPTQ_ITER: ar_iter}
|
|
166
|
-
return self.quantizer_parameters
|
|
129
|
+
self.add_quantizer_variable(PTQ_THRESHOLD, ptq_threshold_tensor, VariableGroup.QPARAMS)
|
|
130
|
+
self.add_quantizer_variable(AUXVAR, auxvar_tensor, VariableGroup.WEIGHTS)
|
|
167
131
|
|
|
168
|
-
def __call__(self,
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
**kwargs: Dict[str, Any]):
|
|
132
|
+
def __call__(self,
|
|
133
|
+
inputs: tf.Tensor,
|
|
134
|
+
training: bool):
|
|
172
135
|
"""
|
|
173
136
|
Quantize a tensor.
|
|
137
|
+
|
|
174
138
|
Args:
|
|
175
139
|
inputs: Input tensor to quantize.
|
|
176
140
|
training: Whether the graph is in training mode.
|
|
177
|
-
weights: Dictionary of weights the quantizer can use to quantize the tensor.
|
|
178
|
-
**kwargs: Additional variables the quantizer may receive.
|
|
179
141
|
|
|
180
142
|
Returns:
|
|
181
143
|
The quantized tensor.
|
|
182
144
|
"""
|
|
183
145
|
|
|
184
|
-
auxvar =
|
|
185
|
-
ptq_threshold_tensor =
|
|
146
|
+
auxvar = self.get_quantizer_variable(AUXVAR)
|
|
147
|
+
ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
|
|
186
148
|
|
|
187
|
-
if self.
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
self.quantization_axis
|
|
192
|
-
reshape_shape = [-1 if i == quantization_axis else 1 for i in range(n_axis)]
|
|
149
|
+
if self.per_channel:
|
|
150
|
+
reshape_shape = get_threshold_reshape_shape(inputs.shape,
|
|
151
|
+
quant_axis=self.quantization_axis,
|
|
152
|
+
quant_axis_dim=-1)
|
|
193
153
|
ptq_threshold_tensor = tf.reshape(ptq_threshold_tensor, reshape_shape)
|
|
194
|
-
q_tensor = pertubation_symmetric_quantizer(inputs,
|
|
154
|
+
q_tensor = pertubation_symmetric_quantizer(inputs,
|
|
155
|
+
auxvar,
|
|
195
156
|
ptq_threshold_tensor,
|
|
196
157
|
self.num_bits,
|
|
197
|
-
|
|
198
|
-
self.power_of_two,
|
|
158
|
+
signed=True,
|
|
159
|
+
power_of_two=self.power_of_two,
|
|
199
160
|
max_lsbs_change=self.max_lsbs_change)
|
|
200
161
|
return q_tensor
|
|
201
162
|
else:
|
|
202
|
-
return pertubation_symmetric_quantizer(inputs,
|
|
163
|
+
return pertubation_symmetric_quantizer(inputs,
|
|
164
|
+
auxvar,
|
|
203
165
|
ptq_threshold_tensor,
|
|
204
166
|
self.num_bits,
|
|
205
|
-
|
|
206
|
-
self.power_of_two)
|
|
207
|
-
|
|
208
|
-
def get_aux_variable(self) -> tf.Tensor:
|
|
209
|
-
return self.quantizer_parameters[gptq_constants.AUXVAR]
|
|
210
|
-
|
|
211
|
-
def get_config(self) -> Dict[str, Any]:
|
|
212
|
-
"""
|
|
213
|
-
Returns: Configuration of TrainableQuantizer.
|
|
214
|
-
"""
|
|
167
|
+
signed=True,
|
|
168
|
+
power_of_two=self.power_of_two)
|
|
215
169
|
|
|
216
|
-
return {
|
|
217
|
-
'num_bits': self.num_bits,
|
|
218
|
-
'per_axis': self.per_axis,
|
|
219
|
-
'symmetric': self.symmetric,
|
|
220
|
-
'power_of_two': self.power_of_two
|
|
221
|
-
}
|
|
222
170
|
|
|
223
|
-
def get_quant_config(self
|
|
171
|
+
def get_quant_config(self) -> Dict[str, np.ndarray]:
|
|
224
172
|
"""
|
|
225
173
|
Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
|
|
226
174
|
|
|
227
|
-
Args:
|
|
228
|
-
layer: quantized layer
|
|
229
|
-
|
|
230
175
|
Returns:
|
|
231
176
|
A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
|
|
232
177
|
Keys must match NodeQuantizationConfig attributes
|
|
233
178
|
|
|
234
179
|
"""
|
|
235
|
-
old_threshold = self.
|
|
180
|
+
old_threshold = self.get_quantizer_variable(PTQ_THRESHOLD)
|
|
236
181
|
return {THRESHOLD: old_threshold.numpy().reshape(self.threshold_shape)}
|
|
237
|
-
|
|
238
|
-
def get_trainable_parameters(self):
|
|
239
|
-
"""
|
|
240
|
-
A function to get a list trainable of trainable parameters of the quantizer for GPTQ retraining
|
|
241
|
-
|
|
242
|
-
Returns:
|
|
243
|
-
A list of trainable Tensors
|
|
244
|
-
|
|
245
|
-
"""
|
|
246
|
-
return [t for t in self.quantizer_parameters.values() if t.trainable]
|
|
247
|
-
|
|
248
|
-
def get_quantization_variable(self) -> List[tf.Tensor]:
|
|
249
|
-
"""
|
|
250
|
-
This function return a list of quantizer parameters.
|
|
251
|
-
Returns: A list of the quantizer parameters
|
|
252
|
-
|
|
253
|
-
"""
|
|
254
|
-
return [self.quantizer_parameters[gptq_constants.THRESHOLD_TENSOR]]
|
|
255
|
-
|
|
256
|
-
def __eq__(self, other: Any) -> bool:
|
|
257
|
-
"""
|
|
258
|
-
Check if equals to another object.
|
|
259
|
-
Args:
|
|
260
|
-
other: Other object to compare.
|
|
261
|
-
|
|
262
|
-
Returns:
|
|
263
|
-
Whether they are equal or not.
|
|
264
|
-
"""
|
|
265
|
-
if not isinstance(other, STEWeightQuantizer):
|
|
266
|
-
return False
|
|
267
|
-
|
|
268
|
-
return (self.num_bits == other.num_bits and
|
|
269
|
-
self.per_axis == other.per_axis and
|
|
270
|
-
self.symmetric == other.symmetric)
|
|
271
|
-
|
|
272
|
-
def __ne__(self, other: Any) -> bool:
|
|
273
|
-
"""
|
|
274
|
-
Check if not equals to another object.
|
|
275
|
-
Args:
|
|
276
|
-
other: Other object to compare.
|
|
277
|
-
|
|
278
|
-
Returns:
|
|
279
|
-
Whether they are differ or not.
|
|
280
|
-
"""
|
|
281
|
-
return not self.__eq__(other)
|
|
@@ -12,24 +12,29 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from typing import Callable, List, Tuple
|
|
15
|
+
from typing import Callable, List, Tuple, Union
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
|
+
from torch.nn import Module
|
|
18
19
|
from tqdm import tqdm
|
|
19
20
|
import copy
|
|
20
21
|
import torch
|
|
21
22
|
from model_compression_toolkit.core.common.logger import Logger
|
|
23
|
+
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
|
|
24
|
+
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
22
25
|
from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
|
|
23
26
|
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
|
|
24
|
-
from model_compression_toolkit.core.common import Graph
|
|
27
|
+
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
25
28
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
26
29
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
27
|
-
from model_compression_toolkit.core.pytorch.constants import BIAS
|
|
28
|
-
from model_compression_toolkit.gptq.pytorch.gptq_model_builder import GPTQPytorchModelBuilder
|
|
30
|
+
from model_compression_toolkit.core.pytorch.constants import BIAS
|
|
29
31
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, set_model, torch_tensor_to_numpy
|
|
30
|
-
from model_compression_toolkit.gptq.pytorch.
|
|
31
|
-
|
|
32
|
-
from model_compression_toolkit.gptq.pytorch.
|
|
32
|
+
from model_compression_toolkit.gptq.pytorch.graph_info import get_gptq_trainable_parameters, \
|
|
33
|
+
get_weights_for_loss
|
|
34
|
+
from model_compression_toolkit.gptq.pytorch.quantizer.quantization_builder import quantization_builder
|
|
35
|
+
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
36
|
+
from model_compression_toolkit.gptq.pytorch.quantizer.regularization_factory import get_regularization
|
|
37
|
+
from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
|
|
33
38
|
|
|
34
39
|
|
|
35
40
|
class PytorchGPTQTrainer(GPTQTrainer):
|
|
@@ -66,11 +71,9 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
66
71
|
else:
|
|
67
72
|
self.input_scale = self.gptq_user_info.input_scale
|
|
68
73
|
|
|
69
|
-
trainable_weights, trainable_bias, trainable_threshold
|
|
74
|
+
trainable_weights, trainable_bias, trainable_threshold = get_gptq_trainable_parameters(
|
|
70
75
|
self.fxp_model,
|
|
71
|
-
add_bias=self.gptq_config.train_bias
|
|
72
|
-
quantization_parameters_learning=self.gptq_config.quantization_parameters_learning,
|
|
73
|
-
is_gumbel=self.gptq_config.is_gumbel)
|
|
76
|
+
add_bias=self.gptq_config.train_bias)
|
|
74
77
|
|
|
75
78
|
self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model)
|
|
76
79
|
if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
|
|
@@ -81,10 +84,45 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
81
84
|
|
|
82
85
|
self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights,
|
|
83
86
|
trainable_bias,
|
|
84
|
-
trainable_threshold
|
|
85
|
-
trainable_temperature)
|
|
87
|
+
trainable_threshold)
|
|
86
88
|
|
|
87
|
-
self.weights_for_average_loss = to_torch_tensor(self.
|
|
89
|
+
self.weights_for_average_loss = to_torch_tensor(self.compute_hessian_based_weights(representative_data_gen))
|
|
90
|
+
|
|
91
|
+
self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
|
|
92
|
+
|
|
93
|
+
def _is_gptq_applicable(self,
|
|
94
|
+
node: BaseNode) -> bool:
|
|
95
|
+
"""
|
|
96
|
+
A function for deciding if a layer should be fine-tuned during GPTQ.
|
|
97
|
+
Args:
|
|
98
|
+
node (BaseNode): Node for quantization decision
|
|
99
|
+
Returns:
|
|
100
|
+
A boolean whether the layer is to be wrapped with a Quantization Wrapper.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
if node.is_weights_quantization_enabled() and not self.fw_info.is_kernel_op(node.type):
|
|
104
|
+
Logger.error(f"GPTQ Error: Quantizing node {node.name} of type {node.type} "
|
|
105
|
+
f"without a kernel isn't supported.")
|
|
106
|
+
return node.is_weights_quantization_enabled()
|
|
107
|
+
|
|
108
|
+
def gptq_wrapper(self, n: BaseNode, layer: Module) -> Union[qi.PytorchQuantizationWrapper, Module]:
|
|
109
|
+
"""
|
|
110
|
+
A function which takes a computational graph node and a pytorch layer and perform the quantization wrapping.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
n: A node of mct graph.
|
|
114
|
+
layer: A pytorch layer
|
|
115
|
+
|
|
116
|
+
Returns: Wrapped layer if the layer should be wrap, otherwise returns the layer as is.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
if self._is_gptq_applicable(n):
|
|
120
|
+
weights_quantizers, activation_quantizers = quantization_builder(n, self.gptq_config)
|
|
121
|
+
return qi.PytorchQuantizationWrapper(layer,
|
|
122
|
+
weights_quantizers=weights_quantizers,
|
|
123
|
+
activation_quantizers=activation_quantizers)
|
|
124
|
+
else:
|
|
125
|
+
return layer
|
|
88
126
|
|
|
89
127
|
def build_gptq_model(self):
|
|
90
128
|
"""
|
|
@@ -92,10 +130,13 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
92
130
|
Returns:
|
|
93
131
|
Quantized graph for GPTQ fine-tuning, GPTQ graph user info
|
|
94
132
|
"""
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
133
|
+
gptq_model, gptq_user_info = PyTorchModelBuilder(graph=self.graph_quant,
|
|
134
|
+
append2output=self.compare_points,
|
|
135
|
+
fw_info=self.fw_info,
|
|
136
|
+
wrapper=self.gptq_wrapper,
|
|
137
|
+
return_float_outputs=True).build_model()
|
|
138
|
+
|
|
139
|
+
return gptq_model, gptq_user_info
|
|
99
140
|
|
|
100
141
|
def train(self, representative_data_gen: Callable):
|
|
101
142
|
"""
|
|
@@ -145,14 +186,9 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
145
186
|
self.compare_points_std,
|
|
146
187
|
self.weights_for_average_loss)
|
|
147
188
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
for p in gumbel_prob:
|
|
152
|
-
entropy = -torch.mean(torch.sum(p * torch.log(torch.maximum(p, self.gptq_config.eps*torch.ones_like(p))),dim=0))
|
|
153
|
-
gumbel_reg += entropy
|
|
154
|
-
gumbel_reg = 0 if gumbel_reg == 0 else gumbel_reg/len(gumbel_prob)
|
|
155
|
-
loss_value += self.gptq_config.quantizer_config.gumbel_entropy_regularization * gumbel_reg
|
|
189
|
+
reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor)
|
|
190
|
+
|
|
191
|
+
loss_value += reg_value
|
|
156
192
|
|
|
157
193
|
# Back-pass
|
|
158
194
|
loss_value.backward()
|
|
@@ -202,20 +238,23 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
202
238
|
|
|
203
239
|
# Update graph after training
|
|
204
240
|
for name, layer in self.fxp_model.named_modules():
|
|
205
|
-
if isinstance(layer,
|
|
241
|
+
if isinstance(layer, PytorchQuantizationWrapper):
|
|
206
242
|
node = self.graph_quant.find_node_by_name(name)
|
|
207
243
|
if len(node) != 1:
|
|
208
244
|
Logger.error(f"Can't update GPTQ graph due to missing layer named: {name}")
|
|
209
245
|
node = node[0]
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
node.
|
|
218
|
-
|
|
246
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
|
|
247
|
+
fw_info=self.fw_info)
|
|
248
|
+
weights, weight_quant_config, activation_quant_config = \
|
|
249
|
+
layer.weights_quantizers[kernel_attribute].update_layer_quantization_params(layer)
|
|
250
|
+
for weight_attr, weight in weights.items():
|
|
251
|
+
node.set_weights_by_keys(weight_attr, self.fw_impl.to_numpy(weight))
|
|
252
|
+
for config_attr, config_value in weight_quant_config.items():
|
|
253
|
+
node.final_weights_quantization_cfg.set_quant_config_attr(config_attr, config_value)
|
|
254
|
+
for config_attr, config_value in activation_quant_config.items():
|
|
255
|
+
node.final_activation_quantization_cfg.set_quant_config_attr(config_attr, config_value)
|
|
256
|
+
if self.gptq_config.train_bias and hasattr(layer.layer, BIAS):
|
|
257
|
+
node.set_weights_by_keys(BIAS, self.fw_impl.to_numpy(getattr(layer.layer, BIAS)))
|
|
219
258
|
|
|
220
259
|
return graph_quant
|
|
221
260
|
|
|
@@ -229,7 +268,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
229
268
|
|
|
230
269
|
# Fxp model: unfreeze bias trainable parameters
|
|
231
270
|
for layer in self.fxp_model.modules():
|
|
232
|
-
if isinstance(layer,
|
|
233
|
-
if hasattr(layer.
|
|
234
|
-
bias = getattr(layer.
|
|
271
|
+
if isinstance(layer, PytorchQuantizationWrapper):
|
|
272
|
+
if hasattr(layer.layer, BIAS):
|
|
273
|
+
bias = getattr(layer.layer, BIAS)
|
|
235
274
|
bias.requires_grad = self.gptq_config.train_bias
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
from typing import List
|
|
18
|
+
from model_compression_toolkit.core.pytorch.constants import BIAS
|
|
19
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
20
|
+
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
|
|
22
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_gptq_trainable_parameters(fxp_model: nn.Module,
|
|
26
|
+
add_bias: bool = False,
|
|
27
|
+
) -> (List[nn.Parameter], List[nn.Parameter], List[nn.Parameter]):
|
|
28
|
+
"""
|
|
29
|
+
Get trainable parameters from all layers in a model
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
fxp_model: Model to get its trainable parameters.
|
|
33
|
+
add_bias: Whether to include biases of the model (if there are) or not.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
A list of trainable variables in a model. Each item is a list of a layers weights.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
trainable_aux_weights = nn.ParameterList()
|
|
40
|
+
trainable_threshold = nn.ParameterList()
|
|
41
|
+
trainable_bias = nn.ParameterList()
|
|
42
|
+
|
|
43
|
+
for layer in fxp_model.modules():
|
|
44
|
+
if isinstance(layer, PytorchQuantizationWrapper):
|
|
45
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
|
46
|
+
fw_info=DEFAULT_PYTORCH_INFO)
|
|
47
|
+
|
|
48
|
+
# collect trainable weights per quantizer
|
|
49
|
+
quantizer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.WEIGHTS)
|
|
50
|
+
quantizer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.QPARAMS)
|
|
51
|
+
trainable_aux_weights.extend(quantizer_trainable_weights)
|
|
52
|
+
trainable_threshold.extend(quantizer_trainable_threshold)
|
|
53
|
+
|
|
54
|
+
if add_bias and hasattr(layer.layer, BIAS):
|
|
55
|
+
bias = getattr(layer.layer, BIAS)
|
|
56
|
+
trainable_bias.append(bias)
|
|
57
|
+
|
|
58
|
+
return trainable_aux_weights, trainable_bias, trainable_threshold
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def get_weights_for_loss(fxp_model: nn.Module) -> [List[nn.Parameter], List[torch.Tensor]]:
|
|
62
|
+
"""
|
|
63
|
+
Get all float and quantized kernels for the GPTQ loss
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
fxp_model: Model to get its float and quantized weights.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
A list of float kernels, each item is the float kernel of the layer
|
|
70
|
+
A list of quantized kernels, each item is the quantized kernel of the layer
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
flp_weights_list, fxp_weights_list = [], []
|
|
74
|
+
for layer in fxp_model.modules():
|
|
75
|
+
if isinstance(layer, PytorchQuantizationWrapper):
|
|
76
|
+
# Collect pairs of float and quantized weights per layer
|
|
77
|
+
for weight, quantizer_vars, quantizer in layer.get_weights_vars():
|
|
78
|
+
flp_weights_list.append(quantizer_vars)
|
|
79
|
+
fxp_weights_list.append(quantizer(training=False, inputs=quantizer_vars))
|
|
80
|
+
|
|
81
|
+
return flp_weights_list, fxp_weights_list
|
|
@@ -21,6 +21,7 @@ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV
|
|
|
21
21
|
from model_compression_toolkit.core.common.target_platform import TargetPlatformCapabilities
|
|
22
22
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
23
23
|
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
|
|
24
|
+
from model_compression_toolkit.gptq.keras.quantization_facade import GPTQ_MOMENTUM
|
|
24
25
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
|
25
26
|
from model_compression_toolkit.core.exporter import export_model
|
|
26
27
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
@@ -38,7 +39,7 @@ if FOUND_TORCH:
|
|
|
38
39
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
|
39
40
|
from model_compression_toolkit.core.pytorch.constants import DEFAULT_TP_MODEL
|
|
40
41
|
from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss
|
|
41
|
-
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import
|
|
42
|
+
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
|
|
42
43
|
import torch
|
|
43
44
|
from torch.nn import Module
|
|
44
45
|
from torch.optim import Adam, Optimizer
|
|
@@ -71,26 +72,19 @@ if FOUND_TORCH:
|
|
|
71
72
|
Import MCT and Create a GradientPTQConfigV2 to run for 5 epochs:
|
|
72
73
|
|
|
73
74
|
>>> import model_compression_toolkit as mct
|
|
74
|
-
>>> gptq_conf = mct.get_pytorch_gptq_config(n_epochs=5)
|
|
75
|
+
>>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=5)
|
|
75
76
|
|
|
76
77
|
Other PyTorch optimizers can be passed with dummy params:
|
|
77
78
|
|
|
78
79
|
>>> import torch
|
|
79
|
-
>>> gptq_conf = mct.get_pytorch_gptq_config(n_epochs=3, optimizer=torch.optim.Adam([torch.Tensor(1)]))
|
|
80
|
+
>>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=3, optimizer=torch.optim.Adam([torch.Tensor(1)]))
|
|
80
81
|
|
|
81
82
|
The configuration can be passed to :func:`~model_compression_toolkit.pytorch_post_training_quantization` in order to quantize a pytorch model using gptq.
|
|
82
83
|
|
|
83
84
|
"""
|
|
84
|
-
bias_optimizer =
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
optimizer,
|
|
88
|
-
optimizer_rest=optimizer_rest,
|
|
89
|
-
loss=loss,
|
|
90
|
-
log_function=log_function,
|
|
91
|
-
train_bias=True,
|
|
92
|
-
optimizer_quantization_parameter=optimizer_quantization_parameter,
|
|
93
|
-
optimizer_bias=bias_optimizer)
|
|
85
|
+
bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
|
|
86
|
+
return GradientPTQConfigV2(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
|
|
87
|
+
log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer)
|
|
94
88
|
|
|
95
89
|
|
|
96
90
|
def pytorch_gradient_post_training_quantization_experimental(model: Module,
|
|
@@ -152,15 +146,15 @@ if FOUND_TORCH:
|
|
|
152
146
|
|
|
153
147
|
Pass the module, the representative dataset generator and the configuration (optional) to get a quantized module
|
|
154
148
|
|
|
155
|
-
>>> quantized_module, quantization_info = mct.pytorch_gradient_post_training_quantization_experimental(module, repr_datagen, core_config=config, gptq_config=gptq_conf)
|
|
149
|
+
>>> quantized_module, quantization_info = mct.gptq.pytorch_gradient_post_training_quantization_experimental(module, repr_datagen, core_config=config, gptq_config=gptq_conf)
|
|
156
150
|
|
|
157
151
|
"""
|
|
158
152
|
|
|
159
153
|
if core_config.mixed_precision_enable:
|
|
160
154
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
|
|
161
155
|
common.Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
162
|
-
"MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization
|
|
163
|
-
"or pass a valid mixed precision configuration.")
|
|
156
|
+
"MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
|
|
157
|
+
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
164
158
|
|
|
165
159
|
common.Logger.info("Using experimental mixed-precision quantization. "
|
|
166
160
|
"If you encounter an issue please file a bug.")
|
|
@@ -212,10 +206,10 @@ else:
|
|
|
212
206
|
def get_pytorch_gptq_config(*args, **kwargs):
|
|
213
207
|
Logger.critical('Installing Pytorch is mandatory '
|
|
214
208
|
'when using pytorch_gradient_post_training_quantization_experimental. '
|
|
215
|
-
'Could not find torch package.')
|
|
209
|
+
'Could not find torch package.') # pragma: no cover
|
|
216
210
|
|
|
217
211
|
|
|
218
212
|
def pytorch_gradient_post_training_quantization_experimental(*args, **kwargs):
|
|
219
213
|
Logger.critical('Installing Pytorch is mandatory '
|
|
220
214
|
'when using pytorch_gradient_post_training_quantization_experimental. '
|
|
221
|
-
'Could not find the torch package.')
|
|
215
|
+
'Could not find the torch package.') # pragma: no cover
|