mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import copy
|
|
16
|
+
from typing import Callable
|
|
17
|
+
|
|
18
|
+
import keras.models
|
|
19
|
+
import numpy as np
|
|
20
|
+
import tensorflow as tf
|
|
21
|
+
from keras import Sequential
|
|
22
|
+
from keras.layers import Dense, Conv2D, Reshape
|
|
23
|
+
from keras.models import clone_model
|
|
24
|
+
|
|
25
|
+
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
26
|
+
from model_compression_toolkit.core.common import Logger
|
|
27
|
+
from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import FakelyQuantKerasExporter
|
|
28
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import \
|
|
29
|
+
constants as keras_inferable_constants
|
|
30
|
+
|
|
31
|
+
BIAS_INITIALIZER = 'bias_initializer'
|
|
32
|
+
BIAS_REGULARIZER = 'bias_regularizer'
|
|
33
|
+
BIAS_CONSTRAINT = 'bias_constraint'
|
|
34
|
+
ACTIVITY_REGULARIZER = 'activity_regularizer'
|
|
35
|
+
KERNEL_INITIALIZER = 'kernel_initializer'
|
|
36
|
+
KERNEL_REGULARIZER = 'kernel_regularizer'
|
|
37
|
+
KERNEL_CONSTRAINT = 'kernel_constraint'
|
|
38
|
+
KERNEL_SIZE = 'kernel_size'
|
|
39
|
+
PADDING = 'padding'
|
|
40
|
+
STRIDES = 'strides'
|
|
41
|
+
LAYER_NAME = 'name'
|
|
42
|
+
TRAINABLE = 'trainable'
|
|
43
|
+
ACTIVATION = 'activation'
|
|
44
|
+
USE_BIAS = 'use_bias'
|
|
45
|
+
FILTERS = 'filters'
|
|
46
|
+
UNITS = 'units'
|
|
47
|
+
PAD_VALID = 'valid'
|
|
48
|
+
KERNEL = 'kernel'
|
|
49
|
+
|
|
50
|
+
CONV_KERNEL_CHANNEL_AXIS = 3
|
|
51
|
+
CONV_INPUT_CHANNELS_DIM = 4
|
|
52
|
+
|
|
53
|
+
class INT8TFLiteExporter(FakelyQuantKerasExporter):
|
|
54
|
+
"""
|
|
55
|
+
Exporter for INT8 TFLite models.
|
|
56
|
+
The exporter expects to receive an exportable model (where each layer's full quantization parameters
|
|
57
|
+
can be retrieved), and convert it into a quantized model where weights and activations are represented
|
|
58
|
+
as integer data type.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(self,
|
|
62
|
+
model: keras.models.Model,
|
|
63
|
+
is_layer_exportable_fn: Callable,
|
|
64
|
+
save_model_path: str):
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
model: Model to export.
|
|
69
|
+
is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
|
|
70
|
+
save_model_path: Path to save the exported model.
|
|
71
|
+
"""
|
|
72
|
+
super().__init__(model,
|
|
73
|
+
is_layer_exportable_fn,
|
|
74
|
+
save_model_path)
|
|
75
|
+
|
|
76
|
+
self.exported_model = None
|
|
77
|
+
|
|
78
|
+
def _get_pointwise_layer_to_replace_dense(self, wrapped_layer: qi.KerasQuantizationWrapper) -> keras.layers.Layer:
|
|
79
|
+
# First we create a pointwise configuration based on the Dense layer's configuration
|
|
80
|
+
dense_cfg = wrapped_layer.layer.get_config()
|
|
81
|
+
|
|
82
|
+
# List of pw attributes that should be taken from the dense layer as they are.
|
|
83
|
+
pw_attr_list = [LAYER_NAME, ACTIVATION, USE_BIAS, BIAS_CONSTRAINT,
|
|
84
|
+
BIAS_INITIALIZER, BIAS_REGULARIZER, TRAINABLE, ACTIVITY_REGULARIZER,
|
|
85
|
+
KERNEL_INITIALIZER, KERNEL_REGULARIZER, KERNEL_CONSTRAINT]
|
|
86
|
+
|
|
87
|
+
pw_cfg = {attr: dense_cfg[attr] for attr in pw_attr_list}
|
|
88
|
+
|
|
89
|
+
# Use more attributes that are not taken as they are
|
|
90
|
+
pw_cfg.update({KERNEL_SIZE: (1, 1),
|
|
91
|
+
STRIDES: (1, 1),
|
|
92
|
+
PADDING: PAD_VALID,
|
|
93
|
+
FILTERS: dense_cfg[UNITS]})
|
|
94
|
+
|
|
95
|
+
# Create the point-wise layer
|
|
96
|
+
pw_layer = Conv2D(**pw_cfg)
|
|
97
|
+
pw_layer.build(wrapped_layer.layer.input_shape)
|
|
98
|
+
|
|
99
|
+
# Create and set the point-wise weights to assign
|
|
100
|
+
dense_kernel = wrapped_layer.layer.kernel
|
|
101
|
+
pw_weights = []
|
|
102
|
+
pw_kernel = np.reshape(wrapped_layer.get_weights()[0],
|
|
103
|
+
(1, 1, dense_kernel.get_shape()[0], dense_cfg[UNITS]))
|
|
104
|
+
|
|
105
|
+
pw_weights.append(pw_kernel)
|
|
106
|
+
if wrapped_layer.layer.use_bias:
|
|
107
|
+
pw_bias = wrapped_layer.get_weights()[2]
|
|
108
|
+
pw_weights.append(pw_bias)
|
|
109
|
+
|
|
110
|
+
pw_layer.set_weights(pw_weights)
|
|
111
|
+
|
|
112
|
+
# Now that we have the point-wise to replace the dense layer,
|
|
113
|
+
# we need to wrap it using qi.KerasQuantizationWrapper with a new
|
|
114
|
+
# relevant quantizers.
|
|
115
|
+
# Create new kernel quantizer
|
|
116
|
+
pw_kernel_quantizer_cfg = wrapped_layer.weights_quantizers[KERNEL].get_config()
|
|
117
|
+
|
|
118
|
+
# In Conv2D channel axis is 3 and not 1 as in Dense
|
|
119
|
+
pw_kernel_quantizer_cfg[keras_inferable_constants.CHANNEL_AXIS] = CONV_KERNEL_CHANNEL_AXIS
|
|
120
|
+
|
|
121
|
+
# Unquantized weight to conv layer has 4 dimensions (unlike dense which varies)
|
|
122
|
+
pw_kernel_quantizer_cfg[keras_inferable_constants.INPUT_RANK] = CONV_INPUT_CHANNELS_DIM
|
|
123
|
+
|
|
124
|
+
assert isinstance(pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD], np.ndarray), f'Expected to find threshold which is a numpy array, but found: {type(pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD])}'
|
|
125
|
+
pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD] = list(pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD])
|
|
126
|
+
|
|
127
|
+
# Now that we have the point-wise quantizer we can instantiate it
|
|
128
|
+
quantizer_class = type(wrapped_layer.weights_quantizers[KERNEL])
|
|
129
|
+
pw_quantizer = quantizer_class(**pw_kernel_quantizer_cfg)
|
|
130
|
+
pw_weights_quantizers = copy.deepcopy(wrapped_layer.weights_quantizers)
|
|
131
|
+
pw_weights_quantizers[KERNEL] = pw_quantizer
|
|
132
|
+
|
|
133
|
+
# Wrap pw with the new quantizers (the activation is not affected thus we take the Dense quantizers)
|
|
134
|
+
wrapped_pw = qi.KerasQuantizationWrapper(pw_layer,
|
|
135
|
+
pw_weights_quantizers,
|
|
136
|
+
wrapped_layer.activation_quantizers)
|
|
137
|
+
|
|
138
|
+
# Compute the shape that the input to the new layer should be reshaped into
|
|
139
|
+
# Example: Dense kernel with the following shape (3, 20) expects to have input with the
|
|
140
|
+
# next dimensions (BATCH_SIZE, x0, x1, ..., xn, 20).
|
|
141
|
+
# Conv layer expects 4-rank input. Thus, the input is reshaped to (BATCH_SIZE, 1, x0*x1*...*xn, 20)
|
|
142
|
+
dim = wrapped_layer.layer.input_shape[1:-1]
|
|
143
|
+
target_shape = (1, int(np.prod(dim))) + (dense_kernel.get_shape()[0],)
|
|
144
|
+
|
|
145
|
+
return Sequential([
|
|
146
|
+
Reshape(target_shape=target_shape),
|
|
147
|
+
wrapped_pw,
|
|
148
|
+
Reshape(wrapped_layer.layer.output_shape[1:])
|
|
149
|
+
])
|
|
150
|
+
|
|
151
|
+
def export(self) -> None:
|
|
152
|
+
"""
|
|
153
|
+
Export a fully quantized model to its int8 tflite model.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def _substitute_model(wrapped_layer: qi.KerasQuantizationWrapper) -> keras.layers.Layer:
|
|
157
|
+
assert self.is_layer_exportable_fn(
|
|
158
|
+
wrapped_layer), f'Layer {wrapped_layer.get_config()} did not pass validation'
|
|
159
|
+
|
|
160
|
+
# In order to support dense quantization using per-channel quantization (which is
|
|
161
|
+
# unsupported in TFLITE int models) we substitute each dense layer to its equivalent
|
|
162
|
+
# point-wise convolution.
|
|
163
|
+
if isinstance(wrapped_layer.layer, Dense):
|
|
164
|
+
return self._get_pointwise_layer_to_replace_dense(wrapped_layer)
|
|
165
|
+
|
|
166
|
+
return wrapped_layer
|
|
167
|
+
|
|
168
|
+
# Transform the model to a new model that can be converted to int8 models.
|
|
169
|
+
# For example: replace dense layers with point-wise layers (to support per-channel quantization)
|
|
170
|
+
self.transformed_model = clone_model(self.model,
|
|
171
|
+
clone_function=_substitute_model)
|
|
172
|
+
|
|
173
|
+
# Convert model to int8 representation
|
|
174
|
+
converter = tf.lite.TFLiteConverter.from_keras_model(self.transformed_model)
|
|
175
|
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
|
176
|
+
self.exported_model = converter.convert()
|
|
177
|
+
|
|
178
|
+
Logger.info(f'Exporting INT8 tflite model to: {self.save_model_path}')
|
|
179
|
+
with open(self.save_model_path, 'wb') as f:
|
|
180
|
+
f.write(self.exported_model)
|
|
@@ -15,41 +15,59 @@
|
|
|
15
15
|
from enum import Enum
|
|
16
16
|
from typing import Callable
|
|
17
17
|
|
|
18
|
-
import keras
|
|
19
|
-
|
|
20
18
|
from model_compression_toolkit.core.common import Logger
|
|
21
|
-
from model_compression_toolkit.
|
|
22
|
-
FakelyQuantTFLiteExporter
|
|
19
|
+
from model_compression_toolkit.core.common.constants import FOUND_TF
|
|
23
20
|
|
|
24
21
|
|
|
25
22
|
class TFLiteExportMode(Enum):
|
|
26
23
|
FAKELY_QUANT = 0
|
|
24
|
+
INT8 = 1
|
|
25
|
+
|
|
26
|
+
if FOUND_TF:
|
|
27
|
+
import keras
|
|
28
|
+
from model_compression_toolkit.exporter.model_exporter.tflite.fakely_quant_tflite_exporter import FakelyQuantTFLiteExporter
|
|
29
|
+
from model_compression_toolkit.exporter.model_exporter.tflite.int8_tflite_exporter import INT8TFLiteExporter
|
|
30
|
+
from model_compression_toolkit.exporter.model_wrapper.keras.validate_layer import is_keras_layer_exportable
|
|
27
31
|
|
|
32
|
+
def tflite_export_model(model: keras.models.Model,
|
|
33
|
+
save_model_path: str,
|
|
34
|
+
mode: TFLiteExportMode = TFLiteExportMode.FAKELY_QUANT,
|
|
35
|
+
is_layer_exportable_fn: Callable = is_keras_layer_exportable
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Export a Keras quantized model to a tflite model.
|
|
39
|
+
The model will be saved to the path in save_model_path.
|
|
40
|
+
Mode can be used for different exported files. Currently, tflite_export_model
|
|
41
|
+
supports TFLiteExportMode.FAKELY_QUANT (where weights and activations are
|
|
42
|
+
float fakely-quantized values), and TFLiteExportMode.INT8 (where weights
|
|
43
|
+
and activations are represented using 8bits integers).
|
|
28
44
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
Prepare and return fully quantized model for export. Save exported model to
|
|
35
|
-
a path if passed.
|
|
45
|
+
Args:
|
|
46
|
+
model: Model to export.
|
|
47
|
+
is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
|
|
48
|
+
mode: Mode to export the model according to.
|
|
49
|
+
save_model_path: Path to save the model.
|
|
36
50
|
|
|
37
|
-
|
|
38
|
-
model: Model to export.
|
|
39
|
-
is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
|
|
40
|
-
mode: Mode to export the model according to.
|
|
41
|
-
save_model_path: Path to save the model.
|
|
51
|
+
"""
|
|
42
52
|
|
|
43
|
-
|
|
53
|
+
if mode == TFLiteExportMode.FAKELY_QUANT:
|
|
54
|
+
exporter = FakelyQuantTFLiteExporter(model,
|
|
55
|
+
is_layer_exportable_fn,
|
|
56
|
+
save_model_path)
|
|
57
|
+
elif mode == TFLiteExportMode.INT8:
|
|
58
|
+
exporter = INT8TFLiteExporter(model,
|
|
59
|
+
is_layer_exportable_fn,
|
|
60
|
+
save_model_path)
|
|
44
61
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
62
|
+
else:
|
|
63
|
+
Logger.critical(
|
|
64
|
+
f'Unsupported mode was used {mode.name} to export TFLite model.'
|
|
65
|
+
f' Please see API for supported modes.') # pragma: no cover
|
|
49
66
|
|
|
50
|
-
|
|
51
|
-
Logger.critical(
|
|
52
|
-
f'Unsupported mode was used {mode.name} to export TFLite model.'
|
|
53
|
-
f' Please see API for supported modes.')
|
|
67
|
+
exporter.export()
|
|
54
68
|
|
|
55
|
-
|
|
69
|
+
else:
|
|
70
|
+
def tflite_export_model(*args, **kwargs):
|
|
71
|
+
Logger.error('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
72
|
+
'when using tflite_export_model. '
|
|
73
|
+
'Could not find some or all of TensorFlow packages.') # pragma: no cover
|
|
@@ -13,8 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from model_compression_toolkit.
|
|
16
|
+
from model_compression_toolkit.exporter.model_wrapper.keras.validate_layer import is_keras_layer_exportable
|
|
17
|
+
from model_compression_toolkit.exporter.model_wrapper.keras.builder.fully_quantized_model_builder import get_exportable_keras_model
|
|
17
18
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
from model_compression_toolkit.exporter.model_wrapper.keras.builder.fully_quantized_model_builder import get_exportable_keras_model
|
|
19
|
+
from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable
|
|
20
|
+
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
|
model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -12,157 +12,54 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
from typing import Tuple
|
|
15
16
|
|
|
16
|
-
import tensorflow as tf
|
|
17
|
-
import tensorflow_model_optimization.quantization.keras.graph_transformations.model_transformer as mt
|
|
18
|
-
from keras.layers import TFOpLambda
|
|
19
|
-
from keras.models import Model
|
|
20
|
-
from tensorflow.python.util.object_identity import Reference as TFReference
|
|
21
|
-
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_configs import \
|
|
22
|
-
NoOpQuantizeConfig
|
|
23
|
-
from typing import List, Tuple, Dict, Any
|
|
24
|
-
|
|
25
|
-
from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
|
|
26
17
|
|
|
18
|
+
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
27
19
|
from model_compression_toolkit.core import common
|
|
28
|
-
from model_compression_toolkit.core.common import
|
|
20
|
+
from model_compression_toolkit.core.common import Graph, Logger
|
|
21
|
+
from model_compression_toolkit.core.common.constants import FOUND_TF
|
|
29
22
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
30
|
-
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder, \
|
|
31
|
-
is_layer_fake_quant, get_node_name_from_layer
|
|
32
|
-
from model_compression_toolkit.core.keras.quantizer.input_layer_quantize_transform import InputLayerWrapperTransform
|
|
33
|
-
|
|
34
|
-
from model_compression_toolkit.exporter.model_wrapper.keras.builder.quantize_config_to_node import \
|
|
35
|
-
get_quantization_config
|
|
36
|
-
from model_compression_toolkit.exporter.model_wrapper.keras.quantize_configs.activation_quantize_config import \
|
|
37
|
-
ActivationQuantizeConfig
|
|
38
|
-
from model_compression_toolkit.exporter.model_wrapper.keras.quantize_configs.weights_activation_quantize_config \
|
|
39
|
-
import \
|
|
40
|
-
WeightsActivationQuantizeConfig
|
|
41
|
-
from model_compression_toolkit.exporter.model_wrapper.keras.quantize_configs.weights_quantize_config import \
|
|
42
|
-
WeightsQuantizeConfig
|
|
43
|
-
from model_compression_toolkit.exporter.model_wrapper.keras.extended_quantize_wrapper import ExtendedQuantizeWrapper
|
|
44
|
-
from model_compression_toolkit.exporter.model_wrapper.keras.quantizers.fq_quantizer import FakeQuantQuantizer
|
|
45
|
-
from model_compression_toolkit.exporter.model_wrapper.keras.quantizers.weights_uniform_quantizer import \
|
|
46
|
-
WeightsUniformQuantizer
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def get_exportable_keras_model(graph: Graph) -> tf.keras.models.Model:
|
|
50
|
-
"""
|
|
51
|
-
Convert graph to an exportable Keras model (model with all quantization parameters).
|
|
52
|
-
An exportable model can then be exported using model_exporter, to retrieve the
|
|
53
|
-
final exported model.
|
|
54
|
-
|
|
55
|
-
Args:
|
|
56
|
-
graph: Graph to convert to an exportable Keras model.
|
|
57
23
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
24
|
+
if FOUND_TF:
|
|
25
|
+
import tensorflow as tf
|
|
26
|
+
from tensorflow.keras.layers import Layer
|
|
27
|
+
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
|
28
|
+
from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizers import get_quantization_quantizers
|
|
61
29
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
class FullyQuantizedKerasModelBuilder(KerasModelBuilder):
|
|
66
|
-
"""
|
|
67
|
-
Builder of exportable Keras models (fully quantized).
|
|
68
|
-
"""
|
|
69
|
-
|
|
70
|
-
def __init__(self,
|
|
71
|
-
graph: common.Graph):
|
|
30
|
+
def _get_wrapper(node: common.BaseNode,
|
|
31
|
+
layer: Layer) -> qi.KerasQuantizationWrapper:
|
|
72
32
|
"""
|
|
73
|
-
|
|
33
|
+
A function which takes a computational graph node and a keras layer and perform the quantization wrapping
|
|
74
34
|
Args:
|
|
75
|
-
|
|
76
|
-
|
|
35
|
+
n: A node of mct graph.
|
|
36
|
+
layer: A keras layer
|
|
77
37
|
|
|
78
|
-
|
|
38
|
+
Returns: Wrapped layer with weights quantizers and activation quantizers
|
|
79
39
|
|
|
80
|
-
def _quantize_node_activations(self,
|
|
81
|
-
node: BaseNode,
|
|
82
|
-
input_tensors: List[TFReference]) -> List[TFReference]:
|
|
83
40
|
"""
|
|
84
|
-
|
|
41
|
+
weights_quantizers, activation_quantizers = get_quantization_quantizers(node)
|
|
42
|
+
return qi.KerasQuantizationWrapper(layer, weights_quantizers, activation_quantizers)
|
|
85
43
|
|
|
86
|
-
Args:
|
|
87
|
-
node: Node to quantize its outputs.
|
|
88
|
-
input_tensors: Input tensors of the node.
|
|
89
|
-
|
|
90
|
-
Returns:
|
|
91
|
-
Output of the node.
|
|
92
|
-
|
|
93
|
-
"""
|
|
94
|
-
return input_tensors
|
|
95
44
|
|
|
96
|
-
def
|
|
45
|
+
def get_exportable_keras_model(graph: Graph) -> Tuple[tf.keras.models.Model, UserInformation]:
|
|
97
46
|
"""
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
"""
|
|
102
|
-
model, user_info = super().build_model()
|
|
103
|
-
|
|
104
|
-
def _wrap_layer_with_quantize_config(layer):
|
|
105
|
-
|
|
106
|
-
node = self.oh.layer_to_node_dict.get(layer)
|
|
47
|
+
Convert graph to an exportable Keras model (model with all quantization parameters).
|
|
48
|
+
An exportable model can then be exported using model_exporter, to retrieve the
|
|
49
|
+
final exported model.
|
|
107
50
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
layer_output_shape = layer.output_shape if (node.reuse_group is None) else None
|
|
111
|
-
# For now, we do not support reused TFOpLambda layers.
|
|
112
|
-
if isinstance(layer, TFOpLambda) and layer_output_shape is None:
|
|
113
|
-
Logger.error(
|
|
114
|
-
f'Output shape must be inferred to use ExtendedQuantizeWrapper, but it seems that TFOpLambda '
|
|
115
|
-
f'layer {layer.name} has no output shape. If it is a reused layer, MCT does not support '
|
|
116
|
-
f'reused TFOpLambda layers for now.')
|
|
117
|
-
return ExtendedQuantizeWrapper(layer, get_quantization_config(node), layer_output_shape)
|
|
118
|
-
|
|
119
|
-
elif is_layer_fake_quant(layer):
|
|
120
|
-
return layer
|
|
121
|
-
|
|
122
|
-
else:
|
|
123
|
-
raise Exception(
|
|
124
|
-
f'Mismatch between keras model and graph cant find node named: '
|
|
125
|
-
f'{get_node_name_from_layer(layer)}')
|
|
126
|
-
|
|
127
|
-
# clone each layer in the model and apply _wrap_layer_with_quantize_config to the layer.
|
|
128
|
-
model = tf.keras.models.clone_model(model,
|
|
129
|
-
input_tensors=None,
|
|
130
|
-
clone_function=_wrap_layer_with_quantize_config)
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
# We use a model transformer to wrap the input layer with QuantizeWrapper.
|
|
134
|
-
# A model transformer allows to modify a layer in an existing model, by applying the given list of
|
|
135
|
-
# transformers on the model (in this case,
|
|
136
|
-
# we only apply single transformer - InputLayerQuantizeTransform)
|
|
137
|
-
model_inputs = self.graph.get_inputs()
|
|
138
|
-
|
|
139
|
-
input_transformer = mt.ModelTransformer(model, [InputLayerWrapperTransform(inp,
|
|
140
|
-
get_quantization_config(inp),
|
|
141
|
-
self.get_custom_objects(),
|
|
142
|
-
ExtendedQuantizeWrapper)
|
|
143
|
-
for inp in model_inputs])
|
|
144
|
-
|
|
145
|
-
model = input_transformer.transform()[0]
|
|
146
|
-
|
|
147
|
-
return model, user_info
|
|
148
|
-
|
|
149
|
-
@staticmethod
|
|
150
|
-
def get_custom_objects() -> Dict[str, Any]:
|
|
151
|
-
"""
|
|
152
|
-
|
|
153
|
-
Returns: Dictionary of custom objects needed to load this model builder's output.
|
|
51
|
+
Args:
|
|
52
|
+
graph: Graph to convert to an exportable Keras model.
|
|
154
53
|
|
|
54
|
+
Returns:
|
|
55
|
+
Exportable Keras model and user information.
|
|
155
56
|
"""
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
57
|
+
exportable_model, user_info = KerasModelBuilder(graph=graph,
|
|
58
|
+
wrapper=_get_wrapper).build_model()
|
|
59
|
+
exportable_model.trainable = False
|
|
60
|
+
return exportable_model, user_info
|
|
61
|
+
else:
|
|
62
|
+
def get_exportable_keras_model(*args, **kwargs): # pragma: no cover
|
|
63
|
+
Logger.error('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
64
|
+
'when using get_exportable_keras_model. '
|
|
65
|
+
'Could not find Tensorflow package.')
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import Dict, Any
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit.core.common import BaseNode, Logger
|
|
18
|
+
from model_compression_toolkit.core.common.constants import THRESHOLD, RANGE_MIN, RANGE_MAX, SIGNED, CLUSTER_CENTERS, SCALE_PER_CHANNEL
|
|
19
|
+
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
20
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import get_inferable_quantizer_class
|
|
22
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer import BaseKerasInferableQuantizer
|
|
23
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import constants as qi_keras_consts
|
|
24
|
+
|
|
25
|
+
def get_inferable_quantizer_kwargs(node: BaseNode,
|
|
26
|
+
quantization_target: QuantizationTarget) -> Dict[str, Any]:
|
|
27
|
+
"""
|
|
28
|
+
Get the quantization parameters for an inferable quantizer.
|
|
29
|
+
Args:
|
|
30
|
+
node: The node for which the quantizer is being created.
|
|
31
|
+
quantization_target: The target of the quantization (weights or activations).
|
|
32
|
+
Returns:
|
|
33
|
+
The quantization parameters as a dictionary.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
if quantization_target == QuantizationTarget.Weights:
|
|
37
|
+
# Get the weights quantization configuration for the node
|
|
38
|
+
node_w_qc = node.final_weights_quantization_cfg
|
|
39
|
+
quantization_method = node_w_qc.weights_quantization_method
|
|
40
|
+
|
|
41
|
+
# Return the appropriate quantization parameters based on the quantization method
|
|
42
|
+
if quantization_method in [QuantizationMethod.POWER_OF_TWO,
|
|
43
|
+
QuantizationMethod.SYMMETRIC]:
|
|
44
|
+
return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
|
|
45
|
+
qi_keras_consts.THRESHOLD: list(node_w_qc.weights_quantization_params[THRESHOLD].flatten()),
|
|
46
|
+
qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
|
|
47
|
+
qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
|
|
48
|
+
qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[THRESHOLD].shape)}
|
|
49
|
+
|
|
50
|
+
elif quantization_method in [QuantizationMethod.UNIFORM]:
|
|
51
|
+
return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
|
|
52
|
+
qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
|
|
53
|
+
qi_keras_consts.MIN_RANGE: list(node_w_qc.weights_quantization_params[RANGE_MIN].flatten()),
|
|
54
|
+
qi_keras_consts.MAX_RANGE: list(node_w_qc.weights_quantization_params[RANGE_MAX].flatten()),
|
|
55
|
+
qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
|
|
56
|
+
qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[RANGE_MIN].shape)}
|
|
57
|
+
|
|
58
|
+
elif quantization_method in [QuantizationMethod.LUT_SYM_QUANTIZER, QuantizationMethod.LUT_POT_QUANTIZER]:
|
|
59
|
+
return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
|
|
60
|
+
qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
|
|
61
|
+
qi_keras_consts.CLUSTER_CENTERS: node_w_qc.weights_quantization_params[CLUSTER_CENTERS],
|
|
62
|
+
qi_keras_consts.THRESHOLD: list(node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten()),
|
|
63
|
+
qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
|
|
64
|
+
# TODO: how to pass multiplier nbits and eps for a specific node?
|
|
65
|
+
qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].shape)}
|
|
66
|
+
|
|
67
|
+
else:
|
|
68
|
+
Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
|
|
69
|
+
|
|
70
|
+
elif quantization_target == QuantizationTarget.Activation:
|
|
71
|
+
# Get the activation quantization configuration for the node
|
|
72
|
+
node_qc = node.final_activation_quantization_cfg
|
|
73
|
+
quantization_method = node_qc.activation_quantization_method
|
|
74
|
+
|
|
75
|
+
# Return the appropriate quantization parameters based on the quantization method
|
|
76
|
+
if quantization_method in [QuantizationMethod.POWER_OF_TWO,
|
|
77
|
+
QuantizationMethod.SYMMETRIC]:
|
|
78
|
+
return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
|
|
79
|
+
# In activation quantization is per-tensor only - thus we hold the threshold as a list with a len of 1
|
|
80
|
+
qi_keras_consts.THRESHOLD: [node_qc.activation_quantization_params[THRESHOLD]],
|
|
81
|
+
qi_keras_consts.SIGNED: node_qc.activation_quantization_params[SIGNED]}
|
|
82
|
+
|
|
83
|
+
elif quantization_method in [QuantizationMethod.UNIFORM]:
|
|
84
|
+
return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
|
|
85
|
+
# In activation quantization is per-tensor only - thus we hold the min/max as a list with a len of 1
|
|
86
|
+
qi_keras_consts.MIN_RANGE: [node_qc.activation_quantization_params[RANGE_MIN]],
|
|
87
|
+
qi_keras_consts.MAX_RANGE: [node_qc.activation_quantization_params[RANGE_MAX]]}
|
|
88
|
+
|
|
89
|
+
elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER]:
|
|
90
|
+
return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
|
|
91
|
+
qi_keras_consts.SIGNED: node_qc.activation_quantization_params[SIGNED],
|
|
92
|
+
qi_keras_consts.CLUSTER_CENTERS: node_qc.activation_quantization_params[CLUSTER_CENTERS],
|
|
93
|
+
qi_keras_consts.THRESHOLD: [node_qc.activation_quantization_params[THRESHOLD]]
|
|
94
|
+
# TODO: how to pass multiplier nbits and eps for a specific node?
|
|
95
|
+
}
|
|
96
|
+
else:
|
|
97
|
+
Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
|
|
98
|
+
else:
|
|
99
|
+
Logger.critical(f'{quantization_target} is not supported') # pragma: no cover
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def get_weights_quantizer_for_node(node: BaseNode) -> BaseKerasInferableQuantizer:
|
|
103
|
+
"""
|
|
104
|
+
Get weights quantizer for a node.
|
|
105
|
+
Args:
|
|
106
|
+
node: Node to create a weight quantizer for.
|
|
107
|
+
Returns:
|
|
108
|
+
Quantizer for the node's weights.
|
|
109
|
+
"""
|
|
110
|
+
if node.final_weights_quantization_cfg is None:
|
|
111
|
+
Logger.critical(f'Can not set quantizer for a node with no final weights quantization configuration') # pragma:
|
|
112
|
+
# no cover
|
|
113
|
+
node_w_qc = node.final_weights_quantization_cfg
|
|
114
|
+
weights_quantization_method = node_w_qc.weights_quantization_method
|
|
115
|
+
|
|
116
|
+
quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Weights,
|
|
117
|
+
weights_quantization_method,
|
|
118
|
+
BaseKerasInferableQuantizer)
|
|
119
|
+
kwargs = get_inferable_quantizer_kwargs(node, QuantizationTarget.Weights)
|
|
120
|
+
|
|
121
|
+
return quantier_for_node(**kwargs)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def get_activations_quantizer_for_node(node: BaseNode) -> BaseKerasInferableQuantizer:
|
|
125
|
+
"""
|
|
126
|
+
Get activation quantizer for a node.
|
|
127
|
+
Args:
|
|
128
|
+
node: Node to create an activation quantizer for.
|
|
129
|
+
Returns:
|
|
130
|
+
Quantizer for the node's activations.
|
|
131
|
+
"""
|
|
132
|
+
if node.final_activation_quantization_cfg is None:
|
|
133
|
+
Logger.critical(f'Can not set quantizer for a node with no final activation quantization configuration') #
|
|
134
|
+
# pragma: no cover
|
|
135
|
+
node_act_qc = node.final_activation_quantization_cfg
|
|
136
|
+
activation_quantization_method = node_act_qc.activation_quantization_method
|
|
137
|
+
|
|
138
|
+
quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Activation,
|
|
139
|
+
activation_quantization_method,
|
|
140
|
+
BaseKerasInferableQuantizer)
|
|
141
|
+
kwargs = get_inferable_quantizer_kwargs(node, QuantizationTarget.Activation)
|
|
142
|
+
|
|
143
|
+
return quantier_for_node(**kwargs)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import Dict, List, Tuple
|
|
16
|
+
from model_compression_toolkit.core.common import BaseNode
|
|
17
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
18
|
+
from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \
|
|
19
|
+
get_weights_quantizer_for_node, get_activations_quantizer_for_node
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_quantization_quantizers(node: BaseNode) -> Tuple[Dict, List]:
|
|
23
|
+
"""
|
|
24
|
+
Create quantizers to wrap a layer for its corresponding node.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
node: Node to create quantizers for.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
weight_quantizers: A dictionary between a weight's name to its quantizer.
|
|
31
|
+
activation_quantizers: A list of activations quantization, one for each layer output.
|
|
32
|
+
"""
|
|
33
|
+
weight_quantizers = {}
|
|
34
|
+
activation_quantizers = []
|
|
35
|
+
|
|
36
|
+
if node.is_weights_quantization_enabled():
|
|
37
|
+
weight_attrs = DEFAULT_KERAS_INFO.get_kernel_op_attributes(node.type)
|
|
38
|
+
weight_quantizer = get_weights_quantizer_for_node(node)
|
|
39
|
+
for attr in weight_attrs:
|
|
40
|
+
weight_quantizers[attr] = weight_quantizer
|
|
41
|
+
|
|
42
|
+
if node.is_activation_quantization_enabled():
|
|
43
|
+
num_of_outputs = len(node.output_shape) if isinstance(node.output_shape, list) else 1
|
|
44
|
+
activation_quantizers = [get_activations_quantizer_for_node(node)] * num_of_outputs
|
|
45
|
+
|
|
46
|
+
return weight_quantizers, activation_quantizers
|