mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py
DELETED
|
@@ -1,157 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from typing import Union, List
|
|
16
|
-
from abc import abstractmethod
|
|
17
|
-
import torch
|
|
18
|
-
import numpy as np
|
|
19
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
|
|
20
|
-
from model_compression_toolkit.core.common.logger import Logger
|
|
21
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.gptq_quantizer import BaseWeightQuantizer
|
|
22
|
-
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
|
|
23
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import sample_gumbel
|
|
24
|
-
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
25
|
-
from model_compression_toolkit.core.common.target_platform.op_quantization_config import QuantizationMethod
|
|
26
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import ste_clip
|
|
27
|
-
|
|
28
|
-
P_INIT = 0.01
|
|
29
|
-
GR_SHIFT_BASE = 2
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def init_aux_var(ceil_indicator: torch.Tensor, w_shape: torch.Size, m: int, p: float = P_INIT) -> torch.Tensor:
|
|
33
|
-
"""
|
|
34
|
-
This function generate a random pi matrix for Gumbel Rounding
|
|
35
|
-
Args:
|
|
36
|
-
ceil_indicator: An array of indicator if the value should be ceil or floor.
|
|
37
|
-
w_shape(torch.Size): A list of integers that represent the shape of the weights tensor to be quantization
|
|
38
|
-
m(int): An integer that define the number of shift
|
|
39
|
-
p(float): A floating point number that represent the probability of non-round options of pi matrix.
|
|
40
|
-
|
|
41
|
-
Returns: A torch tensor of pi tensor
|
|
42
|
-
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
if m < 2:
|
|
46
|
-
Logger.error("m must be larger than two")
|
|
47
|
-
if m % 2 != 0:
|
|
48
|
-
Logger.error("m must be module two")
|
|
49
|
-
m_hat = m // 2 - 1
|
|
50
|
-
shift = -np.log(-np.log(1 - p))
|
|
51
|
-
n = np.random.randn(*[m, *w_shape]) * np.sqrt(np.power(np.pi, 2) / 6)
|
|
52
|
-
n = n.reshape([m, -1]).T
|
|
53
|
-
ceil_indicator = ceil_indicator.cpu().numpy().flatten()
|
|
54
|
-
n[np.arange(ceil_indicator.size), ceil_indicator + m_hat] += shift
|
|
55
|
-
n = n.T.reshape(*[m, *w_shape])
|
|
56
|
-
return torch.from_numpy(n).float()
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def init_shift_var(m: int) -> torch.Tensor:
|
|
60
|
-
"""
|
|
61
|
-
This function generate a tensor of 2*m+1 from -m to m
|
|
62
|
-
Args:
|
|
63
|
-
m: An integer value the represent m
|
|
64
|
-
|
|
65
|
-
Returns: A tensor of size m
|
|
66
|
-
|
|
67
|
-
"""
|
|
68
|
-
m_hat = m // 2
|
|
69
|
-
aux_index_shift = [-m_hat + i + 1 for i in range(m)]
|
|
70
|
-
return torch.Tensor(aux_index_shift)
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
class BaseGumbelWeightQuantizer(BaseWeightQuantizer):
|
|
74
|
-
"""
|
|
75
|
-
Base class that implements a quantizer with trainable parameters to be used for GPTQ training.
|
|
76
|
-
"""
|
|
77
|
-
|
|
78
|
-
def __init__(self,
|
|
79
|
-
weights_quantization_cfg: NodeWeightsQuantizationConfig,
|
|
80
|
-
gptq_config: GradientPTQConfigV2,
|
|
81
|
-
weight_shape: torch.Size):
|
|
82
|
-
"""
|
|
83
|
-
Construct a Pytorch model that utilize a fake weight quantizer of Gumbel rounding
|
|
84
|
-
Args:
|
|
85
|
-
weights_quantization_cfg: Configuration of weight quantization
|
|
86
|
-
gptq_config: GradientPTQConfigV2 object with parameters about the tuning process.
|
|
87
|
-
weight_shape: weight shape for auxiliary tensor creation.
|
|
88
|
-
"""
|
|
89
|
-
super().__init__()
|
|
90
|
-
self.power_of_two = QuantizationMethod.POWER_OF_TWO == weights_quantization_cfg.weights_quantization_method
|
|
91
|
-
self.reshape_aux_shift = [-1, *[1 for _ in range(len(weight_shape))]]
|
|
92
|
-
self.num_bits = weights_quantization_cfg.weights_n_bits
|
|
93
|
-
self.weight_shape = weight_shape
|
|
94
|
-
self.max_delta_change = gptq_config.lsb_change_per_bit_width.get(self.num_bits)
|
|
95
|
-
self.quantization_parameter_learning = gptq_config.quantization_parameters_learning
|
|
96
|
-
self.m = GR_SHIFT_BASE * self.max_delta_change + GR_SHIFT_BASE
|
|
97
|
-
self.minimal_temp = gptq_config.quantizer_config.minimal_temp
|
|
98
|
-
self.maximal_temp = gptq_config.quantizer_config.maximal_temp
|
|
99
|
-
self.temperature_learning = gptq_config.quantizer_config.temperature_learning
|
|
100
|
-
self.cycle_iterations = max(1, int(gptq_config.n_epochs / gptq_config.quantizer_config.n_cycles))
|
|
101
|
-
self.shift_tensor = to_torch_tensor(init_shift_var(self.m))
|
|
102
|
-
self.tau = None
|
|
103
|
-
self.g_t = 0
|
|
104
|
-
self.p_t = None
|
|
105
|
-
self.n_iter = 0
|
|
106
|
-
self.update_gumbel_param = True
|
|
107
|
-
scale = self.cycle_iterations / (-2 * np.log(0.001))
|
|
108
|
-
|
|
109
|
-
self.gumbel_scale = gptq_config.quantizer_config.gumbel_scale
|
|
110
|
-
self.gumbel_scale_per_bitwidth = gptq_config.quantizer_config.gumbel_scale_per_bitwidth
|
|
111
|
-
|
|
112
|
-
def tau_function(i: int) -> float:
|
|
113
|
-
"""
|
|
114
|
-
A function that generates the gumbel temperature.
|
|
115
|
-
Args:
|
|
116
|
-
i: An int that represents the current iteration number
|
|
117
|
-
|
|
118
|
-
Returns: A temperature value.
|
|
119
|
-
|
|
120
|
-
"""
|
|
121
|
-
if i < (self.cycle_iterations - 1):
|
|
122
|
-
index = ((i + 1) % self.cycle_iterations) / scale
|
|
123
|
-
else:
|
|
124
|
-
index = (i % self.cycle_iterations) / scale
|
|
125
|
-
|
|
126
|
-
x = np.exp(-index)
|
|
127
|
-
return self.minimal_temp + (self.maximal_temp - self.minimal_temp) * x
|
|
128
|
-
|
|
129
|
-
self.tau_function = tau_function
|
|
130
|
-
|
|
131
|
-
def get_gumbel_probability(self) -> torch.Tensor:
|
|
132
|
-
"""
|
|
133
|
-
A function that return the gumbel probability value.
|
|
134
|
-
Returns: gumbel probability
|
|
135
|
-
"""
|
|
136
|
-
return self.p_t
|
|
137
|
-
|
|
138
|
-
def update_iteration(self, training):
|
|
139
|
-
"""
|
|
140
|
-
A function that update parameters for GPTQ fine-tuning
|
|
141
|
-
Args:
|
|
142
|
-
training: whether in training mode or not
|
|
143
|
-
"""
|
|
144
|
-
if self.temperature_learning:
|
|
145
|
-
self.tau = ste_clip(self.temp_tensor, self.minimal_temp, self.maximal_temp)
|
|
146
|
-
else:
|
|
147
|
-
self.tau = self.tau_function(self.n_iter)
|
|
148
|
-
if self.update_gumbel_param and training:
|
|
149
|
-
self.n_iter += 1
|
|
150
|
-
self.g_t = sample_gumbel([self.m, *self.weight_shape])
|
|
151
|
-
|
|
152
|
-
@abstractmethod
|
|
153
|
-
def get_temperature_variable(self) -> Union[torch.Tensor, List]:
|
|
154
|
-
"""
|
|
155
|
-
Returns temperature trainable variables
|
|
156
|
-
"""
|
|
157
|
-
raise Logger.error(f"{self.__class__.__name__} have to implement this abstract function.")
|
model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py
DELETED
|
@@ -1,150 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
import torch
|
|
16
|
-
import torch.nn as nn
|
|
17
|
-
from typing import List, Union
|
|
18
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
19
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.gumbel_rounding.base_gumbel_weights_quantizer import BaseGumbelWeightQuantizer, init_aux_var
|
|
20
|
-
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
|
|
21
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import symmetric_quantizer
|
|
22
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import ste_clip, ste_gumbel, gumbel_softmax, power_of_two_max
|
|
23
|
-
from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, THRESHOLD_TENSOR, TEMP, SCALE_TENSOR
|
|
24
|
-
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
|
|
25
|
-
from model_compression_toolkit.core.common.constants import THRESHOLD
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class SymmetricGumbelWeightQuantizer(BaseGumbelWeightQuantizer):
|
|
29
|
-
"""
|
|
30
|
-
Class that implements a quantizer with trainable parameters to be used for GPTQ training.
|
|
31
|
-
"""
|
|
32
|
-
|
|
33
|
-
def __init__(self,
|
|
34
|
-
weights_quantization_cfg: NodeWeightsQuantizationConfig,
|
|
35
|
-
gptq_config: GradientPTQConfig,
|
|
36
|
-
weight: torch.Tensor):
|
|
37
|
-
"""
|
|
38
|
-
Construct a Pytorch model that utilize a fake weight quantizer of Symmetric Gumbel rounding
|
|
39
|
-
Args:
|
|
40
|
-
weights_quantization_cfg: Configuration of weight quantization
|
|
41
|
-
gptq_config: GradientPTQConfig object with parameters about the tuning process.
|
|
42
|
-
weight: weight for auxiliary tensor creation.
|
|
43
|
-
"""
|
|
44
|
-
super().__init__(weights_quantization_cfg, gptq_config, weight.shape)
|
|
45
|
-
self.signed = True
|
|
46
|
-
self.min_int = -int(self.signed) * (2 ** (self.num_bits - int(self.signed)))
|
|
47
|
-
self.max_int = (2 ** (self.num_bits - int(self.signed))) - 1
|
|
48
|
-
self.threshold_tensor = to_torch_tensor(weights_quantization_cfg.weights_quantization_params.get(THRESHOLD))
|
|
49
|
-
self.scale_tensor = torch.ones(self.weight_shape)
|
|
50
|
-
|
|
51
|
-
# Set trainable tensors
|
|
52
|
-
self.set_trainable_params(weight)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def set_trainable_params(self, weight: torch.nn.Parameter):
|
|
56
|
-
"""
|
|
57
|
-
A function to set a list of trainable parameters of the quantizer for GPTQ retraining
|
|
58
|
-
Args:
|
|
59
|
-
weight: weight for auxiliary tensor creation.
|
|
60
|
-
"""
|
|
61
|
-
q_error = weight - symmetric_quantizer(weight,
|
|
62
|
-
self.threshold_tensor,
|
|
63
|
-
num_bits=self.num_bits,
|
|
64
|
-
signed=True,
|
|
65
|
-
power_of_two=self.power_of_two)
|
|
66
|
-
ceil_indicator = (q_error < 0).int() # Negative error means the choosen point is rounded to ceil.
|
|
67
|
-
self.aux_tensor = nn.Parameter(to_torch_tensor(init_aux_var(ceil_indicator, self.weight_shape, self.m)), requires_grad=True)
|
|
68
|
-
self.trainable_params.update({AUXVAR: self.aux_tensor})
|
|
69
|
-
self.temp_tensor = nn.Parameter(to_torch_tensor(self.maximal_temp*torch.ones([1,*self.weight_shape])), requires_grad=True)
|
|
70
|
-
self.trainable_params.update({TEMP: self.temp_tensor})
|
|
71
|
-
if self.quantization_parameter_learning and not self.power_of_two:
|
|
72
|
-
self.scale_tensor = nn.Parameter(to_torch_tensor(self.scale_tensor), requires_grad=True)
|
|
73
|
-
self.trainable_params.update({SCALE_TENSOR: self.scale_tensor})
|
|
74
|
-
elif self.quantization_parameter_learning:
|
|
75
|
-
self.threshold_tensor = nn.Parameter(self.threshold_tensor, requires_grad=True)
|
|
76
|
-
self.trainable_params.update({THRESHOLD_TENSOR: self.threshold_tensor})
|
|
77
|
-
else:
|
|
78
|
-
self.trainable_params.update({THRESHOLD_TENSOR: self.threshold_tensor})
|
|
79
|
-
|
|
80
|
-
def get_aux_variable(self) -> torch.Tensor:
|
|
81
|
-
"""
|
|
82
|
-
Returns auxiliary trainable variables
|
|
83
|
-
"""
|
|
84
|
-
return self.trainable_params.get(AUXVAR)
|
|
85
|
-
|
|
86
|
-
def get_quantization_variable(self) -> Union[torch.Tensor, List]:
|
|
87
|
-
"""
|
|
88
|
-
Returns quantization trainable variables
|
|
89
|
-
"""
|
|
90
|
-
if self.quantization_parameter_learning and not self.power_of_two:
|
|
91
|
-
return [self.trainable_params.get(SCALE_TENSOR)]
|
|
92
|
-
else:
|
|
93
|
-
return [self.trainable_params.get(THRESHOLD_TENSOR)]
|
|
94
|
-
|
|
95
|
-
def get_temperature_variable(self) -> Union[torch.Tensor, List]:
|
|
96
|
-
"""
|
|
97
|
-
Returns temperature trainable variables
|
|
98
|
-
"""
|
|
99
|
-
return self.trainable_params.get(TEMP)
|
|
100
|
-
|
|
101
|
-
def get_weight_quant_params(self) -> dict:
|
|
102
|
-
"""
|
|
103
|
-
Returns weight quantization dictionary params
|
|
104
|
-
"""
|
|
105
|
-
threshold_tensor = self.threshold_tensor
|
|
106
|
-
if self.power_of_two:
|
|
107
|
-
threshold_tensor = power_of_two_max(threshold_tensor)
|
|
108
|
-
elif self.quantization_parameter_learning:
|
|
109
|
-
threshold_tensor = threshold_tensor*self.scale_tensor
|
|
110
|
-
return {THRESHOLD: torch_tensor_to_numpy(threshold_tensor.detach())}
|
|
111
|
-
|
|
112
|
-
def forward(self, w: nn.Parameter, training:bool = True) -> nn.Parameter:
|
|
113
|
-
"""
|
|
114
|
-
Weight fake quantizer
|
|
115
|
-
Args:
|
|
116
|
-
w: weights to quantize.
|
|
117
|
-
training: whether in training mode or not
|
|
118
|
-
Returns:
|
|
119
|
-
quantized weights
|
|
120
|
-
"""
|
|
121
|
-
self.update_iteration(training)
|
|
122
|
-
|
|
123
|
-
#####################################################
|
|
124
|
-
# Gumbel Softmax
|
|
125
|
-
#####################################################
|
|
126
|
-
if training:
|
|
127
|
-
gumbel_scale = self.gumbel_scale if self.gumbel_scale_per_bitwidth is None \
|
|
128
|
-
else self.gumbel_scale_per_bitwidth.get(self.num_bits, self.gumbel_scale)
|
|
129
|
-
self.p_t = gumbel_softmax(self.aux_tensor, self.tau, self.g_t, gumbel_scale=gumbel_scale)
|
|
130
|
-
else:
|
|
131
|
-
self.p_t = ste_gumbel(gumbel_softmax(self.aux_tensor, self.minimal_temp, 0))
|
|
132
|
-
|
|
133
|
-
auxhat_tensor = torch.sum(self.p_t * self.shift_tensor.reshape(self.reshape_aux_shift), dim=0)
|
|
134
|
-
|
|
135
|
-
#####################################################
|
|
136
|
-
# Quantizer
|
|
137
|
-
#####################################################
|
|
138
|
-
threshold_tensor = self.threshold_tensor
|
|
139
|
-
if self.power_of_two:
|
|
140
|
-
threshold_tensor = power_of_two_max(threshold_tensor)
|
|
141
|
-
delta_tensor = threshold_tensor / (2 ** (self.num_bits-int(self.signed)))
|
|
142
|
-
w0 = torch.floor(w / delta_tensor).detach()
|
|
143
|
-
w1 = w0 + auxhat_tensor
|
|
144
|
-
w2 = ste_clip(w1, min_val=self.min_int, max_val=self.max_int)
|
|
145
|
-
w_q = delta_tensor * w2
|
|
146
|
-
# Scale
|
|
147
|
-
if self.quantization_parameter_learning and not self.power_of_two:
|
|
148
|
-
w_q *= self.scale_tensor
|
|
149
|
-
return w_q
|
|
150
|
-
|
model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py
DELETED
|
@@ -1,143 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
import numpy as np
|
|
16
|
-
import torch
|
|
17
|
-
import torch.nn as nn
|
|
18
|
-
from typing import List, Union, Dict
|
|
19
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
20
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.gumbel_rounding.base_gumbel_weights_quantizer import BaseGumbelWeightQuantizer, init_aux_var
|
|
21
|
-
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
|
|
22
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import uniform_quantizer, fix_range_to_include_zero
|
|
23
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import ste_clip, ste_gumbel, gumbel_softmax
|
|
24
|
-
from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_MAX_RANGE, PTQ_MIN_RANGE, TEMP
|
|
25
|
-
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
|
|
26
|
-
from model_compression_toolkit.core.common.constants import RANGE_MAX, RANGE_MIN
|
|
27
|
-
|
|
28
|
-
class UniformGumbelWeightQuantizer(BaseGumbelWeightQuantizer):
|
|
29
|
-
"""
|
|
30
|
-
Class that implements a quantizer with trainable parameters to be used for GPTQ training.
|
|
31
|
-
"""
|
|
32
|
-
|
|
33
|
-
def __init__(self,
|
|
34
|
-
weights_quantization_cfg: NodeWeightsQuantizationConfig,
|
|
35
|
-
gptq_config: GradientPTQConfig,
|
|
36
|
-
weight: torch.nn.Parameter):
|
|
37
|
-
"""
|
|
38
|
-
Construct a Pytorch model that utilize a fake weight quantizer of Uniform Gumbel rounding
|
|
39
|
-
Args:
|
|
40
|
-
weights_quantization_cfg: Configuration of weight quantization
|
|
41
|
-
gptq_config: GradientPTQConfig object with parameters about the tuning process.
|
|
42
|
-
weight: weight for auxiliary tensor creation.
|
|
43
|
-
"""
|
|
44
|
-
super().__init__(weights_quantization_cfg, gptq_config, weight.shape)
|
|
45
|
-
self.min_int = 0
|
|
46
|
-
self.max_int = 2**self.num_bits - 1
|
|
47
|
-
self.max_range_tensor = weights_quantization_cfg.weights_quantization_params.get(RANGE_MAX)
|
|
48
|
-
self.min_range_tensor = weights_quantization_cfg.weights_quantization_params.get(RANGE_MIN)
|
|
49
|
-
|
|
50
|
-
# Set trainable tensors
|
|
51
|
-
self.set_trainable_params(weight)
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
def set_trainable_params(self, weight: torch.nn.Parameter):
|
|
55
|
-
"""
|
|
56
|
-
A function to set a list of trainable parameters of the quantizer for GPTQ retraining
|
|
57
|
-
"""
|
|
58
|
-
self.temp_tensor = nn.Parameter(to_torch_tensor(self.maximal_temp*torch.ones([1,*self.weight_shape])), requires_grad=True)
|
|
59
|
-
self.trainable_params.update({TEMP: self.temp_tensor})
|
|
60
|
-
self.max_range_tensor = nn.Parameter(to_torch_tensor(self.max_range_tensor), requires_grad=self.quantization_parameter_learning)
|
|
61
|
-
self.trainable_params.update({PTQ_MAX_RANGE: self.max_range_tensor})
|
|
62
|
-
self.min_range_tensor = nn.Parameter(to_torch_tensor(self.min_range_tensor), requires_grad=self.quantization_parameter_learning)
|
|
63
|
-
self.trainable_params.update({PTQ_MIN_RANGE: self.min_range_tensor})
|
|
64
|
-
q_error = weight - uniform_quantizer(weight,
|
|
65
|
-
self.min_range_tensor,
|
|
66
|
-
self.max_range_tensor,
|
|
67
|
-
n_bits=self.num_bits)
|
|
68
|
-
ceil_indicator = (q_error < 0).int() # Negative error means the choosen point is rounded to ceil.
|
|
69
|
-
self.aux_tensor = nn.Parameter(to_torch_tensor(init_aux_var(ceil_indicator, self.weight_shape, self.m)), requires_grad=True)
|
|
70
|
-
self.trainable_params.update({AUXVAR: self.aux_tensor})
|
|
71
|
-
|
|
72
|
-
def get_aux_variable(self) -> torch.Tensor:
|
|
73
|
-
"""
|
|
74
|
-
Returns auxiliary trainable variables
|
|
75
|
-
"""
|
|
76
|
-
return self.trainable_params.get(AUXVAR)
|
|
77
|
-
|
|
78
|
-
def get_quantization_variable(self) -> Union[torch.Tensor, List]:
|
|
79
|
-
"""
|
|
80
|
-
Returns quantization trainable variables
|
|
81
|
-
"""
|
|
82
|
-
return [self.trainable_params.get(PTQ_MAX_RANGE), self.trainable_params.get(PTQ_MIN_RANGE)]
|
|
83
|
-
|
|
84
|
-
def get_temperature_variable(self) -> Union[torch.Tensor, List]:
|
|
85
|
-
"""
|
|
86
|
-
Returns temperature trainable variables
|
|
87
|
-
"""
|
|
88
|
-
return self.trainable_params.get(TEMP)
|
|
89
|
-
|
|
90
|
-
def get_weight_quant_params(self) -> Dict[str, np.ndarray]:
|
|
91
|
-
"""
|
|
92
|
-
Returns weight quantization dictionary params
|
|
93
|
-
"""
|
|
94
|
-
max_range_tensor = self.max_range_tensor
|
|
95
|
-
min_range_tensor = self.min_range_tensor
|
|
96
|
-
return {PTQ_MAX_RANGE: torch_tensor_to_numpy(max_range_tensor.detach()),
|
|
97
|
-
PTQ_MIN_RANGE: torch_tensor_to_numpy(min_range_tensor.detach())}
|
|
98
|
-
|
|
99
|
-
def forward(self, w: nn.Parameter, training:bool = True) -> nn.Parameter:
|
|
100
|
-
"""
|
|
101
|
-
Weight fake quantizer
|
|
102
|
-
Args:
|
|
103
|
-
w: weights to quantize.
|
|
104
|
-
training: whether in training mode or not
|
|
105
|
-
Returns:
|
|
106
|
-
quantized weights
|
|
107
|
-
"""
|
|
108
|
-
self.update_iteration(training)
|
|
109
|
-
|
|
110
|
-
#####################################################
|
|
111
|
-
# Gumbel Softmax
|
|
112
|
-
#####################################################
|
|
113
|
-
if training:
|
|
114
|
-
self.p_t = gumbel_softmax(self.aux_tensor, self.tau, self.g_t)
|
|
115
|
-
else:
|
|
116
|
-
self.p_t = ste_gumbel(gumbel_softmax(self.aux_tensor, self.minimal_temp, 0))
|
|
117
|
-
|
|
118
|
-
auxhat_tensor = torch.sum(self.p_t * self.shift_tensor.reshape(self.reshape_aux_shift), dim=0)
|
|
119
|
-
|
|
120
|
-
#####################################################
|
|
121
|
-
# Quantizer
|
|
122
|
-
#####################################################
|
|
123
|
-
max_range_tensor = self.max_range_tensor
|
|
124
|
-
min_range_tensor = self.min_range_tensor
|
|
125
|
-
|
|
126
|
-
# adjusts the quantization rage so the quantization grid include zero.
|
|
127
|
-
a, b = fix_range_to_include_zero(min_range_tensor, max_range_tensor, self.num_bits)
|
|
128
|
-
|
|
129
|
-
# Compute the step size of quantized values.
|
|
130
|
-
delta_tensor = (b - a) / (2 ** self.num_bits - 1)
|
|
131
|
-
|
|
132
|
-
# Apply rounding
|
|
133
|
-
w0 = torch.floor((w - a) / delta_tensor).detach() # Apply rounding
|
|
134
|
-
|
|
135
|
-
w1 = w0 + auxhat_tensor
|
|
136
|
-
|
|
137
|
-
# Clip data in range
|
|
138
|
-
w2 = ste_clip(w1, min_val=self.min_int, max_val=self.max_int)
|
|
139
|
-
|
|
140
|
-
# Quantize the data between min/max of quantization range.
|
|
141
|
-
w_q = delta_tensor * w2 + a
|
|
142
|
-
return w_q
|
|
143
|
-
|
|
@@ -1,103 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
import torch
|
|
16
|
-
import torch.nn as nn
|
|
17
|
-
from model_compression_toolkit.core.common import BaseNode, Logger
|
|
18
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType
|
|
19
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.gptq_quantizer import BaseWeightQuantizer
|
|
20
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.ste_rounding.ste_weights_quantizer import STEWeightQuantizer
|
|
21
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.gumbel_rounding.sym_gumbel_weights_quantizer import SymmetricGumbelWeightQuantizer
|
|
22
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.gumbel_rounding.uniform_gumbel_weights_quantizer import UniformGumbelWeightQuantizer
|
|
23
|
-
from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder
|
|
24
|
-
from model_compression_toolkit.core.pytorch.constants import KERNEL
|
|
25
|
-
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
26
|
-
from model_compression_toolkit.core.common.target_platform.op_quantization_config import QuantizationMethod
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class WeightQuantizerWrapper(nn.Module):
|
|
30
|
-
|
|
31
|
-
def __init__(self, node: BaseNode, gptq_config: GradientPTQConfig, weight_quantizer: BaseWeightQuantizer):
|
|
32
|
-
"""
|
|
33
|
-
Construct a Pytorch model that constitutes as a wrapper for a Pytorch layer, built from a given graph node.
|
|
34
|
-
Args:
|
|
35
|
-
node: Node to build its Pytorch quantizer wrapper.
|
|
36
|
-
gptq_config: GradientPTQConfig object with parameters about the tuning process.
|
|
37
|
-
weight_quantizer: BaseWeightQuantizer object for gradient based weight quantizer
|
|
38
|
-
"""
|
|
39
|
-
super().__init__()
|
|
40
|
-
|
|
41
|
-
# loading operation
|
|
42
|
-
self.op = node.type(**node.framework_attr)
|
|
43
|
-
|
|
44
|
-
# loading the weights from the graph node (weights of the trained model)
|
|
45
|
-
self.op.load_state_dict({k: torch.Tensor(v) for k, v in node.weights.items()}, strict=False)
|
|
46
|
-
self.float_weight = to_torch_tensor(getattr(self.op, KERNEL)).detach()
|
|
47
|
-
|
|
48
|
-
# replace non-gradient needed nn.Parameter with gradient needed torch.tensor
|
|
49
|
-
delattr(self.op, KERNEL)
|
|
50
|
-
setattr(self.op, KERNEL, self.float_weight)
|
|
51
|
-
setattr(getattr(self.op, KERNEL), 'requires_grad', True)
|
|
52
|
-
|
|
53
|
-
# quantizer
|
|
54
|
-
self.weight_quantizer = weight_quantizer(node.final_weights_quantization_cfg, gptq_config, self.float_weight)
|
|
55
|
-
|
|
56
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
57
|
-
"""
|
|
58
|
-
Weight fake quantizer wrapper
|
|
59
|
-
Args:
|
|
60
|
-
x: input to layer.
|
|
61
|
-
Returns:
|
|
62
|
-
Output of layer after using operation with fake quantized weights
|
|
63
|
-
"""
|
|
64
|
-
# Run weight quantizer
|
|
65
|
-
setattr(self.op, KERNEL, self.weight_quantizer(self.float_weight))
|
|
66
|
-
# Do computation
|
|
67
|
-
return self.op(x)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def quantizer_wrapper(node: BaseNode, gptq_config: GradientPTQConfig) -> nn.Module:
|
|
71
|
-
"""
|
|
72
|
-
Construct a Pytorch model that constitutes as a wrapper for a Pytorch layer, built from a given graph node.
|
|
73
|
-
Args:
|
|
74
|
-
node: Node to build its Pytorch layer.
|
|
75
|
-
gptq_config: GradientPTQConfig with parameters about the tuning process.
|
|
76
|
-
"""
|
|
77
|
-
if node.is_weights_quantization_enabled():
|
|
78
|
-
quantization_method = node.final_weights_quantization_cfg.weights_quantization_method
|
|
79
|
-
if quantization_method in [QuantizationMethod.SYMMETRIC, QuantizationMethod.POWER_OF_TWO]:
|
|
80
|
-
# STE quantizer
|
|
81
|
-
# ---------------
|
|
82
|
-
if gptq_config.rounding_type == RoundingType.STE:
|
|
83
|
-
node_instance = WeightQuantizerWrapper(node, gptq_config, STEWeightQuantizer)
|
|
84
|
-
|
|
85
|
-
# Symmetric Gumbel rounding quantizer
|
|
86
|
-
# ------------------------------------
|
|
87
|
-
elif gptq_config.rounding_type == RoundingType.GumbelRounding:
|
|
88
|
-
node_instance = WeightQuantizerWrapper(node, gptq_config, SymmetricGumbelWeightQuantizer)
|
|
89
|
-
|
|
90
|
-
elif quantization_method == QuantizationMethod.UNIFORM:
|
|
91
|
-
# Uniform Gumbel rounding quantizer
|
|
92
|
-
# ------------------------------------
|
|
93
|
-
if gptq_config.rounding_type == RoundingType.GumbelRounding:
|
|
94
|
-
node_instance = WeightQuantizerWrapper(node, gptq_config, UniformGumbelWeightQuantizer)
|
|
95
|
-
else:
|
|
96
|
-
Logger.error(f"For quantization method {quantization_method}, GPTQ Rounding type {gptq_config.rounding_type} is not supported")
|
|
97
|
-
else:
|
|
98
|
-
Logger.error(f"For quantization method {quantization_method}, GPTQ Rounding type {gptq_config.rounding_type} is not supported")
|
|
99
|
-
else:
|
|
100
|
-
# No quantization
|
|
101
|
-
node_instance = node_builder(node)
|
|
102
|
-
|
|
103
|
-
return node_instance
|
|
@@ -1,103 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
import torch
|
|
16
|
-
import torch.nn as nn
|
|
17
|
-
from typing import List, Union
|
|
18
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
19
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.gptq_quantizer import BaseWeightQuantizer
|
|
20
|
-
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
21
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import ste_round, ste_clip
|
|
22
|
-
from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR
|
|
23
|
-
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
|
|
24
|
-
from model_compression_toolkit.core.common.constants import THRESHOLD
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class STEWeightQuantizer(BaseWeightQuantizer):
|
|
28
|
-
"""
|
|
29
|
-
Class that implements a quantizer with trainable parameters to be used for GPTQ training.
|
|
30
|
-
"""
|
|
31
|
-
|
|
32
|
-
def __init__(self,
|
|
33
|
-
weights_quantization_cfg: NodeWeightsQuantizationConfig,
|
|
34
|
-
gptq_config: GradientPTQConfig,
|
|
35
|
-
weight: torch.nn.Parameter):
|
|
36
|
-
"""
|
|
37
|
-
Construct a Pytorch model that utilize a fake weight quantizer of STE (Straight Through Estimator) for symmetric quantizer.
|
|
38
|
-
Args:
|
|
39
|
-
weights_quantization_cfg: Configuration of weight quantization
|
|
40
|
-
gptq_config: GradientPTQConfig object with parameters about the tuning process.
|
|
41
|
-
weight: weight for auxiliary tensor creation.
|
|
42
|
-
"""
|
|
43
|
-
super().__init__()
|
|
44
|
-
|
|
45
|
-
self.signed = True
|
|
46
|
-
self.num_bits = weights_quantization_cfg.weights_n_bits
|
|
47
|
-
self.min_int = -int(self.signed) * (2 ** (self.num_bits - int(self.signed)))
|
|
48
|
-
self.max_int = (2 ** (self.num_bits - int(self.signed))) - 1
|
|
49
|
-
self.weight_shape = weight.shape
|
|
50
|
-
self.threshold_values = weights_quantization_cfg.weights_quantization_params.get(THRESHOLD)
|
|
51
|
-
self.delta_tensor = self.threshold_values / (2 ** (self.num_bits-int(self.signed)))
|
|
52
|
-
self.max_delta_change = gptq_config.lsb_change_per_bit_width.get(self.num_bits)
|
|
53
|
-
|
|
54
|
-
# Set trainable tensors
|
|
55
|
-
self.set_trainable_params()
|
|
56
|
-
|
|
57
|
-
# Create tensors
|
|
58
|
-
self.delta_tensor = to_torch_tensor(self.delta_tensor)
|
|
59
|
-
self.max_tensor_change = self.delta_tensor * self.max_delta_change
|
|
60
|
-
|
|
61
|
-
def set_trainable_params(self):
|
|
62
|
-
"""
|
|
63
|
-
A function to set a list of trainable parameters of the quantizer for GPTQ retraining
|
|
64
|
-
"""
|
|
65
|
-
self.aux_tensor = nn.Parameter(to_torch_tensor(torch.zeros(self.weight_shape)), requires_grad=True)
|
|
66
|
-
self.trainable_params.update({AUXVAR: self.aux_tensor})
|
|
67
|
-
|
|
68
|
-
def get_aux_variable(self) -> torch.Tensor:
|
|
69
|
-
"""
|
|
70
|
-
Returns auxiliary trainable variables
|
|
71
|
-
"""
|
|
72
|
-
return self.trainable_params.get(AUXVAR)
|
|
73
|
-
|
|
74
|
-
def get_quantization_variable(self) -> Union[torch.Tensor, List]:
|
|
75
|
-
"""
|
|
76
|
-
Returns quantization trainable variables
|
|
77
|
-
"""
|
|
78
|
-
return []
|
|
79
|
-
|
|
80
|
-
def get_weight_quantization_params(self) -> dict:
|
|
81
|
-
"""
|
|
82
|
-
Returns weight quantization dictionary params
|
|
83
|
-
"""
|
|
84
|
-
return {THRESHOLD: self.threshold_values}
|
|
85
|
-
|
|
86
|
-
def forward(self, w: nn.Parameter, training: bool = True) -> nn.Parameter:
|
|
87
|
-
"""
|
|
88
|
-
Weight fake quantizer
|
|
89
|
-
Args:
|
|
90
|
-
w: weights to quantize.
|
|
91
|
-
training: whether in training mode or not
|
|
92
|
-
Returns:
|
|
93
|
-
quantized weights
|
|
94
|
-
"""
|
|
95
|
-
v0 = ste_clip(self.aux_tensor, min_val=-self.max_tensor_change, max_val=self.max_tensor_change)
|
|
96
|
-
v1 = v0 / self.delta_tensor
|
|
97
|
-
w0 = torch.round(w / self.delta_tensor).detach()
|
|
98
|
-
w1 = w0 + v1
|
|
99
|
-
w2 = ste_round(w1)
|
|
100
|
-
w3 = ste_clip(w2, min_val=self.min_int, max_val=self.max_int)
|
|
101
|
-
w_q = self.delta_tensor * w3
|
|
102
|
-
return w_q
|
|
103
|
-
|