mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -11,4 +11,8 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import model_compression_toolkit.gptq.pytorch.quantizer.ste_rounding.symmetric_ste
|
|
17
|
+
import model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.symmetric_soft_quantizer
|
|
18
|
+
import model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.uniform_soft_quantizer
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from abc import abstractmethod
|
|
16
|
+
from typing import Union, Dict, List
|
|
17
|
+
|
|
18
|
+
from model_compression_toolkit.core.common.logger import Logger
|
|
19
|
+
from model_compression_toolkit.core.common.constants import FOUND_TORCH
|
|
20
|
+
from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
|
|
21
|
+
|
|
22
|
+
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
23
|
+
TrainableQuantizerActivationConfig
|
|
24
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import \
|
|
25
|
+
BaseTrainableQuantizer
|
|
26
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.pytorch.base_pytorch_quantizer import \
|
|
27
|
+
BasePytorchTrainableQuantizer
|
|
28
|
+
|
|
29
|
+
if FOUND_TORCH:
|
|
30
|
+
from torch import Tensor
|
|
31
|
+
from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
|
|
32
|
+
|
|
33
|
+
class BasePytorchGPTQTrainableQuantizer(BasePytorchTrainableQuantizer):
|
|
34
|
+
"""
|
|
35
|
+
A base class for trainable Pytorch quantizer for GPTQ.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self,
|
|
39
|
+
quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
|
|
40
|
+
"""
|
|
41
|
+
Initializes BasePytorchGPTQTrainableQuantizer object.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
quantization_config: quantizer config class contains all the information about a quantizer configuration.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
super().__init__(quantization_config)
|
|
48
|
+
|
|
49
|
+
def update_layer_quantization_params(self, layer: PytorchQuantizationWrapper
|
|
50
|
+
) -> (Dict[str, Tensor], Dict[str, Dict], Dict):
|
|
51
|
+
"""
|
|
52
|
+
A Function to calculate the needed change in attributes in NodeQuantizationConfig after retraining.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
layer: A wrapped Pytorch layer.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
3 dictionaries describing the change in layer's weights, weights config, activation config
|
|
59
|
+
that changed during GPTQ retraining.
|
|
60
|
+
Keys must match NodeQuantizationConfig attributes
|
|
61
|
+
|
|
62
|
+
"""
|
|
63
|
+
weights = {}
|
|
64
|
+
for weight, quantizer_vars, quantizer in layer.get_weights_vars():
|
|
65
|
+
if not isinstance(quantizer, BaseTrainableQuantizer):
|
|
66
|
+
Logger.error(f"Expecting a GPTQ trainable quantizer, " # pragma: no cover
|
|
67
|
+
f"but got {type(quantizer)} which is not callable.")
|
|
68
|
+
weights.update({weight: quantizer(training=False, inputs=quantizer_vars)})
|
|
69
|
+
|
|
70
|
+
quant_config = {WEIGHTS_QUANTIZATION_PARAMS: self.get_quant_config()}
|
|
71
|
+
|
|
72
|
+
return weights, quant_config, {}
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
def get_quant_config(self):
|
|
76
|
+
"""
|
|
77
|
+
Returns the config used to edit NodeQuantizationConfig after GPTQ retraining.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
|
|
81
|
+
Keys must match NodeQuantizationConfig attributes.
|
|
82
|
+
|
|
83
|
+
"""
|
|
84
|
+
raise NotImplemented(f'{self.__class__.__name__} have to implement the ' # pragma: no cover
|
|
85
|
+
f'quantizer\'s get_quant_config.')
|
|
86
|
+
|
|
87
|
+
else:
|
|
88
|
+
class BasePytorchGPTQTrainableQuantizer: # pragma: no cover
|
|
89
|
+
def __init__(self, *args, **kwargs):
|
|
90
|
+
Logger.critical('Installing Pytorch is mandatory '
|
|
91
|
+
'when using BasePytorchGPTQTrainableQuantizer. '
|
|
92
|
+
'Could not find torch package.') # pragma: no cover
|
|
@@ -30,11 +30,20 @@ def calculate_delta(max_tensor: torch.Tensor,
|
|
|
30
30
|
num_bits: int,
|
|
31
31
|
signed: bool) -> torch.Tensor:
|
|
32
32
|
"""
|
|
33
|
-
Compute the step size for the quantization.
|
|
33
|
+
Compute the step size for the symmetric quantization.
|
|
34
34
|
"""
|
|
35
35
|
return max_tensor / (2 ** (num_bits - int(signed)))
|
|
36
36
|
|
|
37
37
|
|
|
38
|
+
def calculate_delta_uniform(min_tensor: torch.Tensor,
|
|
39
|
+
max_tensor: torch.Tensor,
|
|
40
|
+
num_bits: int) -> torch.Tensor:
|
|
41
|
+
"""
|
|
42
|
+
Compute the step size for the uniform quantization.
|
|
43
|
+
"""
|
|
44
|
+
return (max_tensor-min_tensor) / (2 ** num_bits - 1)
|
|
45
|
+
|
|
46
|
+
|
|
38
47
|
def ste_ceil(x: torch.Tensor) -> torch.Tensor:
|
|
39
48
|
"""
|
|
40
49
|
Return the ceil values of a tensor.
|
|
@@ -66,93 +75,6 @@ def ste_clip(x: torch.Tensor, min_val=-1.0, max_val=1.0) -> torch.Tensor:
|
|
|
66
75
|
return (torch.clip(x, min=min_val, max=max_val) - x).detach() + x
|
|
67
76
|
|
|
68
77
|
|
|
69
|
-
def gumbel_softmax(x: torch.Tensor, tau: Union[torch.Tensor,float], gumbel_tensor: Union[torch.Tensor,float], eps: float = 1e-6, axis=0,
|
|
70
|
-
gumbel_scale: float = 1.0) -> torch.Tensor:
|
|
71
|
-
"""
|
|
72
|
-
A gumbel softmax function.
|
|
73
|
-
Args:
|
|
74
|
-
x: A tensor of log probability.
|
|
75
|
-
tau: A temperature tensor.
|
|
76
|
-
gumbel_tensor: A tensor of gumbel random variable.
|
|
77
|
-
eps: A small number for numeric stability.
|
|
78
|
-
axis: A integer representing the axis of which the gumbel softmax applyed on.
|
|
79
|
-
gumbel_scale: A normalization factor for the gumbel tensor values
|
|
80
|
-
|
|
81
|
-
Returns: A gumbel softmax probability tensor.
|
|
82
|
-
|
|
83
|
-
"""
|
|
84
|
-
return softmax((log_softmax(x, dim=axis) + gumbel_tensor * gumbel_scale) / (tau + eps), dim=axis)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def select_gumbel(prob: torch.Tensor) -> torch.Tensor:
|
|
88
|
-
"""
|
|
89
|
-
This function apply ste on the output of the gumbel softmax.
|
|
90
|
-
Args:
|
|
91
|
-
prob: A tensor of probability.
|
|
92
|
-
|
|
93
|
-
Returns: A Tensor of ohe hot vector
|
|
94
|
-
|
|
95
|
-
"""
|
|
96
|
-
max_index = torch.argmax(prob, dim=0)
|
|
97
|
-
axis_list = [i for i in range(len(max_index.shape))]
|
|
98
|
-
axis_list.insert(0, len(max_index.shape))
|
|
99
|
-
one_hot_prob = torch.permute(one_hot(max_index, num_classes=prob.shape[0]), axis_list)
|
|
100
|
-
return one_hot_prob + 0*prob
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
def ste_gumbel(prob: torch.Tensor) -> torch.Tensor:
|
|
104
|
-
"""
|
|
105
|
-
This function apply ste on the output of the gumbel softmax.
|
|
106
|
-
Args:
|
|
107
|
-
prob:A tensor of probability
|
|
108
|
-
|
|
109
|
-
Returns: A Tensor of ohe hot vector with STE.
|
|
110
|
-
|
|
111
|
-
"""
|
|
112
|
-
delta = (select_gumbel(prob) - prob).detach()
|
|
113
|
-
return prob + delta
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
def sample_gumbel(shape, eps=1e-6) -> torch.Tensor:
|
|
117
|
-
"""
|
|
118
|
-
A function that sample a tensor of i.i.d gumbel random variable.
|
|
119
|
-
Args:
|
|
120
|
-
shape: The tensor output shape
|
|
121
|
-
eps: A small number for numeric stability.
|
|
122
|
-
|
|
123
|
-
Returns: A tensor of i.i.d gumbel random variable.
|
|
124
|
-
|
|
125
|
-
"""
|
|
126
|
-
u = to_torch_tensor(torch.rand(shape))
|
|
127
|
-
return -torch.log(-torch.log(u + eps) + eps)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
def symmetric_quantizer(input_tensor: torch.Tensor,
|
|
131
|
-
max_tensor: torch.Tensor,
|
|
132
|
-
num_bits: int,
|
|
133
|
-
signed: bool,
|
|
134
|
-
power_of_two: bool = False) -> torch.Tensor:
|
|
135
|
-
"""
|
|
136
|
-
Quantize a tensor symmetrically.
|
|
137
|
-
Args:
|
|
138
|
-
input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
|
|
139
|
-
max_tensor: Tensor with max values to compute the threshold.
|
|
140
|
-
num_bits: Num of bits to use.
|
|
141
|
-
signed: Signedness of the quantization range.
|
|
142
|
-
power_of_two: Whether the threshold should be constrained or not.
|
|
143
|
-
Returns:
|
|
144
|
-
A quantized tensor.
|
|
145
|
-
"""
|
|
146
|
-
|
|
147
|
-
if power_of_two:
|
|
148
|
-
max_tensor = power_of_two_max(max_tensor)
|
|
149
|
-
delta_tensor = calculate_delta(max_tensor, num_bits, signed)
|
|
150
|
-
tensor_q = ste_round(input_tensor / delta_tensor)
|
|
151
|
-
min_int = -int(signed) * (2 ** (num_bits - int(signed)))
|
|
152
|
-
max_int = (2 ** (num_bits - int(signed))) - 1
|
|
153
|
-
return delta_tensor * ste_clip(tensor_q, min_val=min_int, max_val=max_int)
|
|
154
|
-
|
|
155
|
-
|
|
156
78
|
def fix_range_to_include_zero(range_min: torch.Tensor,
|
|
157
79
|
range_max: torch.Tensor,
|
|
158
80
|
n_bits: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
@@ -180,34 +102,3 @@ def fix_range_to_include_zero(range_min: torch.Tensor,
|
|
|
180
102
|
min_range_adj = min_range_adj * mid_range + max_negative * range_min
|
|
181
103
|
max_range_adj = max_range_adj * mid_range + min_positive * range_max
|
|
182
104
|
return min_range_adj, max_range_adj
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
def uniform_quantizer(tensor_data: torch.Tensor,
|
|
186
|
-
range_min: torch.Tensor,
|
|
187
|
-
range_max: torch.Tensor,
|
|
188
|
-
n_bits: int) -> torch.Tensor:
|
|
189
|
-
"""
|
|
190
|
-
Quantize a tensor according to given range (min, max) and number of bits.
|
|
191
|
-
Args:
|
|
192
|
-
tensor_data: Tensor values to quantize.
|
|
193
|
-
range_min: minimum bound of the range for quantization (or array of min values per channel).
|
|
194
|
-
range_max: maximum bound of the range for quantization (or array of max values per channel).
|
|
195
|
-
n_bits: Number of bits to quantize the tensor.
|
|
196
|
-
Returns:
|
|
197
|
-
Quantized data.
|
|
198
|
-
"""
|
|
199
|
-
# adjusts the quantization rage so the quantization grid include zero.
|
|
200
|
-
a, b = fix_range_to_include_zero(range_min, range_max, n_bits)
|
|
201
|
-
|
|
202
|
-
# Compute the step size of quantized values.
|
|
203
|
-
delta_tensor = (b - a) / (2 ** n_bits - 1)
|
|
204
|
-
|
|
205
|
-
# Apply rounding
|
|
206
|
-
input_tensor_int = ste_round((tensor_data - a) / delta_tensor)
|
|
207
|
-
|
|
208
|
-
# Clip data in range
|
|
209
|
-
clipped_tensor = ste_clip(input_tensor_int, min_val=0, max_val=2 ** n_bits - 1)
|
|
210
|
-
|
|
211
|
-
# Quantize the data between min/max of quantization range.
|
|
212
|
-
q = delta_tensor * clipped_tensor + a
|
|
213
|
-
return q
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import List, Dict, Tuple
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit.gptq import GradientPTQConfigV2
|
|
18
|
+
from model_compression_toolkit.core import common
|
|
19
|
+
from model_compression_toolkit.core.pytorch.constants import KERNEL
|
|
20
|
+
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizer import \
|
|
21
|
+
get_activation_inferable_quantizer_kwargs
|
|
22
|
+
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
|
|
23
|
+
BasePytorchGPTQTrainableQuantizer
|
|
24
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
|
|
25
|
+
get_inferable_quantizer_class
|
|
26
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
|
|
27
|
+
BasePyTorchInferableQuantizer
|
|
28
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizer_config import \
|
|
29
|
+
get_trainable_quantizer_weights_config
|
|
30
|
+
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
|
|
31
|
+
from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
|
|
32
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizers import \
|
|
33
|
+
get_trainable_quantizer_class
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def quantization_builder(n: common.BaseNode,
|
|
37
|
+
gptq_config: GradientPTQConfigV2,
|
|
38
|
+
) -> Tuple[Dict[str, BasePytorchQATTrainableQuantizer],
|
|
39
|
+
List[BasePyTorchInferableQuantizer]]:
|
|
40
|
+
"""
|
|
41
|
+
Build quantizers for a node according to its quantization configuration and
|
|
42
|
+
a global NoOpQuantizeConfig object.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
n: Node to build its QuantizeConfig.
|
|
46
|
+
gptq_config (GradientPTQConfigV2): GradientPTQConfigV2 configuration.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
A dictionary which maps the weights kernel attribute to a quantizer for GPTQ training.
|
|
50
|
+
Note that we return a dictionary although there is only a single attribute that is being mapped to a quantizer,
|
|
51
|
+
to be compatible with the quantization infrastructure template.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
weights_quantizers = {}
|
|
55
|
+
if n.is_weights_quantization_enabled():
|
|
56
|
+
quant_method = n.final_weights_quantization_cfg.weights_quantization_method
|
|
57
|
+
quantizer_class = get_trainable_quantizer_class(quant_target=QuantizationTarget.Weights,
|
|
58
|
+
quantizer_type=gptq_config.rounding_type,
|
|
59
|
+
quant_method=quant_method,
|
|
60
|
+
quantizer_base_class=BasePytorchGPTQTrainableQuantizer)
|
|
61
|
+
weights_quantizers.update({KERNEL: quantizer_class(get_trainable_quantizer_weights_config(n),
|
|
62
|
+
**gptq_config.gptq_quantizer_params_override)})
|
|
63
|
+
activation_quantizers = []
|
|
64
|
+
if n.is_activation_quantization_enabled():
|
|
65
|
+
quant_method = n.final_activation_quantization_cfg.activation_quantization_method
|
|
66
|
+
|
|
67
|
+
quantizer_class = get_inferable_quantizer_class(quant_target=QuantizationTarget.Activation,
|
|
68
|
+
quant_method=quant_method,
|
|
69
|
+
quantizer_base_class=BasePyTorchInferableQuantizer)
|
|
70
|
+
|
|
71
|
+
kwargs = get_activation_inferable_quantizer_kwargs(n)
|
|
72
|
+
|
|
73
|
+
activation_quantizers.append(quantizer_class(**kwargs))
|
|
74
|
+
|
|
75
|
+
return weights_quantizers, activation_quantizers
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import Callable
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit.gptq import RoundingType, GradientPTQConfigV2, GradientPTQConfig
|
|
18
|
+
from model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.soft_quantizer_reg import \
|
|
19
|
+
SoftQuantizerRegularization
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen: Callable) -> Callable:
|
|
23
|
+
"""
|
|
24
|
+
Returns a function that computes the regularization term for GPTQ training based on the given
|
|
25
|
+
rounding type in the GPTQ configuration.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
gptq_config: A GPTQ configuration.
|
|
29
|
+
representative_data_gen: Dataset used for the GPTQ training.
|
|
30
|
+
|
|
31
|
+
Returns: A function for computing the regularization. If there is no regularization function defined for the given
|
|
32
|
+
rounding type, then it returns a function that just returns 0.
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
if gptq_config.rounding_type == RoundingType.SoftQuantizer:
|
|
36
|
+
# dry run on the representative dataset to count number of batches
|
|
37
|
+
num_batches = 0
|
|
38
|
+
for _ in representative_data_gen():
|
|
39
|
+
num_batches += 1
|
|
40
|
+
|
|
41
|
+
n_epochs = GradientPTQConfigV2.from_v1(n_ptq_iter=num_batches, config_v1=gptq_config).n_epochs if \
|
|
42
|
+
not type(gptq_config) == GradientPTQConfigV2 else gptq_config.n_epochs
|
|
43
|
+
return SoftQuantizerRegularization(total_gradient_steps=num_batches * n_epochs)
|
|
44
|
+
else:
|
|
45
|
+
return lambda m, e_reg: 0
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import List
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import numpy as np
|
|
19
|
+
from torch import nn
|
|
20
|
+
|
|
21
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
22
|
+
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
23
|
+
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
24
|
+
from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class LinearTempDecay:
|
|
28
|
+
"""
|
|
29
|
+
Annealing process for the soft quantizer regularization temperature term.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 20, end_b: int = 2):
|
|
33
|
+
"""
|
|
34
|
+
Initializes a LinearTempDecay object.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
t_max: maximal time step.
|
|
38
|
+
rel_start_decay: Decay step size at the beginning of the process.
|
|
39
|
+
start_b: Starting value of the regularization term.
|
|
40
|
+
end_b: Target value of the regularization term.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
self.t_max = t_max
|
|
44
|
+
self.start_decay = rel_start_decay * t_max
|
|
45
|
+
self.start_b = start_b
|
|
46
|
+
self.end_b = end_b
|
|
47
|
+
|
|
48
|
+
def __call__(self, t: float) -> float:
|
|
49
|
+
"""
|
|
50
|
+
Cosine annealing scheduler for soft quantizer regularization temperature term.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
t: The current time step.
|
|
54
|
+
|
|
55
|
+
Returns: Scheduled temperature.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
is_before_start_decay = (t < self.start_decay)
|
|
59
|
+
|
|
60
|
+
rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
|
|
61
|
+
|
|
62
|
+
return self.start_b * is_before_start_decay + \
|
|
63
|
+
(1 - is_before_start_decay) * \
|
|
64
|
+
(self.end_b + (self.start_b - self.end_b) * torch.maximum(to_torch_tensor(np.array([0.0])),
|
|
65
|
+
to_torch_tensor(np.array((1 - rel_t)))))
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class SoftQuantizerRegularization:
|
|
69
|
+
"""
|
|
70
|
+
A class to handle the computation of soft quantizer regularization for GPTQ training.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(self, total_gradient_steps: int):
|
|
74
|
+
"""
|
|
75
|
+
Initializes the regularization computation object with a LinearDecay object.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
total_gradient_steps: The number of gradient steps during optimization.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
# Initializing the temperature decay according to the number of expected gradient steps
|
|
82
|
+
self.linear_decay = LinearTempDecay(total_gradient_steps)
|
|
83
|
+
|
|
84
|
+
self.count_iter = 0
|
|
85
|
+
|
|
86
|
+
def __call__(self, model: nn.Module, entropy_reg: float):
|
|
87
|
+
"""
|
|
88
|
+
Returns the soft quantizer regularization value for SoftRounding.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
model: A model to be quantized with SoftRounding.
|
|
92
|
+
entropy_reg: Entropy value to scale the quantizer regularization.
|
|
93
|
+
|
|
94
|
+
Returns: Regularization value.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
soft_reg_aux: List[torch.Tensor] = []
|
|
98
|
+
for layer in model.modules():
|
|
99
|
+
if isinstance(layer, PytorchQuantizationWrapper):
|
|
100
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
|
101
|
+
fw_info=DEFAULT_PYTORCH_INFO)
|
|
102
|
+
|
|
103
|
+
st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
|
|
104
|
+
b = self.linear_decay(self.count_iter)
|
|
105
|
+
|
|
106
|
+
soft_reg_aux.append((1 - torch.pow(torch.abs(st - .5) * 2, b)).sum())
|
|
107
|
+
|
|
108
|
+
reg = 0
|
|
109
|
+
|
|
110
|
+
for sq in soft_reg_aux:
|
|
111
|
+
reg += sq
|
|
112
|
+
|
|
113
|
+
self.count_iter += 1
|
|
114
|
+
|
|
115
|
+
return entropy_reg * reg
|