mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -1,105 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
from typing import List
|
|
17
|
-
from tensorflow.keras.layers import Layer
|
|
18
|
-
from tensorflow.python.util.object_identity import Reference as TFReference
|
|
19
|
-
|
|
20
|
-
from model_compression_toolkit import get_target_platform_capabilities
|
|
21
|
-
from model_compression_toolkit.core import common
|
|
22
|
-
from model_compression_toolkit.core.common import BaseNode
|
|
23
|
-
from model_compression_toolkit.core.common.constants import TENSORFLOW
|
|
24
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
25
|
-
|
|
26
|
-
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
|
27
|
-
from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL
|
|
28
|
-
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
29
|
-
from model_compression_toolkit.qat.keras.quantizer.quantization_dispatcher_builder import \
|
|
30
|
-
quantization_dispatcher_builder
|
|
31
|
-
from model_compression_toolkit import qunatizers_infrastructure as qi
|
|
32
|
-
|
|
33
|
-
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def _is_qat_applicable(node: common.BaseNode,
|
|
37
|
-
fw_info: FrameworkInfo) -> bool:
|
|
38
|
-
"""
|
|
39
|
-
A function for deciding if a layer should be fine-tuned during QAT
|
|
40
|
-
Args:
|
|
41
|
-
node (BaseNode): Node for quantization decision
|
|
42
|
-
fw_info (FrameworkInfo): Keras quantization information
|
|
43
|
-
|
|
44
|
-
Returns:
|
|
45
|
-
A boolean whether the layer is to be wrapped with a QuantizeWrapper
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
return fw_info.is_kernel_op(node.type) and node.is_weights_quantization_enabled()
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
def qat_wrapper(n: common.BaseNode, layer: Layer):
|
|
52
|
-
"""
|
|
53
|
-
A function which takes a computational graph node and a keras layer and perform the quantization wrapping
|
|
54
|
-
Args:
|
|
55
|
-
n: A node of mct graph.
|
|
56
|
-
layer: A keras layer
|
|
57
|
-
|
|
58
|
-
Returns: Wrapped layer
|
|
59
|
-
|
|
60
|
-
"""
|
|
61
|
-
if _is_qat_applicable(n, DEFAULT_KERAS_INFO):
|
|
62
|
-
return qi.KerasQuantizationWrapper(layer, quantization_dispatcher_builder(n, DEFAULT_KERAS_INFO))
|
|
63
|
-
else:
|
|
64
|
-
return layer
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
class QATKerasModelBuilder(KerasModelBuilder):
|
|
68
|
-
"""
|
|
69
|
-
Builder of QAT Keras models.
|
|
70
|
-
"""
|
|
71
|
-
|
|
72
|
-
def __init__(self,
|
|
73
|
-
graph: common.Graph,
|
|
74
|
-
append2output=None,
|
|
75
|
-
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
|
|
76
|
-
return_float_outputs: bool = False):
|
|
77
|
-
"""
|
|
78
|
-
|
|
79
|
-
Args:
|
|
80
|
-
graph: Graph to build the model from.
|
|
81
|
-
append2output: Nodes to append to model's output.
|
|
82
|
-
fw_info: Information about the specific framework of the model that is built.
|
|
83
|
-
return_float_outputs: Whether the model returns float tensors or not.
|
|
84
|
-
"""
|
|
85
|
-
super().__init__(graph,
|
|
86
|
-
append2output,
|
|
87
|
-
fw_info,
|
|
88
|
-
return_float_outputs,
|
|
89
|
-
wrapper=qat_wrapper)
|
|
90
|
-
|
|
91
|
-
def _quantize_node_activations(self,
|
|
92
|
-
node: BaseNode,
|
|
93
|
-
input_tensors: List[TFReference]) -> List[TFReference]:
|
|
94
|
-
"""
|
|
95
|
-
Quantize node's activation given input tensors.
|
|
96
|
-
|
|
97
|
-
Args:
|
|
98
|
-
node: Node to quantize its outputs.
|
|
99
|
-
input_tensors: Input tensors of the node.
|
|
100
|
-
|
|
101
|
-
Returns:
|
|
102
|
-
Output of the node.
|
|
103
|
-
|
|
104
|
-
"""
|
|
105
|
-
return node.final_activation_quantization_cfg.quantize_node_output(input_tensors)
|
|
@@ -1,56 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
from typing import Dict
|
|
18
|
-
from model_compression_toolkit.core import common
|
|
19
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
20
|
-
from model_compression_toolkit import qunatizers_infrastructure as qi
|
|
21
|
-
from model_compression_toolkit.qat.keras.quantizer.ste_rounding.symmetirc_ste import STEWeightQuantizer
|
|
22
|
-
from model_compression_toolkit.qat.keras.quantizer.ste_rounding.uniform_ste import STEUniformWeightQuantizer
|
|
23
|
-
|
|
24
|
-
METHOD2QUANTIZER = {qi.QuantizationMethod.SYMMETRIC: STEWeightQuantizer,
|
|
25
|
-
qi.QuantizationMethod.POWER_OF_TWO: STEWeightQuantizer,
|
|
26
|
-
qi.QuantizationMethod.UNIFORM: STEUniformWeightQuantizer}
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def quantization_dispatcher_builder(n: common.BaseNode,
|
|
30
|
-
fw_info: FrameworkInfo,
|
|
31
|
-
method2quantizer: Dict[
|
|
32
|
-
qi.QuantizationMethod, qi.BaseKerasQuantizer] = METHOD2QUANTIZER) -> qi.KerasNodeQuantizationDispatcher:
|
|
33
|
-
"""
|
|
34
|
-
Build a NodeQuantizationDispatcher for a node according to its quantization configuration and
|
|
35
|
-
a global NoOpQuantizeConfig object.
|
|
36
|
-
|
|
37
|
-
Args:
|
|
38
|
-
n: Node to build its QuantizeConfig.
|
|
39
|
-
fw_info: Framework information (e.g., mapping from layers to their attributes to quantize).
|
|
40
|
-
method2quantizer: A mapping between quantization method to quantizer.
|
|
41
|
-
|
|
42
|
-
Returns:
|
|
43
|
-
A QuantizeConfig object with the appropriate quantizers (according to the node's
|
|
44
|
-
quantization configuration).
|
|
45
|
-
"""
|
|
46
|
-
nqd = qi.KerasNodeQuantizationDispatcher()
|
|
47
|
-
if n.is_weights_quantization_enabled():
|
|
48
|
-
attributes = fw_info.get_kernel_op_attributes(n.type)
|
|
49
|
-
for attr in attributes:
|
|
50
|
-
qunatizer_class = method2quantizer.get(n.final_weights_quantization_cfg.weights_quantization_method)
|
|
51
|
-
if qunatizer_class is None:
|
|
52
|
-
common.Logger.error(
|
|
53
|
-
f'Unknown Quantiztion method: {n.final_weights_quantization_cfg.weights_quantization_method}')
|
|
54
|
-
nqd.add_weight_quantizer(attr, qunatizer_class(n.final_weights_quantization_cfg))
|
|
55
|
-
|
|
56
|
-
return nqd
|
|
@@ -1,145 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
from typing import Dict
|
|
17
|
-
|
|
18
|
-
import numpy as np
|
|
19
|
-
import tensorflow as tf
|
|
20
|
-
from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
|
|
21
|
-
from tensorflow.python.framework.tensor_shape import TensorShape
|
|
22
|
-
|
|
23
|
-
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
|
|
24
|
-
|
|
25
|
-
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
26
|
-
from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
|
|
27
|
-
from model_compression_toolkit.qat.common.constants import FQ_MIN, FQ_MAX
|
|
28
|
-
from model_compression_toolkit import qunatizers_infrastructure as qi
|
|
29
|
-
from model_compression_toolkit.core.common import constants as C
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class STEWeightQuantizer(qi.BaseKerasQuantizer):
|
|
33
|
-
"""
|
|
34
|
-
Trainable constrained quantizer to quantize a layer inputs.
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
def __init__(self, quantization_config: NodeWeightsQuantizationConfig):
|
|
38
|
-
"""
|
|
39
|
-
Initialize a TrainableWeightQuantizer object with parameters to use
|
|
40
|
-
for the quantization.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
quantization_config: node quantization config class
|
|
44
|
-
"""
|
|
45
|
-
super().__init__(quantization_config,
|
|
46
|
-
qi.QuantizationTarget.Weights,
|
|
47
|
-
[qi.QuantizationMethod.POWER_OF_TWO, qi.QuantizationMethod.SYMMETRIC])
|
|
48
|
-
self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
|
|
49
|
-
self.threshold_values = quantization_config.weights_quantization_params[C.THRESHOLD]
|
|
50
|
-
self.threshold_shape = np.asarray(self.threshold_values).shape
|
|
51
|
-
self.np_threshold_values = np.reshape(np.asarray(self.threshold_values),
|
|
52
|
-
[-1]) if self.quantization_config.weights_channels_axis else float(
|
|
53
|
-
self.threshold_values)
|
|
54
|
-
|
|
55
|
-
if self.quantization_config.weights_per_channel_threshold and self.quantization_config.weights_channels_axis not in [
|
|
56
|
-
-1, len(self.threshold_shape) - 1]:
|
|
57
|
-
# Tensorflow's fake_quant_with_min_max_vars_per_channel only works on last axis, so
|
|
58
|
-
# need to move the quantization axis to the last axis
|
|
59
|
-
self.perm_vec = list(np.arange(len(self.threshold_shape)))
|
|
60
|
-
self.perm_vec[self.quantization_config.weights_channels_axis] = len(self.threshold_shape) - 1
|
|
61
|
-
self.perm_vec[len(self.threshold_shape) - 1] = self.quantization_config.weights_channels_axis
|
|
62
|
-
else:
|
|
63
|
-
self.perm_vec = None
|
|
64
|
-
|
|
65
|
-
if self.power_of_two:
|
|
66
|
-
self.np_threshold_values = np.power(2.0,
|
|
67
|
-
np.ceil(np.log2(np.maximum(self.np_threshold_values, C.MIN_THRESHOLD))))
|
|
68
|
-
num_bits = self.quantization_config.weights_n_bits
|
|
69
|
-
delta = self.np_threshold_values / np.power(2.0, num_bits - int(C.WEIGHTS_SIGNED))
|
|
70
|
-
min_int = -int(C.WEIGHTS_SIGNED) * (2 ** (num_bits - int(C.WEIGHTS_SIGNED)))
|
|
71
|
-
max_int = (2 ** (num_bits - int(C.WEIGHTS_SIGNED))) - 1
|
|
72
|
-
self.min = delta * min_int
|
|
73
|
-
self.max = delta * max_int
|
|
74
|
-
self.quantizer_parameters = {}
|
|
75
|
-
|
|
76
|
-
def initialize_quantization(self,
|
|
77
|
-
tensor_shape: TensorShape,
|
|
78
|
-
name: str,
|
|
79
|
-
layer: QuantizeWrapper) -> Dict[str, tf.Variable]:
|
|
80
|
-
"""
|
|
81
|
-
Add min and max variables to layer.
|
|
82
|
-
Args:
|
|
83
|
-
tensor_shape: Tensor shape the quantizer quantize.
|
|
84
|
-
name: Prefix of variables names.
|
|
85
|
-
layer: Layer to add the variables to. The variables are saved
|
|
86
|
-
in the layer's scope.
|
|
87
|
-
|
|
88
|
-
Returns:
|
|
89
|
-
Dictionary of new variables.
|
|
90
|
-
"""
|
|
91
|
-
ptq_threshold_tensor = layer.add_weight(
|
|
92
|
-
name + THRESHOLD_TENSOR,
|
|
93
|
-
shape=len(self.np_threshold_values) if self.quantization_config.weights_channels_axis else (),
|
|
94
|
-
initializer=tf.keras.initializers.Constant(1.0),
|
|
95
|
-
trainable=False)
|
|
96
|
-
ptq_threshold_tensor.assign(self.np_threshold_values)
|
|
97
|
-
|
|
98
|
-
fq_min = layer.add_weight(
|
|
99
|
-
name + FQ_MIN,
|
|
100
|
-
shape=len(self.min) if self.quantization_config.weights_channels_axis else (),
|
|
101
|
-
initializer=tf.keras.initializers.Constant(-1.0),
|
|
102
|
-
trainable=False)
|
|
103
|
-
fq_min.assign(self.min)
|
|
104
|
-
|
|
105
|
-
fq_max = layer.add_weight(
|
|
106
|
-
name + FQ_MAX,
|
|
107
|
-
shape=len(self.max) if self.quantization_config.weights_channels_axis else (),
|
|
108
|
-
initializer=tf.keras.initializers.Constant(1.0),
|
|
109
|
-
trainable=False)
|
|
110
|
-
fq_max.assign(self.max)
|
|
111
|
-
|
|
112
|
-
# save the quantizer added parameters for later calculations
|
|
113
|
-
self.quantizer_parameters = {THRESHOLD_TENSOR: ptq_threshold_tensor,
|
|
114
|
-
FQ_MIN: fq_min, FQ_MAX: fq_max}
|
|
115
|
-
return self.quantizer_parameters
|
|
116
|
-
|
|
117
|
-
def __call__(self,
|
|
118
|
-
inputs: tf.Tensor,
|
|
119
|
-
training: bool):
|
|
120
|
-
"""
|
|
121
|
-
Quantize a tensor.
|
|
122
|
-
Args:
|
|
123
|
-
inputs: Input tensor to quantize.
|
|
124
|
-
training: Whether the graph is in training mode.
|
|
125
|
-
weights: Dictionary of weights the quantizer can use to quantize the tensor.
|
|
126
|
-
**kwargs: Additional variables the quantizer may receive.
|
|
127
|
-
|
|
128
|
-
Returns:
|
|
129
|
-
The quantized tensor.
|
|
130
|
-
"""
|
|
131
|
-
|
|
132
|
-
_min = self.quantizer_parameters[FQ_MIN]
|
|
133
|
-
_max = self.quantizer_parameters[FQ_MAX]
|
|
134
|
-
if self.quantization_config.weights_channels_axis:
|
|
135
|
-
if self.perm_vec:
|
|
136
|
-
inputs = tf.transpose(inputs, perm=self.perm_vec)
|
|
137
|
-
q_tensor = tf.quantization.fake_quant_with_min_max_vars_per_channel(inputs, _min, _max,
|
|
138
|
-
num_bits=self.quantization_config.weights_n_bits)
|
|
139
|
-
if self.perm_vec:
|
|
140
|
-
q_tensor = tf.transpose(q_tensor, perm=self.perm_vec)
|
|
141
|
-
else:
|
|
142
|
-
q_tensor = tf.quantization.fake_quant_with_min_max_vars(inputs, _min, _max,
|
|
143
|
-
num_bits=self.quantization_config.weights_n_bits)
|
|
144
|
-
|
|
145
|
-
return q_tensor
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
from model_compression_toolkit.qunatizers_infrastructure.common.base_quantizer import QuantizationTarget, \
|
|
2
|
-
QuantizationMethod
|
|
3
|
-
|
|
4
|
-
from model_compression_toolkit.qunatizers_infrastructure.keras.quantize_wrapper import KerasQuantizationWrapper
|
|
5
|
-
from model_compression_toolkit.qunatizers_infrastructure.keras.base_keras_quantizer import BaseKerasQuantizer
|
|
6
|
-
from model_compression_toolkit.qunatizers_infrastructure.keras.load_model import keras_load_quantized_model
|
|
7
|
-
from model_compression_toolkit.qunatizers_infrastructure.keras.keras_node_quantization_dispatcher import \
|
|
8
|
-
KerasNodeQuantizationDispatcher
|
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
@@ -1,123 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
from typing import List
|
|
17
|
-
from enum import Enum
|
|
18
|
-
|
|
19
|
-
from model_compression_toolkit.core import common
|
|
20
|
-
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig, \
|
|
21
|
-
NodeActivationQuantizationConfig, BaseNodeQuantizationConfig
|
|
22
|
-
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class QuantizationTarget(Enum):
|
|
26
|
-
Activation = 0
|
|
27
|
-
Weights = 1
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class BaseQuantizer:
|
|
31
|
-
def __init__(self,
|
|
32
|
-
quantization_config: BaseNodeQuantizationConfig,
|
|
33
|
-
quantization_target: QuantizationTarget,
|
|
34
|
-
quantization_method: List[QuantizationMethod]):
|
|
35
|
-
"""
|
|
36
|
-
This class is a base quantizer which validate the the provide quantization config and define abstract function which any quantizer need to implment.
|
|
37
|
-
|
|
38
|
-
Args:
|
|
39
|
-
quantization_config: node quantization config class contins all the information above a quantizer.
|
|
40
|
-
quantization_target: A enum which decided the qunaizer tensor type activation or weights.
|
|
41
|
-
quantization_method: A list of "QuantizationMethod" enums which represent the quantizer supported methods.
|
|
42
|
-
"""
|
|
43
|
-
self.quantization_config = quantization_config
|
|
44
|
-
self.quantization_target = quantization_target
|
|
45
|
-
self.quantization_method = quantization_method
|
|
46
|
-
if self.quantization_target == QuantizationTarget.Weights:
|
|
47
|
-
self.validate_weights()
|
|
48
|
-
if self.quantization_config.weights_quantization_method not in quantization_method:
|
|
49
|
-
common.Logger.error(
|
|
50
|
-
f'Quantization method mismatch expected:{quantization_method} and got {self.quantization_config.weights_quantization_method}')
|
|
51
|
-
elif self.quantization_target == QuantizationTarget.Activation:
|
|
52
|
-
self.validate_activation()
|
|
53
|
-
if self.quantization_config.activation_quantization_method not in quantization_method:
|
|
54
|
-
common.Logger.error(
|
|
55
|
-
f'Quantization method mismatch expected:{quantization_method} and got {self.quantization_config.activation_quantization_method}')
|
|
56
|
-
else:
|
|
57
|
-
common.Logger.error(
|
|
58
|
-
f'Unknown Quantization Part:{quantization_target}')
|
|
59
|
-
|
|
60
|
-
def initialize_quantization(self,
|
|
61
|
-
tensor_shape,
|
|
62
|
-
name: str,
|
|
63
|
-
layer):
|
|
64
|
-
"""
|
|
65
|
-
This initializes the quantizer parameters given the parameter name and shape.
|
|
66
|
-
|
|
67
|
-
Args:
|
|
68
|
-
tensor_shape: tensor shape
|
|
69
|
-
name: tensor name
|
|
70
|
-
layer: layer to quantized
|
|
71
|
-
|
|
72
|
-
Returns: None
|
|
73
|
-
|
|
74
|
-
"""
|
|
75
|
-
raise NotImplemented
|
|
76
|
-
|
|
77
|
-
def __call__(self,
|
|
78
|
-
input2quantize,
|
|
79
|
-
training: bool):
|
|
80
|
-
"""
|
|
81
|
-
Quantize a tensor.
|
|
82
|
-
|
|
83
|
-
Args:
|
|
84
|
-
input2quantize: Input tensor to quantize.
|
|
85
|
-
training: Whether the graph is in training mode.
|
|
86
|
-
|
|
87
|
-
Returns:
|
|
88
|
-
The quantized tensor.
|
|
89
|
-
"""
|
|
90
|
-
raise NotImplemented
|
|
91
|
-
|
|
92
|
-
def activation_quantization(self) -> bool:
|
|
93
|
-
"""
|
|
94
|
-
|
|
95
|
-
Returns: A boolean stating is this activation quantizer
|
|
96
|
-
|
|
97
|
-
"""
|
|
98
|
-
return isinstance(self.quantization_config, NodeActivationQuantizationConfig)
|
|
99
|
-
|
|
100
|
-
def weights_quantization(self) -> bool:
|
|
101
|
-
"""
|
|
102
|
-
|
|
103
|
-
Returns: A boolean stating is this weights quantizer
|
|
104
|
-
|
|
105
|
-
"""
|
|
106
|
-
return isinstance(self.quantization_config, NodeWeightsQuantizationConfig)
|
|
107
|
-
|
|
108
|
-
def validate_weights(self) -> None:
|
|
109
|
-
"""
|
|
110
|
-
This function valid the quantize config compare with it parameters.
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
"""
|
|
114
|
-
if self.activation_quantization() or not self.weights_quantization():
|
|
115
|
-
common.Logger.error(f'Expect weight quantization got activation')
|
|
116
|
-
|
|
117
|
-
def validate_activation(self) -> None:
|
|
118
|
-
"""
|
|
119
|
-
This function valid the quantize config compare with it parameters.
|
|
120
|
-
|
|
121
|
-
"""
|
|
122
|
-
if not self.activation_quantization() or self.weights_quantization():
|
|
123
|
-
common.Logger.error(f'Expect activation quantization got weight')
|
|
@@ -1,65 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
from typing import Dict, List
|
|
17
|
-
|
|
18
|
-
from model_compression_toolkit.qunatizers_infrastructure.common.base_quantizer import BaseQuantizer
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class NodeQuantizationDispatcher:
|
|
22
|
-
def __init__(self,
|
|
23
|
-
weight_quantizers: Dict[str, BaseQuantizer] = None,
|
|
24
|
-
activation_quantizers: List[BaseQuantizer] = None):
|
|
25
|
-
"""
|
|
26
|
-
Node quantization dispatcher collects all the quantizer of a given layer.
|
|
27
|
-
|
|
28
|
-
Args:
|
|
29
|
-
weight_quantizers: A dictionary between weight name to it quantizer .
|
|
30
|
-
activation_quantizers: A list of activation quantization one for each layer output.
|
|
31
|
-
"""
|
|
32
|
-
self.weight_quantizers = weight_quantizers if weight_quantizers is not None else dict()
|
|
33
|
-
self.activation_quantizers = activation_quantizers if activation_quantizers is not None else list()
|
|
34
|
-
|
|
35
|
-
def add_weight_quantizer(self, param_name: str, quantizer: BaseQuantizer):
|
|
36
|
-
"""
|
|
37
|
-
This function add a weight quantizer to existing node dispatcher
|
|
38
|
-
|
|
39
|
-
Args:
|
|
40
|
-
param_name: The name of the parameter to quantize
|
|
41
|
-
quantizer: A quantizer.
|
|
42
|
-
|
|
43
|
-
Returns: None
|
|
44
|
-
|
|
45
|
-
"""
|
|
46
|
-
self.weight_quantizers.update({param_name: quantizer})
|
|
47
|
-
|
|
48
|
-
@property
|
|
49
|
-
def is_activation_quantization(self) -> bool:
|
|
50
|
-
"""
|
|
51
|
-
This function check activation quantizer exists in dispatcher.
|
|
52
|
-
Returns: a boolean if activation quantizer exists
|
|
53
|
-
|
|
54
|
-
"""
|
|
55
|
-
return len(self.activation_quantizers) > 0
|
|
56
|
-
|
|
57
|
-
@property
|
|
58
|
-
def is_weights_quantization(self) -> bool:
|
|
59
|
-
"""
|
|
60
|
-
This function check weights quantizer exists in dispatcher.
|
|
61
|
-
|
|
62
|
-
Returns: a boolean if weights quantizer exists
|
|
63
|
-
|
|
64
|
-
"""
|
|
65
|
-
return len(self.weight_quantizers) > 0
|
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
@@ -1,75 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from typing import Dict, Any, List
|
|
16
|
-
|
|
17
|
-
from model_compression_toolkit.core.common import Logger
|
|
18
|
-
from model_compression_toolkit.core.common.constants import FOUND_TF
|
|
19
|
-
from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig
|
|
20
|
-
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
21
|
-
|
|
22
|
-
from model_compression_toolkit.qunatizers_infrastructure.common.base_quantizer import BaseQuantizer, QuantizationTarget
|
|
23
|
-
|
|
24
|
-
if FOUND_TF:
|
|
25
|
-
QUANTIZATION_CONFIG = 'qunatization_config'
|
|
26
|
-
from model_compression_toolkit.qunatizers_infrastructure.keras.config_serialization import config_serialization, \
|
|
27
|
-
config_deserialization
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class BaseKerasQuantizer(BaseQuantizer):
|
|
31
|
-
def __init__(self,
|
|
32
|
-
quantization_config: BaseNodeQuantizationConfig,
|
|
33
|
-
quantization_target: QuantizationTarget,
|
|
34
|
-
quantization_method: List[QuantizationMethod]):
|
|
35
|
-
"""
|
|
36
|
-
This class is a base quantizer which validate the provide quantization config and define abstract function which any quantizer need to implment.
|
|
37
|
-
This class add to the base quantizer get_config and from_config function to enable keras load and save model.
|
|
38
|
-
Args:
|
|
39
|
-
quantization_config: node quantization config class contins all the information above a quantizer.
|
|
40
|
-
quantization_target: A enum which decided the qunaizer tensor type activation or weights.
|
|
41
|
-
quantization_method: A list of enums which represent the quantizer supported methods.
|
|
42
|
-
"""
|
|
43
|
-
super().__init__(quantization_config, quantization_target, quantization_method)
|
|
44
|
-
|
|
45
|
-
def get_config(self) -> Dict[str, Any]:
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
Returns: Configuration of BaseKerasQuantizer.
|
|
49
|
-
|
|
50
|
-
"""
|
|
51
|
-
return {QUANTIZATION_CONFIG: config_serialization(self.quantization_config)}
|
|
52
|
-
|
|
53
|
-
@classmethod
|
|
54
|
-
def from_config(cls, config: dict):
|
|
55
|
-
"""
|
|
56
|
-
|
|
57
|
-
Args:
|
|
58
|
-
config(dict): dictonory of BaseKerasQuantizer Configuration
|
|
59
|
-
|
|
60
|
-
Returns: A BaseKerasQuantizer
|
|
61
|
-
|
|
62
|
-
"""
|
|
63
|
-
config = config.copy()
|
|
64
|
-
quantization_config = config_deserialization(config[QUANTIZATION_CONFIG])
|
|
65
|
-
# Note that a quantizer only receive quantization config and the rest of define hardcoded inside the speficie quantizer.
|
|
66
|
-
return cls(quantization_config=quantization_config)
|
|
67
|
-
|
|
68
|
-
else:
|
|
69
|
-
class BaseKerasQuantizer(BaseQuantizer):
|
|
70
|
-
def __init__(self, quantization_config: BaseNodeQuantizationConfig, quantization_target: QuantizationTarget,
|
|
71
|
-
quantization_method: List[QuantizationMethod]):
|
|
72
|
-
super().__init__(quantization_config, quantization_target, quantization_method)
|
|
73
|
-
Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
74
|
-
'when using BaseKerasQuantizer. '
|
|
75
|
-
'Could not find Tensorflow package.')
|