mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -33,7 +33,7 @@ if FOUND_TORCH:
|
|
|
33
33
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
|
34
34
|
from model_compression_toolkit.core.pytorch.constants import DEFAULT_TP_MODEL
|
|
35
35
|
from torch.nn import Module
|
|
36
|
-
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import
|
|
36
|
+
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
|
|
37
37
|
from model_compression_toolkit import get_target_platform_capabilities
|
|
38
38
|
|
|
39
39
|
DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
|
@@ -62,7 +62,7 @@ if FOUND_TORCH:
|
|
|
62
62
|
representative_data_gen (Callable): Dataset used for calibration.
|
|
63
63
|
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
|
|
64
64
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
65
|
-
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
65
|
+
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
66
66
|
new_experimental_exporter (bool): Whether exporting the quantized model using new exporter or not (in progress. Avoiding it for now is recommended).
|
|
67
67
|
|
|
68
68
|
Returns:
|
|
@@ -95,8 +95,9 @@ if FOUND_TORCH:
|
|
|
95
95
|
if core_config.mixed_precision_enable:
|
|
96
96
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
|
|
97
97
|
common.Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
98
|
-
"MixedPrecisionQuantizationConfigV2. Please use
|
|
99
|
-
"or pass a valid mixed precision
|
|
98
|
+
"MixedPrecisionQuantizationConfigV2. Please use "
|
|
99
|
+
"pytorch_post_training_quantization API, or pass a valid mixed precision "
|
|
100
|
+
"configuration.") # pragma: no cover
|
|
100
101
|
|
|
101
102
|
common.Logger.info("Using experimental mixed-precision quantization. "
|
|
102
103
|
"If you encounter an issue please file a bug.")
|
|
@@ -127,7 +128,7 @@ if FOUND_TORCH:
|
|
|
127
128
|
Logger.warning('Using new experimental exported models. '
|
|
128
129
|
'Please do not use unless you are familiar with what you are doing')
|
|
129
130
|
|
|
130
|
-
return
|
|
131
|
+
return get_exportable_pytorch_model(tg)
|
|
131
132
|
|
|
132
133
|
quantized_model, user_info = export_model(tg,
|
|
133
134
|
DEFAULT_PYTORCH_INFO,
|
|
@@ -143,4 +144,4 @@ else:
|
|
|
143
144
|
def pytorch_post_training_quantization_experimental(*args, **kwargs):
|
|
144
145
|
Logger.critical('Installing Pytorch is mandatory '
|
|
145
146
|
'when using pytorch_post_training_quantization_experimental. '
|
|
146
|
-
'Could not find the torch package.')
|
|
147
|
+
'Could not find the torch package.') # pragma: no cover
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Dict
|
|
17
|
+
from enum import Enum
|
|
18
|
+
from model_compression_toolkit.core import common
|
|
19
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
20
|
+
|
|
21
|
+
def _is_qat_applicable(node: common.BaseNode,
|
|
22
|
+
fw_info: FrameworkInfo) -> bool:
|
|
23
|
+
"""
|
|
24
|
+
A function for deciding if a layer should be fine-tuned during QAT
|
|
25
|
+
Args:
|
|
26
|
+
node (BaseNode): Node for quantization decision
|
|
27
|
+
fw_info (FrameworkInfo): Pytorch quantization information
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
A boolean whether the layer is to be wrapped with a QuantizeWrapper
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
if node.is_weights_quantization_enabled() and not fw_info.is_kernel_op(node.type):
|
|
34
|
+
common.Logger.error("QAT Error: Quantizing a node without a kernel isn't supported")
|
|
35
|
+
return node.is_weights_quantization_enabled() or node.is_activation_quantization_enabled()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class TrainingMethod(Enum):
|
|
39
|
+
"""
|
|
40
|
+
An enum for selecting a QAT training method
|
|
41
|
+
|
|
42
|
+
STE - Standard straight-through estimator. Includes PowerOfTwo, symmetric & uniform quantizers
|
|
43
|
+
"""
|
|
44
|
+
STE = "STE",
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class QATConfig:
|
|
48
|
+
"""
|
|
49
|
+
QAT configuration class.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, weight_training_method: TrainingMethod = TrainingMethod.STE,
|
|
53
|
+
activation_training_method: TrainingMethod = TrainingMethod.STE,
|
|
54
|
+
weight_quantizer_params_override: Dict = None,
|
|
55
|
+
activation_quantizer_params_override: Dict = None,
|
|
56
|
+
):
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
weight_training_method (TrainingMethod): Training method for weight quantizers
|
|
61
|
+
activation_training_method (TrainingMethod): Training method for activation quantizers:
|
|
62
|
+
weight_quantizer_params_override: A dictionary of parameters to override in weight quantization quantizer instantiation. Defaults to None (no parameters)
|
|
63
|
+
activation_quantizer_params_override: A dictionary of parameters to override in activation quantization quantizer instantiation. Defaults to None (no parameters)
|
|
64
|
+
"""
|
|
65
|
+
self.weight_training_method = weight_training_method
|
|
66
|
+
self.activation_training_method = activation_training_method
|
|
67
|
+
self.weight_quantizer_params_override = {} if weight_quantizer_params_override is None else weight_quantizer_params_override
|
|
68
|
+
self.activation_quantizer_params_override = {} if activation_quantizer_params_override is None else activation_quantizer_params_override
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
from typing import Callable
|
|
17
|
+
from functools import partial
|
|
17
18
|
|
|
18
19
|
from model_compression_toolkit import CoreConfig
|
|
19
20
|
from model_compression_toolkit.core import common
|
|
@@ -29,25 +30,56 @@ from model_compression_toolkit.ptq.runner import ptq_runner
|
|
|
29
30
|
|
|
30
31
|
if FOUND_TF:
|
|
31
32
|
import tensorflow as tf
|
|
33
|
+
from tensorflow.keras.layers import Layer
|
|
34
|
+
from tensorflow.keras.models import Model
|
|
32
35
|
|
|
33
36
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
34
37
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
|
35
38
|
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
|
36
|
-
from tensorflow.keras.models import Model
|
|
37
39
|
from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL
|
|
38
40
|
|
|
39
|
-
from model_compression_toolkit.
|
|
41
|
+
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
|
42
|
+
|
|
43
|
+
from model_compression_toolkit import get_target_platform_capabilities
|
|
44
|
+
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
40
45
|
|
|
41
46
|
from model_compression_toolkit import get_target_platform_capabilities
|
|
42
|
-
from model_compression_toolkit import
|
|
47
|
+
from model_compression_toolkit.core import common
|
|
48
|
+
from model_compression_toolkit.core.common import BaseNode
|
|
49
|
+
from model_compression_toolkit.core.common.constants import TENSORFLOW
|
|
50
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
51
|
+
from model_compression_toolkit.qat.common.qat_config import _is_qat_applicable
|
|
52
|
+
from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL
|
|
53
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
54
|
+
from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder
|
|
55
|
+
from model_compression_toolkit.qat.common.qat_config import QATConfig
|
|
56
|
+
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
43
57
|
|
|
44
58
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
|
45
59
|
|
|
46
60
|
|
|
61
|
+
def qat_wrapper(n: common.BaseNode, layer: Layer, qat_config):
|
|
62
|
+
"""
|
|
63
|
+
A function which takes a computational graph node and a keras layer and perform the quantization wrapping
|
|
64
|
+
Args:
|
|
65
|
+
n: A node of mct graph.
|
|
66
|
+
layer: A keras layer
|
|
67
|
+
|
|
68
|
+
Returns: Wrapped layer
|
|
69
|
+
|
|
70
|
+
"""
|
|
71
|
+
if _is_qat_applicable(n, DEFAULT_KERAS_INFO):
|
|
72
|
+
weights_quantizers, activation_quantizers = quantization_builder(n, qat_config, DEFAULT_KERAS_INFO)
|
|
73
|
+
return qi.KerasQuantizationWrapper(layer, weights_quantizers, activation_quantizers)
|
|
74
|
+
else:
|
|
75
|
+
return layer
|
|
76
|
+
|
|
77
|
+
|
|
47
78
|
def keras_quantization_aware_training_init(in_model: Model,
|
|
48
79
|
representative_data_gen: Callable,
|
|
49
80
|
target_kpi: KPI = None,
|
|
50
81
|
core_config: CoreConfig = CoreConfig(),
|
|
82
|
+
qat_config: QATConfig = QATConfig(),
|
|
51
83
|
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
|
|
52
84
|
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
|
|
53
85
|
"""
|
|
@@ -70,6 +102,7 @@ if FOUND_TF:
|
|
|
70
102
|
representative_data_gen (Callable): Dataset used for initial calibration.
|
|
71
103
|
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
|
|
72
104
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
105
|
+
qat_config (QATConfig): QAT configuration
|
|
73
106
|
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
|
|
74
107
|
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
75
108
|
|
|
@@ -90,14 +123,14 @@ if FOUND_TF:
|
|
|
90
123
|
>>> from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
|
|
91
124
|
>>> model = MobileNetV2()
|
|
92
125
|
|
|
93
|
-
|
|
94
|
-
|
|
126
|
+
Create a random dataset generator, for required number of calibration iterations (num_calibration_batches):
|
|
127
|
+
In this example a random dataset of 10 batches each containing 4 images is used.
|
|
95
128
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
129
|
+
>>> import numpy as np
|
|
130
|
+
>>> num_calibration_batches = 10
|
|
131
|
+
>>> def repr_datagen():
|
|
132
|
+
>>> for _ in range(num_calibration_batches):
|
|
133
|
+
>>> yield [np.random.random((4, 224, 224, 3))]
|
|
101
134
|
|
|
102
135
|
Create a MCT core config, containing the quantization configuration:
|
|
103
136
|
|
|
@@ -154,24 +187,23 @@ if FOUND_TF:
|
|
|
154
187
|
|
|
155
188
|
tg = ptq_runner(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
|
|
156
189
|
|
|
157
|
-
|
|
190
|
+
_qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
|
|
191
|
+
qat_model, user_info = KerasModelBuilder(graph=tg, fw_info=fw_info, wrapper=_qat_wrapper).build_model()
|
|
158
192
|
|
|
159
193
|
user_info.mixed_precision_cfg = bit_widths_config
|
|
160
194
|
#TODO: remove the last output after updating documentation.
|
|
161
195
|
return qat_model, user_info, {}
|
|
162
196
|
|
|
163
197
|
|
|
164
|
-
def keras_quantization_aware_training_finalize(in_model: Model):
|
|
198
|
+
def keras_quantization_aware_training_finalize(in_model: Model) -> Model:
|
|
165
199
|
"""
|
|
166
|
-
Convert a model fine-tuned by the user to a
|
|
167
|
-
model contains float (fake-quantized) parameters and fake-quantiztion layers for quantizing
|
|
168
|
-
the activations
|
|
200
|
+
Convert a model fine-tuned by the user (Trainable quantizers) to a model with Inferable quantizers.
|
|
169
201
|
|
|
170
202
|
Args:
|
|
171
|
-
in_model (Model): Keras model to
|
|
203
|
+
in_model (Model): Keras model to replace TrainableQuantizer with InferableQuantizer
|
|
172
204
|
|
|
173
205
|
Returns:
|
|
174
|
-
A quantized model
|
|
206
|
+
A quantized model with Inferable quantizers
|
|
175
207
|
|
|
176
208
|
Examples:
|
|
177
209
|
|
|
@@ -216,37 +248,12 @@ if FOUND_TF:
|
|
|
216
248
|
>>> quantized_model = mct.keras_quantization_aware_training_finalize(quantized_model)
|
|
217
249
|
|
|
218
250
|
"""
|
|
219
|
-
|
|
220
251
|
def _export(layer):
|
|
221
252
|
if isinstance(layer, qi.KerasQuantizationWrapper):
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
weights_list = []
|
|
227
|
-
for w in new_layer.weights:
|
|
228
|
-
val = None
|
|
229
|
-
for qw in layer.weights:
|
|
230
|
-
if w.name in qw.name:
|
|
231
|
-
attribute_name = w.name.split('/')[-1].split(':')[0]
|
|
232
|
-
if attribute_name in layer.dispatcher.weight_quantizers.keys():
|
|
233
|
-
quantizer = layer.dispatcher.weight_quantizers.get(attribute_name)
|
|
234
|
-
val = quantizer(qw, False)
|
|
235
|
-
else:
|
|
236
|
-
val = qw
|
|
237
|
-
val = val.numpy()
|
|
238
|
-
if val is None:
|
|
239
|
-
Logger.error(f'Could not match weight name: {w.name}')
|
|
240
|
-
weights_list.append(val)
|
|
241
|
-
new_layer.set_weights(weights_list)
|
|
242
|
-
new_layer.trainable = False
|
|
243
|
-
return new_layer
|
|
244
|
-
else:
|
|
245
|
-
Logger.error(f'Undefined quantize_config')
|
|
246
|
-
else:
|
|
247
|
-
return layer
|
|
248
|
-
|
|
249
|
-
# clone each layer in the model and apply _export to layers wrapped with a QuantizeWrapper.
|
|
253
|
+
layer.convert_to_inferable_quantizers()
|
|
254
|
+
return layer
|
|
255
|
+
|
|
256
|
+
# clone each layer in the model and apply _export to layers with TrainableQuantizeWrappers
|
|
250
257
|
exported_model = tf.keras.models.clone_model(in_model, input_tensors=None, clone_function=_export)
|
|
251
258
|
|
|
252
259
|
return exported_model
|
|
@@ -257,10 +264,10 @@ else:
|
|
|
257
264
|
def keras_quantization_aware_training_init(*args, **kwargs):
|
|
258
265
|
Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
259
266
|
'when using keras_quantization_aware_training_init. '
|
|
260
|
-
'Could not find Tensorflow package.')
|
|
267
|
+
'Could not find Tensorflow package.') # pragma: no cover
|
|
261
268
|
|
|
262
269
|
|
|
263
270
|
def keras_quantization_aware_training_finalize(*args, **kwargs):
|
|
264
271
|
Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
265
272
|
'when using keras_quantization_aware_training_finalize. '
|
|
266
|
-
'Could not find Tensorflow package.')
|
|
273
|
+
'Could not find Tensorflow package.') # pragma: no cover
|
|
@@ -12,3 +12,6 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import model_compression_toolkit.qat.keras.quantizer.ste_rounding.symmetric_ste
|
|
17
|
+
import model_compression_toolkit.qat.keras.quantizer.ste_rounding.uniform_ste
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import Union
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit.core.common import Logger
|
|
18
|
+
from model_compression_toolkit.core.common.constants import FOUND_TF
|
|
19
|
+
|
|
20
|
+
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
21
|
+
TrainableQuantizerActivationConfig, BaseKerasTrainableQuantizer
|
|
22
|
+
|
|
23
|
+
if FOUND_TF:
|
|
24
|
+
|
|
25
|
+
class BaseKerasQATTrainableQuantizer(BaseKerasTrainableQuantizer):
|
|
26
|
+
"""
|
|
27
|
+
A base class for trainable Keras quantizer for QAT.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self,
|
|
31
|
+
quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
|
|
32
|
+
"""
|
|
33
|
+
Initializes BaseKerasQATTrainableQuantizer object.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
quantization_config: quantizer config class contains all the information about a quantizer configuration.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
super().__init__(quantization_config)
|
|
40
|
+
|
|
41
|
+
else:
|
|
42
|
+
class BaseKerasQATTrainableQuantizer(BaseKerasTrainableQuantizer):
|
|
43
|
+
def __init__(self,
|
|
44
|
+
quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
|
|
45
|
+
|
|
46
|
+
super().__init__(quantization_config)
|
|
47
|
+
Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
48
|
+
'when using BaseKerasQATTrainableQuantizer. '
|
|
49
|
+
'Could not find Tensorflow package.') # pragma: no cover
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import tensorflow as tf
|
|
17
|
+
from typing import Tuple
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def adjust_range_to_include_zero(range_min: tf.Tensor,
|
|
21
|
+
range_max: tf.Tensor,
|
|
22
|
+
n_bits: int) -> Tuple[tf.Tensor, tf.Tensor]:
|
|
23
|
+
"""
|
|
24
|
+
Adjusting the quantization range to include representation of 0.0 in the quantization grid.
|
|
25
|
+
For per_channel quantization range_min\range_max should be tensors in the specific shape that allows
|
|
26
|
+
quantization along the channel_axis.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
range_min: min bound of the quantization range (before adjustment).
|
|
30
|
+
range_max: max bound of the quantization range (before adjustment).
|
|
31
|
+
n_bits: Number of bits to quantize the tensor.
|
|
32
|
+
|
|
33
|
+
Returns: adjusted quantization range
|
|
34
|
+
"""
|
|
35
|
+
scale = (range_max - range_min) / (2 ** n_bits - 1)
|
|
36
|
+
min_range_adj = scale * tf.round(range_min / scale)
|
|
37
|
+
max_range_adj = range_max - range_min + min_range_adj
|
|
38
|
+
|
|
39
|
+
min_positive = range_min > 0
|
|
40
|
+
max_negative = range_max < 0
|
|
41
|
+
mid_range = tf.logical_and(tf.logical_not(min_positive), tf.logical_not(max_negative))
|
|
42
|
+
min_positive = tf.cast(min_positive, tf.float32)
|
|
43
|
+
max_negative = tf.cast(max_negative, tf.float32)
|
|
44
|
+
mid_range = tf.cast(mid_range, tf.float32)
|
|
45
|
+
min_range_adj = min_range_adj * mid_range + max_negative * range_min
|
|
46
|
+
max_range_adj = max_range_adj * mid_range + min_positive * range_max
|
|
47
|
+
|
|
48
|
+
return min_range_adj, max_range_adj
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import Tuple, Dict, List
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit.core import common
|
|
18
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
19
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizer_config import \
|
|
20
|
+
get_trainable_quantizer_weights_config, get_trainable_quantizer_activation_config, \
|
|
21
|
+
get_trainable_quantizer_quantization_candidates
|
|
22
|
+
from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
|
|
23
|
+
from model_compression_toolkit.qat.common.qat_config import QATConfig
|
|
24
|
+
from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
|
|
25
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.get_quantizers import \
|
|
26
|
+
get_trainable_quantizer_class
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def quantization_builder(n: common.BaseNode,
|
|
30
|
+
qat_config: QATConfig,
|
|
31
|
+
fw_info: FrameworkInfo,
|
|
32
|
+
) -> Tuple[Dict[str, BaseKerasQATTrainableQuantizer], List[BaseKerasQATTrainableQuantizer]]:
|
|
33
|
+
"""
|
|
34
|
+
Build quantizers for a node according to its quantization configuration.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
n: Node to build its QuantizeConfig.
|
|
38
|
+
qat_config (QATConfig): QAT configuration
|
|
39
|
+
fw_info: Framework information (e.g., mapping from layers to their attributes to quantize).
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
weights_quantizers: A dictionary between a weight's name to its quantizer.
|
|
43
|
+
activation_quantizers: A list of activations quantization, one for each layer output.
|
|
44
|
+
"""
|
|
45
|
+
if len(n.candidates_quantization_cfg) > 1:
|
|
46
|
+
wq_cand, aq_cand = get_trainable_quantizer_quantization_candidates(n)
|
|
47
|
+
else:
|
|
48
|
+
wq_cand, aq_cand = None, None
|
|
49
|
+
|
|
50
|
+
weight_quantizers = {}
|
|
51
|
+
if n.is_weights_quantization_enabled():
|
|
52
|
+
quant_method = n.final_weights_quantization_cfg.weights_quantization_method
|
|
53
|
+
|
|
54
|
+
quantizer_class = get_trainable_quantizer_class(QuantizationTarget.Weights,
|
|
55
|
+
qat_config.weight_training_method,
|
|
56
|
+
quant_method,
|
|
57
|
+
BaseKerasQATTrainableQuantizer)
|
|
58
|
+
attributes = fw_info.get_kernel_op_attributes(n.type)
|
|
59
|
+
for attr in attributes:
|
|
60
|
+
weight_quantizers.update({attr: quantizer_class(get_trainable_quantizer_weights_config(n, wq_cand),
|
|
61
|
+
**qat_config.weight_quantizer_params_override)})
|
|
62
|
+
|
|
63
|
+
activation_quantizers = []
|
|
64
|
+
if n.is_activation_quantization_enabled():
|
|
65
|
+
quant_method = n.final_activation_quantization_cfg.activation_quantization_method
|
|
66
|
+
# single output -> normalize to list of output_shapes
|
|
67
|
+
output_shapes = n.output_shape if isinstance(n.output_shape[0], (list, tuple)) else [n.output_shape]
|
|
68
|
+
|
|
69
|
+
quantizer_class = get_trainable_quantizer_class(QuantizationTarget.Activation,
|
|
70
|
+
qat_config.activation_training_method,
|
|
71
|
+
quant_method,
|
|
72
|
+
BaseKerasQATTrainableQuantizer)
|
|
73
|
+
|
|
74
|
+
activation_quantizers = [quantizer_class(get_trainable_quantizer_activation_config(n, aq_cand),
|
|
75
|
+
**qat_config.activation_quantizer_params_override)] * len(output_shapes)
|
|
76
|
+
|
|
77
|
+
return weight_quantizers, activation_quantizers
|