mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from abc import abstractmethod
|
|
16
|
-
from typing import Tuple, Any, Dict, List, Union
|
|
16
|
+
from typing import Tuple, Any, Dict, List, Union, Callable
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
from networkx import topological_sort
|
|
@@ -25,7 +25,7 @@ from model_compression_toolkit.core.common.back2framework.base_model_builder imp
|
|
|
25
25
|
from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
|
|
26
26
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
|
27
27
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
28
|
-
from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder
|
|
28
|
+
from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder, identity_wrapper
|
|
29
29
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
30
30
|
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder, BufferHolder
|
|
31
31
|
from model_compression_toolkit.core.pytorch.utils import get_working_device
|
|
@@ -65,7 +65,8 @@ def _build_input_tensors_list(node: BaseNode,
|
|
|
65
65
|
def _run_operation(n: BaseNode,
|
|
66
66
|
input_tensors: List,
|
|
67
67
|
op_func: Any,
|
|
68
|
-
quantize_node_activation_fn
|
|
68
|
+
quantize_node_activation_fn,
|
|
69
|
+
is_wrapped: bool) -> Tuple[Union[List,torch.Tensor], Union[List,torch.Tensor]]:
|
|
69
70
|
"""
|
|
70
71
|
Applying the layer (op_func) to the input tensors (input_tensors).
|
|
71
72
|
If quantized is set to True, and the layer's corresponding node (n) has quantization
|
|
@@ -76,6 +77,7 @@ def _run_operation(n: BaseNode,
|
|
|
76
77
|
input_tensors: List of Pytorch tensors that are the layer's inputs.
|
|
77
78
|
op_func: Module/functional to apply to the input tensors.
|
|
78
79
|
quantize_node_activation_fn: quantization function
|
|
80
|
+
is_wrapped : Flag to indicate if layer is already quantization wrapped so no activation is needed
|
|
79
81
|
Returns:
|
|
80
82
|
A tuple of Pytorch tensors. The Module/functional output tensors after applying the
|
|
81
83
|
Module/functional to the input tensors.
|
|
@@ -90,7 +92,7 @@ def _run_operation(n: BaseNode,
|
|
|
90
92
|
|
|
91
93
|
# Add a fake quant node if the node has an activation threshold.
|
|
92
94
|
out_tensors_of_n = out_tensors_of_n_float
|
|
93
|
-
if n.is_activation_quantization_enabled():
|
|
95
|
+
if n.is_activation_quantization_enabled() and not is_wrapped:
|
|
94
96
|
if isinstance(out_tensors_of_n_float, list):
|
|
95
97
|
out_tensors_of_n_float = torch.cat(out_tensors_of_n_float, dim=0)
|
|
96
98
|
out_tensors_of_n = quantize_node_activation_fn(n, out_tensors_of_n_float)
|
|
@@ -142,7 +144,8 @@ class PytorchModel(torch.nn.Module):
|
|
|
142
144
|
graph: Graph,
|
|
143
145
|
append2output: List[Any] = None,
|
|
144
146
|
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
|
145
|
-
return_float_outputs: bool = False
|
|
147
|
+
return_float_outputs: bool = False,
|
|
148
|
+
wrapper: Callable = identity_wrapper):
|
|
146
149
|
"""
|
|
147
150
|
Construct a Pytorch model.
|
|
148
151
|
|
|
@@ -151,6 +154,7 @@ class PytorchModel(torch.nn.Module):
|
|
|
151
154
|
append2output: List of nodes or OutTensor objects.
|
|
152
155
|
fw_info: Framework information (e.g., mapping from layers to their attributes to quantize).
|
|
153
156
|
return_float_outputs: Whether the model returns float tensors or not.
|
|
157
|
+
wrapper: A function wrapper Pytorch Layers.
|
|
154
158
|
"""
|
|
155
159
|
super(PytorchModel, self).__init__()
|
|
156
160
|
self.graph = graph
|
|
@@ -159,6 +163,7 @@ class PytorchModel(torch.nn.Module):
|
|
|
159
163
|
self.append2output = append2output
|
|
160
164
|
self.return_float_outputs = return_float_outputs
|
|
161
165
|
self.fw_info = fw_info
|
|
166
|
+
self.wrapper = wrapper
|
|
162
167
|
self._add_modules()
|
|
163
168
|
|
|
164
169
|
@abstractmethod
|
|
@@ -176,17 +181,21 @@ class PytorchModel(torch.nn.Module):
|
|
|
176
181
|
Output of the node.
|
|
177
182
|
|
|
178
183
|
"""
|
|
179
|
-
raise NotImplemented(f'{self.__class__.__name__}
|
|
184
|
+
raise NotImplemented(f'{self.__class__.__name__} '
|
|
185
|
+
f'have to implement a method for quantization activation nodes.') # pragma: no cover
|
|
180
186
|
|
|
181
187
|
def _add_modules(self):
|
|
182
188
|
for n in self.node_sort:
|
|
183
|
-
if
|
|
189
|
+
if isinstance(n, FunctionalNode):
|
|
190
|
+
# for functional layers
|
|
191
|
+
setattr(self, n.name, self.wrapper(n, n.type))
|
|
192
|
+
else:
|
|
184
193
|
if n.type == BufferHolder:
|
|
185
194
|
self.add_module(n.name, node_builder(n))
|
|
186
195
|
self.get_submodule(n.name). \
|
|
187
196
|
register_buffer(n.name, torch.Tensor(n.get_weights_by_keys(BUFFER)).to(get_working_device()))
|
|
188
197
|
else:
|
|
189
|
-
self.add_module(n.name, node_builder(n))
|
|
198
|
+
self.add_module(n.name, self.wrapper(n, node_builder(n)))
|
|
190
199
|
|
|
191
200
|
def forward(self,
|
|
192
201
|
*args: Any) -> Any:
|
|
@@ -211,7 +220,8 @@ class PytorchModel(torch.nn.Module):
|
|
|
211
220
|
out_tensors_of_n, out_tensors_of_n_float = _run_operation(n,
|
|
212
221
|
input_tensors,
|
|
213
222
|
op_func=op_func,
|
|
214
|
-
quantize_node_activation_fn=self._quantize_node_activations
|
|
223
|
+
quantize_node_activation_fn=self._quantize_node_activations,
|
|
224
|
+
is_wrapped=self.wrapper is not identity_wrapper)
|
|
215
225
|
|
|
216
226
|
if isinstance(out_tensors_of_n, list):
|
|
217
227
|
node_to_output_tensors_dict.update({n: out_tensors_of_n})
|
|
@@ -244,7 +254,7 @@ class PytorchModel(torch.nn.Module):
|
|
|
244
254
|
Returns: Module/functional to apply to the input tensors.
|
|
245
255
|
|
|
246
256
|
"""
|
|
247
|
-
return
|
|
257
|
+
return getattr(self, node.name)
|
|
248
258
|
|
|
249
259
|
|
|
250
260
|
class PyTorchModelBuilder(BaseModelBuilder):
|
|
@@ -256,7 +266,8 @@ class PyTorchModelBuilder(BaseModelBuilder):
|
|
|
256
266
|
graph: common.Graph,
|
|
257
267
|
append2output=None,
|
|
258
268
|
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
|
259
|
-
return_float_outputs: bool = False
|
|
269
|
+
return_float_outputs: bool = False,
|
|
270
|
+
wrapper: Callable = identity_wrapper):
|
|
260
271
|
"""
|
|
261
272
|
|
|
262
273
|
Args:
|
|
@@ -264,6 +275,7 @@ class PyTorchModelBuilder(BaseModelBuilder):
|
|
|
264
275
|
append2output: Nodes to append to model's output.
|
|
265
276
|
fw_info: Information about the specific framework of the model that is built.
|
|
266
277
|
return_float_outputs: Whether the model returns float tensors or not.
|
|
278
|
+
wrapper: A function wrapper Pytorch Layers.
|
|
267
279
|
"""
|
|
268
280
|
|
|
269
281
|
super().__init__(graph,
|
|
@@ -271,6 +283,8 @@ class PyTorchModelBuilder(BaseModelBuilder):
|
|
|
271
283
|
fw_info,
|
|
272
284
|
return_float_outputs)
|
|
273
285
|
|
|
286
|
+
self.wrapper = wrapper
|
|
287
|
+
|
|
274
288
|
def build_model(self) -> Tuple[PytorchModel, UserInformation]:
|
|
275
289
|
"""
|
|
276
290
|
Build a PyTorch model and return it.
|
|
@@ -279,4 +293,5 @@ class PyTorchModelBuilder(BaseModelBuilder):
|
|
|
279
293
|
"""
|
|
280
294
|
return PytorchModel(self.graph,
|
|
281
295
|
self.append2output,
|
|
282
|
-
return_float_outputs=self.return_float_outputs
|
|
296
|
+
return_float_outputs=self.return_float_outputs,
|
|
297
|
+
wrapper=self.wrapper), self.graph.user_info
|
|
@@ -38,8 +38,7 @@ class WrapperQuantizeConfig:
|
|
|
38
38
|
Returns: A List of quantizers for weights quantization.
|
|
39
39
|
|
|
40
40
|
"""
|
|
41
|
-
raise NotImplemented
|
|
42
|
-
|
|
41
|
+
raise NotImplemented # pragma: no cover
|
|
43
42
|
|
|
44
43
|
def get_activation_quantizers(self) -> list:
|
|
45
44
|
"""
|
|
@@ -47,7 +46,7 @@ class WrapperQuantizeConfig:
|
|
|
47
46
|
Returns: A List of quantizers for activation quantization.
|
|
48
47
|
|
|
49
48
|
"""
|
|
50
|
-
raise NotImplemented
|
|
49
|
+
raise NotImplemented # pragma: no cover
|
|
51
50
|
|
|
52
51
|
|
|
53
52
|
|
|
@@ -71,6 +71,7 @@ RELU_POT_BOUND = 8.0
|
|
|
71
71
|
|
|
72
72
|
# Supported TP models names for Pytorch:
|
|
73
73
|
DEFAULT_TP_MODEL = 'default'
|
|
74
|
+
IMX500_TP_MODEL = 'imx500'
|
|
74
75
|
TFLITE_TP_MODEL = 'tflite'
|
|
75
76
|
QNNPACK_TP_MODEL = 'qnnpack'
|
|
76
77
|
|
|
@@ -91,3 +92,7 @@ IN_PROJ_WEIGHT = 'in_proj_weight'
|
|
|
91
92
|
IN_PROJ_BIAS = 'in_proj_bias'
|
|
92
93
|
BIAS_K = 'bias_k'
|
|
93
94
|
BIAS_V = 'bias_v'
|
|
95
|
+
|
|
96
|
+
# # Batch size value for 'reshape' and 'view' operators,
|
|
97
|
+
# # the value is -1 so the batch size is inferred from the length of the array and remaining dimensions.
|
|
98
|
+
BATCH_DIM_VALUE = -1
|
|
@@ -20,6 +20,7 @@ import torch.nn as nn
|
|
|
20
20
|
import operator
|
|
21
21
|
from typing import List
|
|
22
22
|
|
|
23
|
+
from model_compression_toolkit.core.common.logger import Logger
|
|
23
24
|
from model_compression_toolkit.core import common
|
|
24
25
|
from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
|
|
25
26
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
|
@@ -46,32 +47,26 @@ class MHAParams:
|
|
|
46
47
|
# Only batch first network is supported
|
|
47
48
|
if BATCH_FIRST in mha_node.framework_attr.keys():
|
|
48
49
|
if mha_node.framework_attr[BATCH_FIRST] is not True:
|
|
49
|
-
|
|
50
|
+
Logger.error('Only batch first network is supported') # pragma: no cover
|
|
50
51
|
else:
|
|
51
|
-
|
|
52
|
+
Logger.error('Only batch first network is supported') # pragma: no cover
|
|
52
53
|
|
|
53
54
|
# Add Zero Attn feature is Not Implemented
|
|
54
55
|
if ADD_ZERO_ATTN in mha_node.framework_attr.keys():
|
|
55
56
|
if mha_node.framework_attr[ADD_ZERO_ATTN] is not False:
|
|
56
|
-
|
|
57
|
+
Logger.error('Add Zero Attn feature is Not Implemented') # pragma: no cover
|
|
57
58
|
|
|
58
59
|
# Check if Add Bias KV feature is Active
|
|
59
60
|
if BIAS_K and BIAS_V in mha_node.weights.keys():
|
|
60
|
-
if mha_node.weights[BIAS_K] and mha_node.weights[BIAS_V] is not None:
|
|
61
|
-
|
|
61
|
+
if mha_node.weights[BIAS_K] is not None and mha_node.weights[BIAS_V] is not None:
|
|
62
|
+
Logger.error('Add BIAS_KV feature is Not Implemented') # pragma: no cover
|
|
62
63
|
|
|
63
64
|
self.embed_dim = mha_node.framework_attr[EMBED_DIM]
|
|
64
65
|
self.num_heads = mha_node.framework_attr[NUM_HEADS]
|
|
65
66
|
|
|
66
|
-
|
|
67
|
-
self.kdim = mha_node.framework_attr[KEY_DIM]
|
|
68
|
-
else:
|
|
69
|
-
self.kdim = False
|
|
67
|
+
self.kdim = mha_node.framework_attr[KEY_DIM]
|
|
70
68
|
|
|
71
|
-
|
|
72
|
-
self.vdim = mha_node.framework_attr[VALUE_DIM]
|
|
73
|
-
else:
|
|
74
|
-
self.vdim = False
|
|
69
|
+
self.vdim = mha_node.framework_attr[VALUE_DIM]
|
|
75
70
|
|
|
76
71
|
self.qdim = int(self.embed_dim / self.num_heads)
|
|
77
72
|
|
|
@@ -707,7 +702,7 @@ class MultiHeadAttentionDecomposition(common.BaseSubstitution):
|
|
|
707
702
|
"""
|
|
708
703
|
|
|
709
704
|
if mha_node.reuse:
|
|
710
|
-
raise Exception("MCT doesn't support reuse of MultiHeadAttention layer")
|
|
705
|
+
raise Exception("MCT doesn't support reuse of MultiHeadAttention layer") # pragma: no cover
|
|
711
706
|
params = MHAParams(mha_node)
|
|
712
707
|
|
|
713
708
|
# project
|
|
@@ -14,10 +14,13 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from torch import reshape
|
|
16
16
|
import torch
|
|
17
|
+
|
|
18
|
+
from model_compression_toolkit.core.common import Logger
|
|
17
19
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
18
20
|
from model_compression_toolkit.core import common
|
|
19
21
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
20
22
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
23
|
+
from model_compression_toolkit.core.pytorch.constants import BATCH_DIM_VALUE
|
|
21
24
|
|
|
22
25
|
|
|
23
26
|
class ReshapeWithStaticShapes(common.BaseSubstitution):
|
|
@@ -47,14 +50,25 @@ class ReshapeWithStaticShapes(common.BaseSubstitution):
|
|
|
47
50
|
Returns:
|
|
48
51
|
Graph after applying the substitution.
|
|
49
52
|
"""
|
|
53
|
+
# we want the batch size value to infer from the length of the array and remaining dimensions
|
|
54
|
+
if len(node.output_shape) == 1:
|
|
55
|
+
node.output_shape[0][0] = BATCH_DIM_VALUE
|
|
56
|
+
else:
|
|
57
|
+
Logger.error('Reshape or view nodes should have a single output shape') # pragma: no cover
|
|
58
|
+
|
|
50
59
|
# configure the new static output shape attribute
|
|
51
60
|
node.op_call_args = node.output_shape
|
|
52
61
|
|
|
53
62
|
# modify the node input info
|
|
54
63
|
node.input_shape = [node.input_shape[0]]
|
|
64
|
+
|
|
65
|
+
# the first input is the tensor to be reshaped, we want his batch size value to infer
|
|
66
|
+
# from the length of the array and remaining dimensions
|
|
67
|
+
node.input_shape[0][0] = BATCH_DIM_VALUE
|
|
68
|
+
|
|
55
69
|
nodes_to_check = []
|
|
56
70
|
for in_edge in graph.incoming_edges(node):
|
|
57
|
-
if in_edge.sink_index > 0:
|
|
71
|
+
if in_edge.sink_index > 0: # the first input is the tensor to be reshaped
|
|
58
72
|
nodes_to_check.append(in_edge.source_node)
|
|
59
73
|
graph.remove_edge(in_edge.source_node, node)
|
|
60
74
|
for n in nodes_to_check:
|
|
@@ -80,4 +94,4 @@ def clean_graph_from_nodes_without_out_edges(graph: Graph,
|
|
|
80
94
|
graph.remove_edge(in_edge.source_node, node)
|
|
81
95
|
graph.remove_node(node)
|
|
82
96
|
for n in nodes_to_check:
|
|
83
|
-
clean_graph_from_nodes_without_out_edges(graph, n)
|
|
97
|
+
clean_graph_from_nodes_without_out_edges(graph, n)
|
|
@@ -154,9 +154,9 @@ else:
|
|
|
154
154
|
# we raise an exception when trying to use this function.
|
|
155
155
|
def pytorch_kpi_data(*args, **kwargs):
|
|
156
156
|
Logger.critical('Installing torch is mandatory when using pytorch_kpi_data. '
|
|
157
|
-
'Could not find Tensorflow package.')
|
|
157
|
+
'Could not find Tensorflow package.') # pragma: no cover
|
|
158
158
|
|
|
159
159
|
|
|
160
160
|
def pytorch_kpi_data_experimental(*args, **kwargs):
|
|
161
161
|
Logger.critical('Installing torch is mandatory when using pytorch_kpi_data. '
|
|
162
|
-
'Could not find Tensorflow package.')
|
|
162
|
+
'Could not find Tensorflow package.') # pragma: no cover
|
|
@@ -269,9 +269,9 @@ else:
|
|
|
269
269
|
def pytorch_post_training_quantization(*args, **kwargs):
|
|
270
270
|
Logger.critical('Installing Pytorch is mandatory '
|
|
271
271
|
'when using pytorch_post_training_quantization. '
|
|
272
|
-
'Could not find the torch package.')
|
|
272
|
+
'Could not find the torch package.') # pragma: no cover
|
|
273
273
|
|
|
274
274
|
def pytorch_post_training_quantization_mixed_precision(*args, **kwargs):
|
|
275
275
|
Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
276
276
|
'when using pytorch_post_training_quantization_mixed_precision. '
|
|
277
|
-
'Could not find Tensorflow package.')
|
|
277
|
+
'Could not find Tensorflow package.') # pragma: no cover
|
|
@@ -17,6 +17,7 @@ import torch
|
|
|
17
17
|
|
|
18
18
|
from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX
|
|
19
19
|
from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import threshold_is_power_of_two
|
|
20
|
+
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import fix_range_to_include_zero
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
def get_symmetric_quantization_range_and_scale(activation_is_signed: bool,
|
|
@@ -62,9 +63,9 @@ def power_of_two_quantization(activation_n_bits: int,
|
|
|
62
63
|
activation_is_signed = quantization_params.get(SIGNED)
|
|
63
64
|
|
|
64
65
|
if activation_threshold is None or activation_is_signed is None:
|
|
65
|
-
return None
|
|
66
|
+
return None # pragma: no cover
|
|
66
67
|
if not threshold_is_power_of_two(activation_threshold, per_channel=False):
|
|
67
|
-
return None
|
|
68
|
+
return None # pragma: no cover
|
|
68
69
|
|
|
69
70
|
min_value, max_value, scale = get_symmetric_quantization_range_and_scale(activation_is_signed,
|
|
70
71
|
activation_n_bits,
|
|
@@ -90,7 +91,7 @@ def symmetric_quantization(activation_n_bits: int,
|
|
|
90
91
|
activation_is_signed = quantization_params.get(SIGNED)
|
|
91
92
|
|
|
92
93
|
if activation_threshold is None or activation_is_signed is None:
|
|
93
|
-
return None
|
|
94
|
+
return None # pragma: no cover
|
|
94
95
|
|
|
95
96
|
min_value, max_value, scale = get_symmetric_quantization_range_and_scale(activation_is_signed,
|
|
96
97
|
activation_n_bits,
|
|
@@ -115,16 +116,17 @@ def uniform_quantization(activation_n_bits: int,
|
|
|
115
116
|
a, b = quantization_params.get(RANGE_MIN), quantization_params.get(RANGE_MAX)
|
|
116
117
|
|
|
117
118
|
if a is None or b is None:
|
|
118
|
-
return None
|
|
119
|
+
return None # pragma: no cover
|
|
119
120
|
|
|
120
121
|
# fixing quantization range to include 0
|
|
121
122
|
a = 0 if a > 0 else a
|
|
122
123
|
b = 0 if b < 0 else b
|
|
124
|
+
a, b = fix_range_to_include_zero(a, b, activation_n_bits)
|
|
123
125
|
|
|
124
126
|
min_value = 0
|
|
125
127
|
max_value = 2 ** activation_n_bits - 1
|
|
126
128
|
scale = (b - a) / ((2 ** activation_n_bits) - 1)
|
|
127
|
-
zero_point = -
|
|
129
|
+
zero_point = -round(a / scale) # zp has to be positive, and a <=0, so we multiply by -1
|
|
128
130
|
|
|
129
131
|
return lambda x: q(x, min_value, max_value, scale, zero_point)
|
|
130
132
|
|
|
@@ -57,7 +57,7 @@ class PytorchLUTFakeQuant(torch.nn.Module):
|
|
|
57
57
|
Quantized torch Tensor.
|
|
58
58
|
"""
|
|
59
59
|
if self.activation_is_signed is None or self.cluster_centers is None or self.threshold is None:
|
|
60
|
-
return None
|
|
60
|
+
return None # pragma: no cover
|
|
61
61
|
|
|
62
62
|
_quant_output = self.lut_kmeans_quantizer(x)
|
|
63
63
|
return _quant_output
|
|
@@ -17,14 +17,18 @@ from model_compression_toolkit.core.common.target_platform import TargetPlatform
|
|
|
17
17
|
|
|
18
18
|
from model_compression_toolkit.core.tpc_models.default_tpc.target_platform_capabilities import \
|
|
19
19
|
tpc_dict as default_tpc_dict
|
|
20
|
+
from model_compression_toolkit.core.tpc_models.imx500_tpc.target_platform_capabilities import \
|
|
21
|
+
tpc_dict as imx500_tpc_dict
|
|
20
22
|
from model_compression_toolkit.core.tpc_models.tflite_tpc.target_platform_capabilities import \
|
|
21
23
|
tpc_dict as tflite_tpc_dict
|
|
22
24
|
from model_compression_toolkit.core.tpc_models.qnnpack_tpc.target_platform_capabilities import \
|
|
23
25
|
tpc_dict as qnnpack_tpc_dict
|
|
24
|
-
from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL, TFLITE_TP_MODEL,
|
|
26
|
+
from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, TFLITE_TP_MODEL, \
|
|
27
|
+
QNNPACK_TP_MODEL
|
|
25
28
|
from model_compression_toolkit.core.common.constants import LATEST
|
|
26
29
|
|
|
27
30
|
tpc_dict = {DEFAULT_TP_MODEL: default_tpc_dict,
|
|
31
|
+
IMX500_TP_MODEL: imx500_tpc_dict,
|
|
28
32
|
TFLITE_TP_MODEL: tflite_tpc_dict,
|
|
29
33
|
QNNPACK_TP_MODEL: qnnpack_tpc_dict}
|
|
30
34
|
|
|
@@ -35,7 +39,7 @@ def get_target_platform_capabilities(fw_name: str,
|
|
|
35
39
|
"""
|
|
36
40
|
Get a TargetPlatformCapabilities by the target platform model name and the framework name.
|
|
37
41
|
For now, it supports frameworks 'tensorflow' and 'pytorch'. For both of them
|
|
38
|
-
the target platform model can be 'default','tflite', or 'qnnpack'.
|
|
42
|
+
the target platform model can be 'default', 'imx500', 'tflite', or 'qnnpack'.
|
|
39
43
|
|
|
40
44
|
Args:
|
|
41
45
|
fw_name: Framework name of the TargetPlatformCapabilities.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from model_compression_toolkit.core.common.constants import FOUND_TF, FOUND_TORCH
|
|
16
|
+
from model_compression_toolkit.core.tpc_models.imx500_tpc.v1.tp_model import get_tp_model, generate_tp_model, \
|
|
17
|
+
get_op_quantization_configs
|
|
18
|
+
if FOUND_TF:
|
|
19
|
+
from model_compression_toolkit.core.tpc_models.imx500_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_latest
|
|
20
|
+
from model_compression_toolkit.core.tpc_models.imx500_tpc.v1.tpc_keras import generate_keras_tpc
|
|
21
|
+
if FOUND_TORCH:
|
|
22
|
+
from model_compression_toolkit.core.tpc_models.imx500_tpc.v1.tpc_pytorch import get_pytorch_tpc as \
|
|
23
|
+
get_pytorch_tpc_latest
|
|
24
|
+
from model_compression_toolkit.core.tpc_models.imx500_tpc.v1.tpc_pytorch import generate_pytorch_tpc
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from model_compression_toolkit.core.common.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH, LATEST
|
|
17
|
+
|
|
18
|
+
###############################
|
|
19
|
+
# Build Tensorflow TPC models
|
|
20
|
+
###############################
|
|
21
|
+
keras_tpc_models_dict = None
|
|
22
|
+
if FOUND_TF:
|
|
23
|
+
from model_compression_toolkit.core.tpc_models.imx500_tpc.latest import get_keras_tpc_latest
|
|
24
|
+
from model_compression_toolkit.core.tpc_models.imx500_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1
|
|
25
|
+
|
|
26
|
+
# Keras: TPC versioning
|
|
27
|
+
keras_tpc_models_dict = {'v1': get_keras_tpc_v1(),
|
|
28
|
+
LATEST: get_keras_tpc_latest()}
|
|
29
|
+
|
|
30
|
+
###############################
|
|
31
|
+
# Build Pytorch TPC models
|
|
32
|
+
###############################
|
|
33
|
+
pytorch_tpc_models_dict = None
|
|
34
|
+
if FOUND_TORCH:
|
|
35
|
+
from model_compression_toolkit.core.tpc_models.imx500_tpc.latest import get_pytorch_tpc_latest
|
|
36
|
+
from model_compression_toolkit.core.tpc_models.imx500_tpc.v1.tpc_pytorch import \
|
|
37
|
+
get_pytorch_tpc as get_pytorch_tpc_v1
|
|
38
|
+
|
|
39
|
+
# Pytorch: TPC versioning
|
|
40
|
+
pytorch_tpc_models_dict = {'v1': get_pytorch_tpc_v1(),
|
|
41
|
+
LATEST: get_pytorch_tpc_latest()}
|
|
42
|
+
|
|
43
|
+
tpc_dict = {TENSORFLOW: keras_tpc_models_dict,
|
|
44
|
+
PYTORCH: pytorch_tpc_models_dict}
|
|
45
|
+
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
__version__ = 'v1'
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import List, Tuple
|
|
16
|
+
|
|
17
|
+
import model_compression_toolkit as mct
|
|
18
|
+
from model_compression_toolkit.core.common.target_platform import OpQuantizationConfig, TargetPlatformModel
|
|
19
|
+
|
|
20
|
+
tp = mct.target_platform
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_tp_model() -> TargetPlatformModel:
|
|
24
|
+
"""
|
|
25
|
+
A method that generates a default target platform model, with base 8-bit quantization configuration and 8, 4, 2
|
|
26
|
+
bits configuration list for mixed-precision quantization.
|
|
27
|
+
NOTE: in order to generate a target platform model with different configurations but with the same Operators Sets
|
|
28
|
+
(for tests, experiments, etc.), use this method implementation as a test-case, i.e., override the
|
|
29
|
+
'get_op_quantization_configs' method and use its output to call 'generate_tp_model' with your configurations.
|
|
30
|
+
|
|
31
|
+
Returns: A TargetPlatformModel object.
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
base_config, mixed_precision_cfg_list = get_op_quantization_configs()
|
|
35
|
+
return generate_tp_model(default_config=base_config,
|
|
36
|
+
base_config=base_config,
|
|
37
|
+
mixed_precision_cfg_list=mixed_precision_cfg_list,
|
|
38
|
+
name='imx500_tp_model')
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
|
|
42
|
+
"""
|
|
43
|
+
Creates a default configuration object for 8-bit quantization, to be used to set a default TargetPlatformModel.
|
|
44
|
+
In addition, creates a default configuration objects list (with 8, 4 and 2 bit quantization) to be used as
|
|
45
|
+
default configuration for mixed-precision quantization.
|
|
46
|
+
|
|
47
|
+
Returns: An OpQuantizationConfig config object and a list of OpQuantizationConfig objects.
|
|
48
|
+
|
|
49
|
+
"""
|
|
50
|
+
# Create a quantization config.
|
|
51
|
+
# A quantization configuration defines how an operator
|
|
52
|
+
# should be quantized on the modeled hardware:
|
|
53
|
+
eight_bits = tp.OpQuantizationConfig(
|
|
54
|
+
activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
|
|
55
|
+
weights_quantization_method=tp.QuantizationMethod.SYMMETRIC,
|
|
56
|
+
activation_n_bits=8,
|
|
57
|
+
weights_n_bits=8,
|
|
58
|
+
weights_per_channel_threshold=True,
|
|
59
|
+
enable_weights_quantization=True,
|
|
60
|
+
enable_activation_quantization=True,
|
|
61
|
+
quantization_preserving=False,
|
|
62
|
+
fixed_scale=None,
|
|
63
|
+
fixed_zero_point=None,
|
|
64
|
+
weights_multiplier_nbits=None)
|
|
65
|
+
|
|
66
|
+
# To quantize a model using mixed-precision, create
|
|
67
|
+
# a list with more than one OpQuantizationConfig.
|
|
68
|
+
# In this example, we quantize some operations' weights
|
|
69
|
+
# using 2, 4 or 8 bits, and when using 2 or 4 bits, it's possible
|
|
70
|
+
# to quantize the operations' activations using LUT.
|
|
71
|
+
four_bits = eight_bits.clone_and_edit(weights_n_bits=4)
|
|
72
|
+
two_bits = eight_bits.clone_and_edit(weights_n_bits=2)
|
|
73
|
+
mixed_precision_cfg_list = [eight_bits, four_bits, two_bits]
|
|
74
|
+
|
|
75
|
+
return eight_bits, mixed_precision_cfg_list
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def generate_tp_model(default_config: OpQuantizationConfig,
|
|
79
|
+
base_config: OpQuantizationConfig,
|
|
80
|
+
mixed_precision_cfg_list: List[OpQuantizationConfig],
|
|
81
|
+
name: str) -> TargetPlatformModel:
|
|
82
|
+
"""
|
|
83
|
+
Generates TargetPlatformModel with default defined Operators Sets, based on the given base configuration and
|
|
84
|
+
mixed-precision configurations options list.
|
|
85
|
+
|
|
86
|
+
Args
|
|
87
|
+
default_config: A default OpQuantizationConfig to set as the TP model default configuration.
|
|
88
|
+
base_config: An OpQuantizationConfig to set as the TargetPlatformModel base configuration for mixed-precision purposes only.
|
|
89
|
+
mixed_precision_cfg_list: A list of OpQuantizationConfig to be used as the TP model mixed-precision
|
|
90
|
+
quantization configuration options.
|
|
91
|
+
name: The name of the TargetPlatformModel.
|
|
92
|
+
|
|
93
|
+
Returns: A TargetPlatformModel object.
|
|
94
|
+
|
|
95
|
+
"""
|
|
96
|
+
# Create a QuantizationConfigOptions, which defines a set
|
|
97
|
+
# of possible configurations to consider when quantizing a set of operations (in mixed-precision, for example).
|
|
98
|
+
# If the QuantizationConfigOptions contains only one configuration,
|
|
99
|
+
# this configuration will be used for the operation quantization:
|
|
100
|
+
default_configuration_options = tp.QuantizationConfigOptions([default_config])
|
|
101
|
+
|
|
102
|
+
# Create a TargetPlatformModel and set its default quantization config.
|
|
103
|
+
# This default configuration will be used for all operations
|
|
104
|
+
# unless specified otherwise (see OperatorsSet, for example):
|
|
105
|
+
generated_tpc = tp.TargetPlatformModel(default_configuration_options, name=name)
|
|
106
|
+
|
|
107
|
+
# To start defining the model's components (such as operator sets, and fusing patterns),
|
|
108
|
+
# use 'with' the TargetPlatformModel instance, and create them as below:
|
|
109
|
+
with generated_tpc:
|
|
110
|
+
# Create an OperatorsSet to represent a set of operations.
|
|
111
|
+
# Each OperatorsSet has a unique label.
|
|
112
|
+
# If a quantization configuration options is passed, these options will
|
|
113
|
+
# be used for operations that will be attached to this set's label.
|
|
114
|
+
# Otherwise, it will be a configure-less set (used in fusing):
|
|
115
|
+
|
|
116
|
+
# May suit for operations like: Dropout, Reshape, etc.
|
|
117
|
+
tp.OperatorsSet("NoQuantization",
|
|
118
|
+
tp.get_default_quantization_config_options().clone_and_edit(
|
|
119
|
+
enable_weights_quantization=False,
|
|
120
|
+
enable_activation_quantization=False))
|
|
121
|
+
|
|
122
|
+
# Create Mixed-Precision quantization configuration options from the given list of OpQuantizationConfig objects
|
|
123
|
+
mixed_precision_configuration_options = tp.QuantizationConfigOptions(mixed_precision_cfg_list,
|
|
124
|
+
base_config=base_config)
|
|
125
|
+
|
|
126
|
+
# Define operator sets that use mixed_precision_configuration_options:
|
|
127
|
+
conv = tp.OperatorsSet("Conv", mixed_precision_configuration_options)
|
|
128
|
+
fc = tp.OperatorsSet("FullyConnected", mixed_precision_configuration_options)
|
|
129
|
+
|
|
130
|
+
# Define operations sets without quantization configuration
|
|
131
|
+
# options (useful for creating fusing patterns, for example):
|
|
132
|
+
any_relu = tp.OperatorsSet("AnyReLU")
|
|
133
|
+
add = tp.OperatorsSet("Add")
|
|
134
|
+
sub = tp.OperatorsSet("Sub")
|
|
135
|
+
mul = tp.OperatorsSet("Mul")
|
|
136
|
+
div = tp.OperatorsSet("Div")
|
|
137
|
+
prelu = tp.OperatorsSet("PReLU")
|
|
138
|
+
swish = tp.OperatorsSet("Swish")
|
|
139
|
+
sigmoid = tp.OperatorsSet("Sigmoid")
|
|
140
|
+
tanh = tp.OperatorsSet("Tanh")
|
|
141
|
+
|
|
142
|
+
# Combine multiple operators into a single operator to avoid quantization between
|
|
143
|
+
# them. To do this we define fusing patterns using the OperatorsSets that were created.
|
|
144
|
+
# To group multiple sets with regard to fusing, an OperatorSetConcat can be created
|
|
145
|
+
activations_after_conv_to_fuse = tp.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh)
|
|
146
|
+
activations_after_fc_to_fuse = tp.OperatorSetConcat(any_relu, swish, sigmoid)
|
|
147
|
+
any_binary = tp.OperatorSetConcat(add, sub, mul, div)
|
|
148
|
+
|
|
149
|
+
# ------------------- #
|
|
150
|
+
# Fusions
|
|
151
|
+
# ------------------- #
|
|
152
|
+
tp.Fusing([conv, activations_after_conv_to_fuse])
|
|
153
|
+
tp.Fusing([fc, activations_after_fc_to_fuse])
|
|
154
|
+
tp.Fusing([any_binary, any_relu])
|
|
155
|
+
|
|
156
|
+
return generated_tpc
|