mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -1,104 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
from typing import List, Tuple
|
|
17
|
-
|
|
18
|
-
import tensorflow as tf
|
|
19
|
-
from keras.models import Model
|
|
20
|
-
from tensorflow.python.util.object_identity import Reference as TFReference
|
|
21
|
-
from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
|
|
22
|
-
|
|
23
|
-
from model_compression_toolkit.core import common
|
|
24
|
-
from model_compression_toolkit.core.common import BaseNode
|
|
25
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
26
|
-
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
27
|
-
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder, \
|
|
28
|
-
is_layer_fake_quant, get_node_name_from_layer
|
|
29
|
-
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
30
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
31
|
-
from model_compression_toolkit.gptq.keras.quantizer.config_factory import quantization_config_builder_gptq
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
class GPTQKerasModelBuilder(KerasModelBuilder):
|
|
35
|
-
"""
|
|
36
|
-
Builder of GPTQ Keras models.
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
def __init__(self,
|
|
40
|
-
graph: common.Graph,
|
|
41
|
-
gptq_config: GradientPTQConfig,
|
|
42
|
-
append2output=None,
|
|
43
|
-
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
|
|
44
|
-
return_float_outputs: bool = True):
|
|
45
|
-
"""
|
|
46
|
-
|
|
47
|
-
Args:
|
|
48
|
-
graph: Graph to build the model from.
|
|
49
|
-
gptq_config: Configuration for GPTQ optimization.
|
|
50
|
-
append2output: Nodes to append to model's output.
|
|
51
|
-
fw_info: Information about the specific framework of the model that is built.
|
|
52
|
-
return_float_outputs: Whether the model returns float tensors or not.
|
|
53
|
-
"""
|
|
54
|
-
|
|
55
|
-
super().__init__(graph,
|
|
56
|
-
append2output,
|
|
57
|
-
fw_info,
|
|
58
|
-
return_float_outputs)
|
|
59
|
-
self.gptq_config = gptq_config
|
|
60
|
-
|
|
61
|
-
def _quantize_node_activations(self,
|
|
62
|
-
node: BaseNode,
|
|
63
|
-
input_tensors: List[TFReference]) -> List[TFReference]:
|
|
64
|
-
"""
|
|
65
|
-
Quantize node's activation given input tensors.
|
|
66
|
-
|
|
67
|
-
Args:
|
|
68
|
-
node: Node to quantize its outputs.
|
|
69
|
-
input_tensors: Input tensors of the node.
|
|
70
|
-
|
|
71
|
-
Returns:
|
|
72
|
-
Output of the node.
|
|
73
|
-
|
|
74
|
-
"""
|
|
75
|
-
|
|
76
|
-
return node.final_activation_quantization_cfg.quantize_node_output(input_tensors)
|
|
77
|
-
|
|
78
|
-
def build_model(self) -> Tuple[Model, UserInformation]:
|
|
79
|
-
"""
|
|
80
|
-
Build a Keras GPTQ model and return it.
|
|
81
|
-
Returns: GPTQ Keras model.
|
|
82
|
-
|
|
83
|
-
"""
|
|
84
|
-
model, user_info = super().build_model()
|
|
85
|
-
|
|
86
|
-
def _quantize(layer):
|
|
87
|
-
|
|
88
|
-
node = self.oh.layer_to_node_dict.get(layer)
|
|
89
|
-
|
|
90
|
-
if node is not None:
|
|
91
|
-
return QuantizeWrapper(layer, quantization_config_builder_gptq(node, self.fw_info, self.gptq_config))
|
|
92
|
-
|
|
93
|
-
elif is_layer_fake_quant(layer):
|
|
94
|
-
return layer
|
|
95
|
-
|
|
96
|
-
else:
|
|
97
|
-
raise Exception(
|
|
98
|
-
f"Mismatch between keras model and graph can't find node named: "
|
|
99
|
-
f"{get_node_name_from_layer(layer)}")
|
|
100
|
-
|
|
101
|
-
# clone each layer in the model and apply _quantize to the layer.
|
|
102
|
-
model = tf.keras.models.clone_model(model, input_tensors=None, clone_function=_quantize)
|
|
103
|
-
|
|
104
|
-
return model, user_info
|
|
@@ -1,119 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from typing import List, Callable, Tuple
|
|
16
|
-
|
|
17
|
-
import tensorflow as tf
|
|
18
|
-
from model_compression_toolkit.gptq.keras.quantizer.configs.weight_quantizer_gptq_config import WeightQuantizeConfig
|
|
19
|
-
from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class SAM:
|
|
23
|
-
"""
|
|
24
|
-
This class implements Sharpness-Aware Minimization for Efficiently Improving Generalization (https://arxiv.org/abs/2010.01412)
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
def __init__(self, model2quantized,
|
|
28
|
-
gradient_step: Callable,
|
|
29
|
-
optimizer_with_param: List[Tuple[List, List[tf.Tensor]]],
|
|
30
|
-
rho: float = 0.01,
|
|
31
|
-
eps: float = 1e-12):
|
|
32
|
-
"""
|
|
33
|
-
The init function of Sharpness-Aware Minimization gradient computation class.
|
|
34
|
-
Args:
|
|
35
|
-
model2quantized: Input quantized module
|
|
36
|
-
gradient_step: A function that returns a list of gradients tensors
|
|
37
|
-
optimizer_with_param: A list of optimizer classes to update with the corresponding parameters.
|
|
38
|
-
rho: A floating point number that set the region of smoothness
|
|
39
|
-
eps: A floating point number for numeric stability
|
|
40
|
-
"""
|
|
41
|
-
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
|
|
42
|
-
self.rho = rho
|
|
43
|
-
self.eps = eps
|
|
44
|
-
self.gradient_step = gradient_step
|
|
45
|
-
|
|
46
|
-
self.trainable_variables = [p for _, p in optimizer_with_param]
|
|
47
|
-
self.m_var = [len(p) for p in self.trainable_variables]
|
|
48
|
-
self.n_groups = len(self.trainable_variables)
|
|
49
|
-
self.model2quantized = model2quantized
|
|
50
|
-
self.e_ws = [[] for _ in range(len(optimizer_with_param))]
|
|
51
|
-
|
|
52
|
-
def _enable_update_step_param(self):
|
|
53
|
-
"""
|
|
54
|
-
This function enables the parameter update (update iteration index and gumbel random variable)
|
|
55
|
-
Returns: None
|
|
56
|
-
|
|
57
|
-
"""
|
|
58
|
-
for layer in self.model2quantized.layers:
|
|
59
|
-
if isinstance(layer, QuantizeWrapper) and isinstance(
|
|
60
|
-
layer.quantize_config, WeightQuantizeConfig):
|
|
61
|
-
layer.quantize_config.enable_update()
|
|
62
|
-
|
|
63
|
-
def _disable_update_step_param(self):
|
|
64
|
-
"""
|
|
65
|
-
This function disables the parameter update (update iteration index and gumbel random variable)
|
|
66
|
-
Returns: None
|
|
67
|
-
|
|
68
|
-
"""
|
|
69
|
-
for layer in self.model2quantized.layers:
|
|
70
|
-
if isinstance(layer, QuantizeWrapper) and isinstance(
|
|
71
|
-
layer.quantize_config, WeightQuantizeConfig):
|
|
72
|
-
layer.quantize_config.disable_update()
|
|
73
|
-
|
|
74
|
-
def _update_w_location(self, gradients: List[List[tf.Tensor]]):
|
|
75
|
-
"""
|
|
76
|
-
This function updates the weights position to the highest point
|
|
77
|
-
Args:
|
|
78
|
-
gradients: A list of gradients tensors
|
|
79
|
-
|
|
80
|
-
Returns: None
|
|
81
|
-
|
|
82
|
-
"""
|
|
83
|
-
|
|
84
|
-
for g in range(self.n_groups):
|
|
85
|
-
self.e_ws[g].clear()
|
|
86
|
-
grad_norm = tf.linalg.global_norm(gradients[g])
|
|
87
|
-
ew_multiplier = self.rho / (grad_norm + self.eps)
|
|
88
|
-
for i in range(self.m_var[g]):
|
|
89
|
-
e_w = tf.math.multiply(gradients[g][i], ew_multiplier)
|
|
90
|
-
self.trainable_variables[g][i].assign_add(e_w)
|
|
91
|
-
self.e_ws[g].append(e_w)
|
|
92
|
-
|
|
93
|
-
def _restore_w_location(self):
|
|
94
|
-
"""
|
|
95
|
-
Restore weights to the original position
|
|
96
|
-
Returns: None
|
|
97
|
-
|
|
98
|
-
"""
|
|
99
|
-
for g in range(self.n_groups):
|
|
100
|
-
for i in range(self.m_var[g]):
|
|
101
|
-
self.trainable_variables[g][i].assign_add(-self.e_ws[g][i])
|
|
102
|
-
|
|
103
|
-
def compute_gradients(self, *arg, **kwargs) -> (tf.Tensor, List[List[tf.Tensor]]):
|
|
104
|
-
"""
|
|
105
|
-
This function compute the gradients of SAM optimizer
|
|
106
|
-
Args:
|
|
107
|
-
*arg: args to pass to the gradient step functions
|
|
108
|
-
**kwargs: kwargs to pass to the gradient step functions
|
|
109
|
-
|
|
110
|
-
Returns: A tensor of the loss value and a list of gradients tensors
|
|
111
|
-
|
|
112
|
-
"""
|
|
113
|
-
self._enable_update_step_param()
|
|
114
|
-
loss, grad = self.gradient_step(*arg, **kwargs)
|
|
115
|
-
self._update_w_location(grad)
|
|
116
|
-
self._disable_update_step_param()
|
|
117
|
-
loss, grad = self.gradient_step(*arg, **kwargs)
|
|
118
|
-
self._restore_w_location()
|
|
119
|
-
return loss, grad
|
|
@@ -1,62 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_configs import \
|
|
18
|
-
NoOpQuantizeConfig
|
|
19
|
-
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_registry import \
|
|
20
|
-
QuantizeConfig
|
|
21
|
-
|
|
22
|
-
from model_compression_toolkit.core import common
|
|
23
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
24
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
25
|
-
from model_compression_toolkit.gptq.keras.quantizer import WeightQuantizeConfig
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def quantization_config_builder_gptq(n: common.BaseNode,
|
|
29
|
-
fw_info: FrameworkInfo,
|
|
30
|
-
gptq_config: GradientPTQConfig) -> QuantizeConfig:
|
|
31
|
-
"""
|
|
32
|
-
Build a QuantizeConfig for a node according to its quantization configuration and
|
|
33
|
-
a global NoOpQuantizeConfig object.
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
n: Node to build its QuantizeConfig.
|
|
37
|
-
fw_info: Framework information (e.g., mapping from layers to their attributes to quantize).
|
|
38
|
-
gptq_config: GPTQ Configuration class..
|
|
39
|
-
|
|
40
|
-
Returns:
|
|
41
|
-
A QuantizeConfig object with the appropriate quantizers (according to the node's
|
|
42
|
-
quantization configuration).
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
if n.is_weights_quantization_enabled() and n.is_activation_quantization_enabled():
|
|
46
|
-
qc = WeightQuantizeConfig(fw_info.get_kernel_op_attributes(n.type),
|
|
47
|
-
n.final_weights_quantization_cfg,
|
|
48
|
-
gptq_config)
|
|
49
|
-
elif n.is_activation_quantization_enabled() and not n.is_weights_quantization_enabled():
|
|
50
|
-
qc = NoOpQuantizeConfig() # Quantization is Preformed using fake quantization node
|
|
51
|
-
elif n.is_weights_quantization_enabled() and not n.is_activation_quantization_enabled():
|
|
52
|
-
qc = WeightQuantizeConfig(fw_info.get_kernel_op_attributes(n.type),
|
|
53
|
-
n.final_weights_quantization_cfg,
|
|
54
|
-
gptq_config)
|
|
55
|
-
|
|
56
|
-
elif not n.is_weights_quantization_enabled() and not n.is_activation_quantization_enabled():
|
|
57
|
-
qc = NoOpQuantizeConfig()
|
|
58
|
-
|
|
59
|
-
else:
|
|
60
|
-
raise Exception('Undefined quantization method')
|
|
61
|
-
|
|
62
|
-
return qc
|
|
@@ -1,65 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
from tensorflow_model_optimization.python.core.quantization.keras.quantize_config import QuantizeConfig
|
|
17
|
-
from typing import Tuple, List, Any, Dict
|
|
18
|
-
from tensorflow import Tensor
|
|
19
|
-
import six, abc
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
@six.add_metaclass(abc.ABCMeta)
|
|
23
|
-
class BaseQuantizeConfig(QuantizeConfig):
|
|
24
|
-
"""
|
|
25
|
-
Base QuantizeConfig to define extra API methods needed by the GPTQ post-processing.
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
@abc.abstractmethod
|
|
29
|
-
def get_quantization_variable(self):
|
|
30
|
-
"""
|
|
31
|
-
A Functions that get the quantization parameters such as threshold, min, max ,etc.
|
|
32
|
-
|
|
33
|
-
Returns: A list of trainable variable
|
|
34
|
-
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
@abc.abstractmethod
|
|
38
|
-
def update_layer_quantization_params(self, layer) -> Tuple[Dict[str, Any],
|
|
39
|
-
Dict[str, Any],
|
|
40
|
-
Dict[str, Any]]:
|
|
41
|
-
"""
|
|
42
|
-
A Function to calculate the needed change in attributes in NodeQuantizationConfig after retraining.
|
|
43
|
-
Usually a function of the config quantizers.
|
|
44
|
-
|
|
45
|
-
Args:
|
|
46
|
-
layer: layer being quantized.
|
|
47
|
-
|
|
48
|
-
Returns:
|
|
49
|
-
3 dictionaries of attributes the quantize_config retraining has changed during GPTQ retraining.
|
|
50
|
-
Keys must match NodeQuantizationConfig attributes:
|
|
51
|
-
1. layer weights
|
|
52
|
-
2. weight quantization config attributes
|
|
53
|
-
3. activation quantization config attributes
|
|
54
|
-
|
|
55
|
-
"""
|
|
56
|
-
|
|
57
|
-
@abc.abstractmethod
|
|
58
|
-
def get_trainable_quantizer_parameters(self) -> List[Tensor]:
|
|
59
|
-
"""
|
|
60
|
-
A function to get a list trainable of trainable parameters for GPTQ retraining from config quantizers
|
|
61
|
-
|
|
62
|
-
Returns:
|
|
63
|
-
A list of trainable Tensors
|
|
64
|
-
|
|
65
|
-
"""
|
|
@@ -1,269 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
from typing import List, Tuple, Any, Dict
|
|
17
|
-
|
|
18
|
-
from tensorflow import Tensor
|
|
19
|
-
import tensorflow as tf
|
|
20
|
-
from packaging import version
|
|
21
|
-
|
|
22
|
-
# As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
if version.parse(tf.__version__) < version.parse("2.6"):
|
|
26
|
-
from tensorflow.python.keras.layers import Layer
|
|
27
|
-
else:
|
|
28
|
-
from keras.engine.base_layer import Layer
|
|
29
|
-
|
|
30
|
-
from tensorflow.python.training.tracking.data_structures import ListWrapper
|
|
31
|
-
from tensorflow_model_optimization.python.core.quantization.keras.quantizers import Quantizer
|
|
32
|
-
|
|
33
|
-
from model_compression_toolkit.gptq.keras.quantizer.configs.base_quantizer_gptq_config import BaseQuantizeConfig
|
|
34
|
-
from model_compression_toolkit.core.keras.constants import KERNEL
|
|
35
|
-
|
|
36
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2, RoundingType
|
|
37
|
-
from model_compression_toolkit.gptq.keras.quantizer.gumbel_rounding.symmetric_gumbel import SymmetricGumbelRounding
|
|
38
|
-
from model_compression_toolkit.gptq.keras.quantizer.gumbel_rounding.uniform_gumbel import UniformGumbelRounding
|
|
39
|
-
from model_compression_toolkit.gptq.keras.quantizer.ste_rounding.symmetric_ste import STEWeightQuantizer
|
|
40
|
-
from model_compression_toolkit.core.common.target_platform.op_quantization_config import QuantizationMethod
|
|
41
|
-
from model_compression_toolkit.core.common.constants import THRESHOLD, RANGE_MAX, RANGE_MIN
|
|
42
|
-
from model_compression_toolkit.core import common
|
|
43
|
-
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
|
|
44
|
-
from model_compression_toolkit.gptq.common import gptq_constants
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
class WeightQuantizeConfig(BaseQuantizeConfig):
|
|
48
|
-
"""
|
|
49
|
-
QuantizeConfig to quantize the weights of a layer using a TrainableQuantizer.
|
|
50
|
-
"""
|
|
51
|
-
|
|
52
|
-
def __init__(self, weight_attrs: List[str],
|
|
53
|
-
final_weights_quantization_cfg: NodeWeightsQuantizationConfig,
|
|
54
|
-
gptq_config: GradientPTQConfigV2):
|
|
55
|
-
"""
|
|
56
|
-
Initialize a TrainableQuantizer and set as the weights quantizer.
|
|
57
|
-
Args:
|
|
58
|
-
weight_attrs: Attributes of the layer's weights to quantize.
|
|
59
|
-
final_weights_quantization_cfg: quantization config of the current layer.
|
|
60
|
-
gptq_config: A GPTQ configuration calls.
|
|
61
|
-
"""
|
|
62
|
-
|
|
63
|
-
num_bits = final_weights_quantization_cfg.weights_n_bits
|
|
64
|
-
weight_channel_axis = final_weights_quantization_cfg.weights_channels_axis
|
|
65
|
-
max_lsbs_change_map = gptq_config.lsb_change_per_bit_width
|
|
66
|
-
self.weight_attrs = weight_attrs
|
|
67
|
-
self.final_weights_quantization_cfg = final_weights_quantization_cfg
|
|
68
|
-
self.gptq_config = gptq_config
|
|
69
|
-
|
|
70
|
-
if final_weights_quantization_cfg.weights_quantization_method in [QuantizationMethod.SYMMETRIC,
|
|
71
|
-
QuantizationMethod.POWER_OF_TWO]:
|
|
72
|
-
is_power_of_two = QuantizationMethod.POWER_OF_TWO == final_weights_quantization_cfg.weights_quantization_method
|
|
73
|
-
threshold_values = final_weights_quantization_cfg.weights_quantization_params.get(THRESHOLD)
|
|
74
|
-
if gptq_config.rounding_type == RoundingType.STE:
|
|
75
|
-
self.weight_quantizer = STEWeightQuantizer(num_bits=num_bits,
|
|
76
|
-
per_axis=len(
|
|
77
|
-
threshold_values.flatten()) > 1,
|
|
78
|
-
threshold_values=threshold_values,
|
|
79
|
-
signed=True,
|
|
80
|
-
power_of_two=is_power_of_two,
|
|
81
|
-
quantization_axis=weight_channel_axis,
|
|
82
|
-
max_lsbs_change_map=max_lsbs_change_map)
|
|
83
|
-
elif gptq_config.rounding_type == RoundingType.GumbelRounding:
|
|
84
|
-
self.weight_quantizer = SymmetricGumbelRounding(num_bits=num_bits,
|
|
85
|
-
per_axis=len(
|
|
86
|
-
threshold_values.flatten()) > 1,
|
|
87
|
-
threshold_values=threshold_values,
|
|
88
|
-
signed=True,
|
|
89
|
-
power_of_two=is_power_of_two,
|
|
90
|
-
quantization_parameter_learning=gptq_config.quantization_parameters_learning,
|
|
91
|
-
quantization_axis=weight_channel_axis,
|
|
92
|
-
max_lsbs_change_map=max_lsbs_change_map,
|
|
93
|
-
max_iteration=gptq_config.n_epochs,
|
|
94
|
-
gumbel_config=gptq_config.quantizer_config)
|
|
95
|
-
else:
|
|
96
|
-
common.Logger.error(
|
|
97
|
-
f"For quantization method {final_weights_quantization_cfg.weights_quantization_method}, GPTQ Rounding type {gptq_config.rounding_type} is not supported")
|
|
98
|
-
elif final_weights_quantization_cfg.weights_quantization_method == QuantizationMethod.UNIFORM:
|
|
99
|
-
if not gptq_config.rounding_type == RoundingType.GumbelRounding:
|
|
100
|
-
common.Logger.error(
|
|
101
|
-
f"For quantization method {final_weights_quantization_cfg.weights_quantization_method}, GPTQ Rounding type {gptq_config.rounding_type} is not supported")
|
|
102
|
-
range_max = final_weights_quantization_cfg.weights_quantization_params.get(RANGE_MAX)
|
|
103
|
-
range_min = final_weights_quantization_cfg.weights_quantization_params.get(RANGE_MIN)
|
|
104
|
-
self.weight_quantizer = UniformGumbelRounding(num_bits=num_bits,
|
|
105
|
-
per_axis=len(
|
|
106
|
-
range_max.flatten()) > 1,
|
|
107
|
-
min_range=range_min,
|
|
108
|
-
max_range=range_max,
|
|
109
|
-
signed=True,
|
|
110
|
-
quantization_parameter_learning=gptq_config.quantization_parameters_learning,
|
|
111
|
-
quantization_axis=weight_channel_axis,
|
|
112
|
-
max_lsbs_change_map=max_lsbs_change_map,
|
|
113
|
-
max_iteration=gptq_config.n_epochs,
|
|
114
|
-
gumbel_config=gptq_config.quantizer_config)
|
|
115
|
-
|
|
116
|
-
def enable_update(self):
|
|
117
|
-
"""
|
|
118
|
-
This function enable the parameter update (update iteration index and gumbel random variable)
|
|
119
|
-
Returns: None
|
|
120
|
-
|
|
121
|
-
"""
|
|
122
|
-
if self.gptq_config.is_gumbel:
|
|
123
|
-
return self.weight_quantizer.enable_update()
|
|
124
|
-
|
|
125
|
-
def disable_update(self):
|
|
126
|
-
"""
|
|
127
|
-
|
|
128
|
-
This function disable the parameter update (update iteration index and gumbel random variable)
|
|
129
|
-
Returns: None
|
|
130
|
-
|
|
131
|
-
"""
|
|
132
|
-
if self.gptq_config.is_gumbel:
|
|
133
|
-
return self.weight_quantizer.disable_update()
|
|
134
|
-
|
|
135
|
-
def get_weights_and_quantizers(self, layer: Layer) -> List[Tuple[Tensor, Quantizer]]:
|
|
136
|
-
"""
|
|
137
|
-
Get a list of tuples with weights and the weight quantizer.
|
|
138
|
-
The layer's attributes are used to get the weights.
|
|
139
|
-
Args:
|
|
140
|
-
layer: The layer the WeightQuantizeConfig wraps.
|
|
141
|
-
|
|
142
|
-
Returns:
|
|
143
|
-
List of tuples of the layer's weights and the weight quantizer.
|
|
144
|
-
"""
|
|
145
|
-
return [(getattr(layer, weight_attr), self.weight_quantizer)
|
|
146
|
-
for weight_attr in self.weight_attrs]
|
|
147
|
-
|
|
148
|
-
def get_activations_and_quantizers(self, layer: Layer) -> list:
|
|
149
|
-
return []
|
|
150
|
-
|
|
151
|
-
def set_quantize_weights(self, layer: Layer, quantize_weights: List[Tensor]):
|
|
152
|
-
"""
|
|
153
|
-
Set the layer weights with new passed weights.
|
|
154
|
-
Args:
|
|
155
|
-
layer: Layer to set its attributes.
|
|
156
|
-
quantize_weights: Quantized weights to set as new weights.
|
|
157
|
-
|
|
158
|
-
"""
|
|
159
|
-
if len(self.weight_attrs) != len(quantize_weights):
|
|
160
|
-
raise ValueError(
|
|
161
|
-
'`set_quantize_weights` called on layer {} with {} '
|
|
162
|
-
'weight parameters, but layer expects {} values.'.format(
|
|
163
|
-
layer.name, len(quantize_weights), len(self.weight_attrs))) # pragma: no cover
|
|
164
|
-
|
|
165
|
-
for weight_attr, weight in zip(self.weight_attrs, quantize_weights):
|
|
166
|
-
current_weight = getattr(layer, weight_attr)
|
|
167
|
-
if current_weight.shape != weight.shape:
|
|
168
|
-
raise ValueError('Existing layer weight shape {} is incompatible with'
|
|
169
|
-
'provided weight shape {}'.format(
|
|
170
|
-
current_weight.shape, weight.shape)) # pragma: no cover
|
|
171
|
-
|
|
172
|
-
setattr(layer, weight_attr, weight)
|
|
173
|
-
|
|
174
|
-
def set_quantize_activations(self, layer, quantize_activations: ListWrapper):
|
|
175
|
-
pass
|
|
176
|
-
|
|
177
|
-
def get_output_quantizers(self, layer: Layer) -> list:
|
|
178
|
-
return []
|
|
179
|
-
|
|
180
|
-
@classmethod
|
|
181
|
-
def from_config(cls, config: dict):
|
|
182
|
-
"""
|
|
183
|
-
Instantiates a `WeightQuantizeConfig` from its config.
|
|
184
|
-
|
|
185
|
-
Args:
|
|
186
|
-
config: Output of `get_config()`.
|
|
187
|
-
|
|
188
|
-
Returns:
|
|
189
|
-
A `WeightQuantizeConfig` instance.
|
|
190
|
-
"""
|
|
191
|
-
|
|
192
|
-
return cls(**config)
|
|
193
|
-
|
|
194
|
-
def get_config(self) -> Dict[str, Any]:
|
|
195
|
-
"""
|
|
196
|
-
Returns: The WeightQuantizeConfig configuration.
|
|
197
|
-
"""
|
|
198
|
-
return {
|
|
199
|
-
'weight_attrs': self.weight_attrs,
|
|
200
|
-
'final_weights_quantization_cfg': self.final_weights_quantization_cfg,
|
|
201
|
-
'gptq_config': self.gptq_config,
|
|
202
|
-
}
|
|
203
|
-
|
|
204
|
-
def update_layer_quantization_params(self, layer: Layer) -> (Dict[str, tf.Tensor], Dict[str, Dict], Dict):
|
|
205
|
-
"""
|
|
206
|
-
A Function to calculate the needed change in attributes in NodeQuantizationConfig after retraining.
|
|
207
|
-
Usually a function of the config quantizers.
|
|
208
|
-
|
|
209
|
-
Args:
|
|
210
|
-
layer: layer being quantized.
|
|
211
|
-
|
|
212
|
-
Returns:
|
|
213
|
-
3 dictionaries describing the change in layer's weights, weights config, activation config
|
|
214
|
-
that changed during GPTQ retraining.
|
|
215
|
-
Keys must match NodeQuantizationConfig attributes
|
|
216
|
-
|
|
217
|
-
"""
|
|
218
|
-
weights = {}
|
|
219
|
-
for weight, quantizer, quantizer_vars in layer._weight_vars:
|
|
220
|
-
weights.update({KERNEL: quantizer(weight, training=False, weights=quantizer_vars)})
|
|
221
|
-
|
|
222
|
-
quant_config = {gptq_constants.WEIGHTS_QUANTIZATION_PARAMS: self.weight_quantizer.get_quant_config(layer)}
|
|
223
|
-
|
|
224
|
-
return weights, quant_config, {}
|
|
225
|
-
|
|
226
|
-
def get_trainable_quantizer_parameters(self) -> List[tf.Tensor]:
|
|
227
|
-
"""
|
|
228
|
-
A function to get a list trainable of trainable parameters for GPTQ retraining from config quantizers
|
|
229
|
-
|
|
230
|
-
Returns:
|
|
231
|
-
A list of trainable Tensors
|
|
232
|
-
|
|
233
|
-
"""
|
|
234
|
-
return self.weight_quantizer.get_trainable_parameters()
|
|
235
|
-
|
|
236
|
-
def get_aux_variable(self) -> List[tf.Tensor]:
|
|
237
|
-
return [self.weight_quantizer.get_aux_variable()]
|
|
238
|
-
|
|
239
|
-
def get_quantization_variable(self) -> List[tf.Tensor]:
|
|
240
|
-
return self.weight_quantizer.get_quantization_variable()
|
|
241
|
-
|
|
242
|
-
def get_temperature_variable(self) -> tf.Tensor:
|
|
243
|
-
if self.gptq_config.is_gumbel:
|
|
244
|
-
return self.weight_quantizer.get_temperature_variable()
|
|
245
|
-
else:
|
|
246
|
-
common.logger.Logger.error("Temperature variable only exist when using Gumbel Rounding Quantizer")
|
|
247
|
-
|
|
248
|
-
def get_gumbel_probability(self) -> tf.Tensor:
|
|
249
|
-
if self.gptq_config.is_gumbel:
|
|
250
|
-
return self.weight_quantizer.get_gumbel_probability()
|
|
251
|
-
else:
|
|
252
|
-
common.logger.Logger.error("Probability variable only exist when using Gumbel Rounding Quantizer")
|
|
253
|
-
|
|
254
|
-
def __eq__(self, other: Any) -> bool:
|
|
255
|
-
"""
|
|
256
|
-
Check whether it equals to another object or not.
|
|
257
|
-
"""
|
|
258
|
-
if not isinstance(other, WeightQuantizeConfig):
|
|
259
|
-
return False
|
|
260
|
-
|
|
261
|
-
return (self.weight_attrs == other.weight_attrs and
|
|
262
|
-
self.weight_quantizer == other.weight_quantizer and
|
|
263
|
-
self.gptq_config == other.gptq_config)
|
|
264
|
-
|
|
265
|
-
def __ne__(self, other: Any) -> bool:
|
|
266
|
-
"""
|
|
267
|
-
Check whether it differs from another object or not.
|
|
268
|
-
"""
|
|
269
|
-
return not self.__eq__(other)
|