mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import List
|
|
16
|
+
from model_compression_toolkit.core.common import BaseNode, Logger
|
|
17
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
|
|
18
|
+
TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig, TrainableQuantizerCandidateConfig
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_trainable_quantizer_weights_config(
|
|
22
|
+
n: BaseNode,
|
|
23
|
+
weights_quantization_candidates: List[TrainableQuantizerCandidateConfig] = None
|
|
24
|
+
) -> TrainableQuantizerWeightsConfig:
|
|
25
|
+
"""
|
|
26
|
+
Returns the relevant configuration for weights trainable quantizer
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
n: BaseNode - the node to build a trainable quantizer from.
|
|
30
|
+
weights_quantization_candidates: A list of weights quantizer config candidates.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
TrainableQuantizerWeightsConfig: an object that contains the quantizer configuration
|
|
34
|
+
"""
|
|
35
|
+
if n.final_weights_quantization_cfg is None:
|
|
36
|
+
Logger.error(f'Node must have final_weights_quantization_cfg in order to build quantizer configuration') # pragma: no cover
|
|
37
|
+
|
|
38
|
+
final_cfg = n.final_weights_quantization_cfg
|
|
39
|
+
return TrainableQuantizerWeightsConfig(final_cfg.weights_quantization_method,
|
|
40
|
+
final_cfg.weights_n_bits,
|
|
41
|
+
final_cfg.weights_quantization_params,
|
|
42
|
+
final_cfg.enable_weights_quantization,
|
|
43
|
+
final_cfg.weights_channels_axis,
|
|
44
|
+
final_cfg.weights_per_channel_threshold,
|
|
45
|
+
final_cfg.min_threshold,
|
|
46
|
+
weights_quantization_candidates)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_trainable_quantizer_activation_config(
|
|
50
|
+
n: BaseNode,
|
|
51
|
+
activation_quantization_candidates: List[TrainableQuantizerCandidateConfig] = None
|
|
52
|
+
) -> TrainableQuantizerActivationConfig:
|
|
53
|
+
"""
|
|
54
|
+
Returns configuration for activation trainable quantizer
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
n: BaseNode - the node to build a trainable quantizer from.
|
|
58
|
+
activation_quantization_candidates: A list of activation quantizer candidates config.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
TrainableQuantizerActivationConfig - an object that contains the quantizer configuration
|
|
62
|
+
"""
|
|
63
|
+
if n.final_activation_quantization_cfg is None:
|
|
64
|
+
Logger.error(f'Node must have final_activation_quantization_cfg in order to build quantizer configuration') # pragma: no cover
|
|
65
|
+
|
|
66
|
+
final_cfg = n.final_activation_quantization_cfg
|
|
67
|
+
return TrainableQuantizerActivationConfig(final_cfg.activation_quantization_method,
|
|
68
|
+
final_cfg.activation_n_bits,
|
|
69
|
+
final_cfg.activation_quantization_params,
|
|
70
|
+
final_cfg.enable_activation_quantization,
|
|
71
|
+
final_cfg.min_threshold,
|
|
72
|
+
activation_quantization_candidates)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def get_trainable_quantizer_quantization_candidates(n: BaseNode):
|
|
76
|
+
"""
|
|
77
|
+
Returns quantization configuration candidates for activation and weights trainable quantizer.
|
|
78
|
+
Checks that the candidates are compatible with trainable quantizer
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
n: BaseNode - the node to build a trainable quantizer from
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
weights_quantization_candidates - A list of configuration candidates for weights
|
|
85
|
+
activation_quantization_candidates - A list of configuration candidates for activation
|
|
86
|
+
"""
|
|
87
|
+
# all candidates must have the same weights quantization method
|
|
88
|
+
weights_quantization_methods = set([cfg.weights_quantization_cfg.weights_quantization_method for cfg in n.candidates_quantization_cfg])
|
|
89
|
+
if len(weights_quantization_methods) > 1:
|
|
90
|
+
Logger.error(f'Unsupported candidates_quantization_cfg with different weights quantization methods: {weights_quantization_methods}') # pragma: no cover
|
|
91
|
+
|
|
92
|
+
# all candidates must have the same activation quantization method
|
|
93
|
+
activation_quantization_methods = set([cfg.activation_quantization_cfg.activation_quantization_method for cfg in n.candidates_quantization_cfg])
|
|
94
|
+
if len(activation_quantization_methods) > 1:
|
|
95
|
+
Logger.error(f'Unsupported candidates_quantization_cfg with different activation quantization methods: {activation_quantization_methods}') # pragma: no cover
|
|
96
|
+
|
|
97
|
+
# get unique lists of candidates
|
|
98
|
+
unique_weights_candidates = n.get_unique_weights_candidates()
|
|
99
|
+
unique_activation_candidates = n.get_unique_activation_candidates()
|
|
100
|
+
|
|
101
|
+
# verify all the combinations of weights_n_bits and activation_n_bits are allowed
|
|
102
|
+
if len(n.candidates_quantization_cfg) != len(unique_weights_candidates) * len(unique_activation_candidates):
|
|
103
|
+
Logger.error(f'Unsupported candidates_quantization_cfg for a trainable quantizer,'
|
|
104
|
+
f'it must contain all the combinations of (weights_n_bits X activations_n_bits)') # pragma: no cover
|
|
105
|
+
|
|
106
|
+
# generate list of weights quantizer candidates
|
|
107
|
+
weights_cfg_candidates = [TrainableQuantizerCandidateConfig(
|
|
108
|
+
cfg.weights_quantization_cfg.weights_n_bits,
|
|
109
|
+
cfg.weights_quantization_cfg.weights_quantization_params) for cfg in unique_weights_candidates]
|
|
110
|
+
|
|
111
|
+
# generate list of activation quantizer candidates
|
|
112
|
+
activation_cfg_candidates = [TrainableQuantizerCandidateConfig(
|
|
113
|
+
cfg.activation_quantization_cfg.activation_n_bits,
|
|
114
|
+
cfg.activation_quantization_cfg.activation_quantization_params) for cfg in unique_activation_candidates]
|
|
115
|
+
|
|
116
|
+
return weights_cfg_candidates, activation_cfg_candidates
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import Union
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit.gptq import RoundingType
|
|
18
|
+
from model_compression_toolkit import TrainingMethod
|
|
19
|
+
from model_compression_toolkit.core.common import Logger
|
|
20
|
+
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
|
|
22
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants \
|
|
23
|
+
import QUANTIZATION_TARGET, QUANTIZATION_METHOD, QUANTIZER_TYPE
|
|
24
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_all_subclasses \
|
|
25
|
+
import get_all_subclasses
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_trainable_quantizer_class(quant_target: QuantizationTarget,
|
|
29
|
+
quantizer_type: Union[TrainingMethod, RoundingType],
|
|
30
|
+
quant_method: QuantizationMethod,
|
|
31
|
+
quantizer_base_class: type) -> type:
|
|
32
|
+
"""
|
|
33
|
+
Searches for a trainable quantizer class that matches the requested QuantizationTarget and QuantizationMethod and
|
|
34
|
+
a task dedicated quantizer type. Exactly one class should be found.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
quant_target: QuantizationTarget value which indicates what is the target for quantization to
|
|
38
|
+
use the quantizer for.
|
|
39
|
+
quantizer_type: The type of the quantizer (quantization technique).
|
|
40
|
+
This can differ, depending on the purpose the quantizer is for.
|
|
41
|
+
quant_method: A list of QuantizationMethod values to indicate all type of quantization methods that the
|
|
42
|
+
quantizer supports.
|
|
43
|
+
quantizer_base_class: A type of quantizer that the requested quantizer should inherit from.
|
|
44
|
+
|
|
45
|
+
Returns: A class of a quantizer that inherits from BaseKerasQATTrainableQuantizer.
|
|
46
|
+
|
|
47
|
+
"""
|
|
48
|
+
qat_quantizer_classes = get_all_subclasses(quantizer_base_class)
|
|
49
|
+
if len(qat_quantizer_classes) == 0:
|
|
50
|
+
Logger.error(f"No quantizers were found that inherit from {quantizer_base_class}.") # pragma: no cover
|
|
51
|
+
|
|
52
|
+
filtered_quantizers = list(filter(lambda q_class: getattr(q_class, QUANTIZATION_TARGET, None) is not None and
|
|
53
|
+
getattr(q_class, QUANTIZATION_TARGET) == quant_target and
|
|
54
|
+
getattr(q_class, QUANTIZATION_METHOD, None) is not None and
|
|
55
|
+
quant_method in getattr(q_class, QUANTIZATION_METHOD, []) and
|
|
56
|
+
getattr(q_class, QUANTIZER_TYPE, None) == quantizer_type,
|
|
57
|
+
qat_quantizer_classes))
|
|
58
|
+
|
|
59
|
+
if len(filtered_quantizers) != 1:
|
|
60
|
+
Logger.error(f"Found {len(filtered_quantizers)} quantizer for target {quant_target.value} " # pragma: no cover
|
|
61
|
+
f"that matches the requested quantization method {quant_method.name} and "
|
|
62
|
+
f"quantizer type {quantizer_type.value} but there should be exactly one."
|
|
63
|
+
f"The possible quantizers that were found are {filtered_quantizers}.")
|
|
64
|
+
|
|
65
|
+
return filtered_quantizers[0]
|
model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import Tuple, List
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_threshold_reshape_shape(tensor_shape: Tuple, quant_axis: int, quant_axis_dim: int) -> List[int]:
|
|
19
|
+
"""
|
|
20
|
+
Gets a shape that contains 1 in all axis except the quantization axis, to adjust the threshold tensor for
|
|
21
|
+
per-channel quantization.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
tensor_shape: The shape of th
|
|
25
|
+
|
|
26
|
+
e tensor to be quantized.
|
|
27
|
+
quant_axis: The axis along which the quantization happens (usually the tensor's channel axis).
|
|
28
|
+
quant_axis_dim: The dimension of the quantization axis.
|
|
29
|
+
|
|
30
|
+
Returns: A shape to reshape the threshold tensor according to.
|
|
31
|
+
|
|
32
|
+
"""
|
|
33
|
+
n_axis = len(tensor_shape)
|
|
34
|
+
quantization_axis = n_axis + quant_axis if quant_axis < 0 else quant_axis
|
|
35
|
+
|
|
36
|
+
return [quant_axis_dim if i == quantization_axis else 1 for i in range(n_axis)]
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from abc import ABC
|
|
16
|
+
from typing import Dict, List
|
|
17
|
+
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TrainableQuantizerCandidateConfig:
|
|
21
|
+
|
|
22
|
+
def __init__(self,
|
|
23
|
+
n_bits: int,
|
|
24
|
+
quantization_params: Dict,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Class for representing candidates of quantization configurations for trainable quantizer.
|
|
28
|
+
It can be used for weights and activation quantization configuration.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
n_bits (int): Number of bits to use for quantization.
|
|
32
|
+
quantization_params (Dict): Dictionary that contains quantization params.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
self.n_bits = n_bits
|
|
36
|
+
self.quantization_params = quantization_params
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class TrainableQuantizerActivationConfig:
|
|
40
|
+
|
|
41
|
+
def __init__(self,
|
|
42
|
+
activation_quantization_method: QuantizationMethod,
|
|
43
|
+
activation_n_bits: int,
|
|
44
|
+
activation_quantization_params: Dict,
|
|
45
|
+
enable_activation_quantization: bool,
|
|
46
|
+
min_threshold: float,
|
|
47
|
+
activation_quantization_candidates: List[TrainableQuantizerCandidateConfig] = None,
|
|
48
|
+
):
|
|
49
|
+
"""
|
|
50
|
+
Attributes for configuring activations trainable quantizer.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
activation_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for activation quantization.
|
|
54
|
+
activation_n_bits (int): Number of bits to quantize the activations.
|
|
55
|
+
activation_quantization_params (Dict): Dictionary that contains activation quantization params.
|
|
56
|
+
enable_activation_quantization (bool): Whether to quantize the layer's activations or not.
|
|
57
|
+
min_threshold (float): Minimum threshold to use during thresholds selection.
|
|
58
|
+
"""
|
|
59
|
+
self.activation_quantization_method = activation_quantization_method
|
|
60
|
+
self.activation_n_bits = activation_n_bits
|
|
61
|
+
self.activation_quantization_params = activation_quantization_params
|
|
62
|
+
self.enable_activation_quantization = enable_activation_quantization
|
|
63
|
+
self.min_threshold = min_threshold
|
|
64
|
+
self.activation_bits_candidates = activation_quantization_candidates
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class TrainableQuantizerWeightsConfig:
|
|
68
|
+
def __init__(self,
|
|
69
|
+
weights_quantization_method: QuantizationMethod,
|
|
70
|
+
weights_n_bits: int,
|
|
71
|
+
weights_quantization_params: Dict,
|
|
72
|
+
enable_weights_quantization: bool,
|
|
73
|
+
weights_channels_axis: int,
|
|
74
|
+
weights_per_channel_threshold: bool,
|
|
75
|
+
min_threshold: float,
|
|
76
|
+
weights_quantization_candidates: List[TrainableQuantizerCandidateConfig] = None,
|
|
77
|
+
):
|
|
78
|
+
"""
|
|
79
|
+
Attributes for configuring weights trainable quantizer.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
weights_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for weights quantization.
|
|
83
|
+
weights_n_bits (int): Number of bits to quantize the coefficients.
|
|
84
|
+
weights_quantization_params (Dict): Dictionary that contains weights quantization params.
|
|
85
|
+
enable_weights_quantization (bool): Whether to quantize the layer's weights or not.
|
|
86
|
+
weights_channels_axis (int): Axis to quantize a node's kernel when quantizing per-channel.
|
|
87
|
+
weights_per_channel_threshold (bool): Whether to quantize the weights per-channel or not (per-tensor).
|
|
88
|
+
min_threshold (float): Minimum threshold to use during thresholds selection.
|
|
89
|
+
"""
|
|
90
|
+
self.weights_quantization_method = weights_quantization_method
|
|
91
|
+
self.weights_n_bits = weights_n_bits
|
|
92
|
+
self.weights_quantization_params = weights_quantization_params
|
|
93
|
+
self.enable_weights_quantization = enable_weights_quantization
|
|
94
|
+
self.weights_channels_axis = weights_channels_axis
|
|
95
|
+
self.weights_per_channel_threshold = weights_per_channel_threshold
|
|
96
|
+
self.min_threshold = min_threshold
|
|
97
|
+
self.weights_bits_candidates = weights_quantization_candidates
|
model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import Dict, Any, Union, List
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit.core.common import Logger
|
|
18
|
+
from model_compression_toolkit.core.common.constants import FOUND_TF
|
|
19
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
20
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer, VAR, GROUP
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
22
|
+
TrainableQuantizerActivationConfig
|
|
23
|
+
|
|
24
|
+
if FOUND_TF:
|
|
25
|
+
QUANTIZATION_CONFIG = 'quantization_config'
|
|
26
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.keras.config_serialization import config_serialization, \
|
|
27
|
+
config_deserialization
|
|
28
|
+
import tensorflow as tf
|
|
29
|
+
|
|
30
|
+
class BaseKerasTrainableQuantizer(BaseTrainableQuantizer):
|
|
31
|
+
def __init__(self,
|
|
32
|
+
quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
|
|
33
|
+
"""
|
|
34
|
+
This class is a base quantizer which validates provided quantization config and defines an abstract function which any quantizer needs to implement.
|
|
35
|
+
This class adds to the base quantizer a get_config and from_config functions to enable loading and saving the keras model.
|
|
36
|
+
Args:
|
|
37
|
+
quantization_config: quantizer config class contains all the information about a quantizer configuration.
|
|
38
|
+
"""
|
|
39
|
+
super().__init__(quantization_config)
|
|
40
|
+
|
|
41
|
+
def get_config(self) -> Dict[str, Any]:
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
Returns: Configuration of BaseKerasQuantizer.
|
|
45
|
+
|
|
46
|
+
"""
|
|
47
|
+
return {QUANTIZATION_CONFIG: config_serialization(self.quantization_config)}
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def from_config(cls, config: dict):
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
config(dict): dictonory of BaseKerasQuantizer Configuration
|
|
55
|
+
|
|
56
|
+
Returns: A BaseKerasQuantizer
|
|
57
|
+
|
|
58
|
+
"""
|
|
59
|
+
config = config.copy()
|
|
60
|
+
quantization_config = config_deserialization(config[QUANTIZATION_CONFIG])
|
|
61
|
+
# Note that a quantizer only receive quantization config and the rest of define hardcoded inside the speficie quantizer.
|
|
62
|
+
return cls(quantization_config=quantization_config)
|
|
63
|
+
|
|
64
|
+
def get_trainable_variables(self, group: VariableGroup) -> List[tf.Tensor]:
|
|
65
|
+
"""
|
|
66
|
+
Get trainable parameters with specific group from quantizer
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
group: Enum of variable group
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
List of trainable variables
|
|
73
|
+
"""
|
|
74
|
+
quantizer_trainable = []
|
|
75
|
+
for name, parameter_dict in self.quantizer_parameters.items():
|
|
76
|
+
quantizer_parameter, parameter_group = parameter_dict[VAR], parameter_dict[GROUP]
|
|
77
|
+
if quantizer_parameter.trainable and parameter_group == group:
|
|
78
|
+
quantizer_trainable.append(quantizer_parameter)
|
|
79
|
+
return quantizer_trainable
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
else:
|
|
83
|
+
class BaseKerasTrainableQuantizer(BaseTrainableQuantizer):
|
|
84
|
+
def __init__(self,
|
|
85
|
+
quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
|
|
86
|
+
|
|
87
|
+
super().__init__(quantization_config)
|
|
88
|
+
Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
89
|
+
'when using BaseKerasQuantizer. '
|
|
90
|
+
'Could not find Tensorflow package.') # pragma: no cover
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import copy
|
|
16
|
+
|
|
17
|
+
from typing import Any, Union
|
|
18
|
+
from enum import Enum
|
|
19
|
+
|
|
20
|
+
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
|
|
22
|
+
TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig
|
|
23
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common import constants as C
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def transform_enum(v: Any):
|
|
27
|
+
"""
|
|
28
|
+
If an enum is received it value is return otherwise the input is returned.
|
|
29
|
+
Args:
|
|
30
|
+
v: Any type
|
|
31
|
+
|
|
32
|
+
Returns: Any
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
if isinstance(v, Enum):
|
|
36
|
+
return v.value
|
|
37
|
+
return v
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def config_serialization(quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
|
|
41
|
+
"""
|
|
42
|
+
This function change trainable quantizer config to a dictionary
|
|
43
|
+
Args:
|
|
44
|
+
quantization_config: A TrainableQuantizerWeightsConfig or TrainableQuantizerActivationConfig for serialization
|
|
45
|
+
|
|
46
|
+
Returns: A config dictionary of quantizer config
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
config_data = {k: transform_enum(v) for k, v in quantization_config.__dict__.items()}
|
|
50
|
+
config_data[C.IS_WEIGHTS] = isinstance(quantization_config, TrainableQuantizerWeightsConfig)
|
|
51
|
+
config_data[C.IS_ACTIVATIONS] = isinstance(quantization_config, TrainableQuantizerActivationConfig)
|
|
52
|
+
return config_data
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def config_deserialization(in_config: dict) -> Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]:
|
|
56
|
+
"""
|
|
57
|
+
This function change config dictionary to trainable quantizer config.
|
|
58
|
+
Args:
|
|
59
|
+
in_config: A config dictionary of trainable quantizer config.
|
|
60
|
+
|
|
61
|
+
Returns: Trainable quantizer configuration object - TrainableQuantizerWeightsConfig or TrainableQuantizerActivationConfig
|
|
62
|
+
|
|
63
|
+
"""
|
|
64
|
+
in_config = copy.deepcopy(in_config)
|
|
65
|
+
if in_config[C.IS_WEIGHTS]:
|
|
66
|
+
return TrainableQuantizerWeightsConfig(weights_quantization_method=QuantizationMethod(in_config[C.WEIGHTS_QUANTIZATION_METHOD]),
|
|
67
|
+
weights_n_bits=in_config[C.WEIGHTS_N_BITS],
|
|
68
|
+
weights_quantization_params=in_config[C.WEIGHTS_QUANTIZATION_PARAMS],
|
|
69
|
+
enable_weights_quantization=in_config[C.ENABLE_WEIGHTS_QUANTIZATION],
|
|
70
|
+
weights_channels_axis=in_config[C.WEIGHTS_CHANNELS_AXIS],
|
|
71
|
+
weights_per_channel_threshold=in_config[C.WEIGHTS_PER_CHANNEL_THRESHOLD],
|
|
72
|
+
min_threshold=in_config[C.MIN_THRESHOLD])
|
|
73
|
+
elif in_config[C.IS_ACTIVATIONS]:
|
|
74
|
+
return TrainableQuantizerActivationConfig(activation_quantization_method=QuantizationMethod(in_config[C.ACTIVATION_QUANTIZATION_METHOD]),
|
|
75
|
+
activation_n_bits=in_config[C.ACTIVATION_N_BITS],
|
|
76
|
+
activation_quantization_params=in_config[C.ACTIVATION_QUANTIZATION_PARAMS],
|
|
77
|
+
enable_activation_quantization=in_config[C.ENABLE_ACTIVATION_QUANTIZATION],
|
|
78
|
+
min_threshold=in_config[C.MIN_THRESHOLD])
|
|
79
|
+
else:
|
|
80
|
+
raise NotImplemented # pragma: no cover
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import tensorflow as tf
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def int_quantization_with_threshold(data: tf.Tensor,
|
|
21
|
+
n_bits: int,
|
|
22
|
+
signed: bool,
|
|
23
|
+
threshold: np.ndarray,
|
|
24
|
+
eps: float) -> tf.Tensor:
|
|
25
|
+
"""
|
|
26
|
+
Divides data by threshold and quantize it to integers in the quantization range (depends on signed value).
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
data: tensor data.
|
|
30
|
+
n_bits: number of bits that determines the quantization range.
|
|
31
|
+
signed: Whether the quantization is signed or not.
|
|
32
|
+
threshold: threshold for quantization.
|
|
33
|
+
eps: Small value for numerical stability in division.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Uniform Quantized tensor.
|
|
37
|
+
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
if signed:
|
|
41
|
+
clip_max = 2 ** (n_bits - 1) - 1
|
|
42
|
+
clip_min = -2 ** (n_bits - 1)
|
|
43
|
+
else:
|
|
44
|
+
clip_max = 2 ** n_bits - 1
|
|
45
|
+
clip_min = 0
|
|
46
|
+
|
|
47
|
+
return tf.clip_by_value((data / (threshold + eps)) * (2 ** (n_bits - int(signed))),
|
|
48
|
+
clip_value_max=clip_max, clip_value_min=clip_min)
|
model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import Union, List
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit.core.common.logger import Logger
|
|
18
|
+
from model_compression_toolkit.core.common.constants import FOUND_TORCH
|
|
19
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
20
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer, VAR, GROUP
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
22
|
+
TrainableQuantizerActivationConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
if FOUND_TORCH:
|
|
26
|
+
|
|
27
|
+
import torch
|
|
28
|
+
|
|
29
|
+
class BasePytorchTrainableQuantizer(BaseTrainableQuantizer):
|
|
30
|
+
def __init__(self,
|
|
31
|
+
quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
|
|
32
|
+
"""
|
|
33
|
+
This class is a base Pytorch quantizer which validates the provided quantization config and defines an
|
|
34
|
+
abstract function which any quantizer needs to implement.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
quantization_config: quantizer config class contains all the information about the quantizer configuration.
|
|
38
|
+
"""
|
|
39
|
+
super().__init__(quantization_config)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_trainable_variables(self, group: VariableGroup) -> List[torch.Tensor]:
|
|
43
|
+
"""
|
|
44
|
+
Get trainable parameters with specific group from quantizer
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
group: Enum of variable group
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
List of trainable variables
|
|
51
|
+
"""
|
|
52
|
+
quantizer_trainable = []
|
|
53
|
+
for name, parameter_dict in self.quantizer_parameters.items():
|
|
54
|
+
quantizer_parameter, parameter_group = parameter_dict[VAR], parameter_dict[GROUP]
|
|
55
|
+
if quantizer_parameter.requires_grad and parameter_group == group:
|
|
56
|
+
quantizer_trainable.append(quantizer_parameter)
|
|
57
|
+
return quantizer_trainable
|
|
58
|
+
|
|
59
|
+
else:
|
|
60
|
+
class BasePytorchTrainableQuantizer(BaseTrainableQuantizer):
|
|
61
|
+
def __init__(self,
|
|
62
|
+
quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
|
|
63
|
+
super().__init__(quantization_config)
|
|
64
|
+
Logger.critical('Installing Pytorch is mandatory '
|
|
65
|
+
'when using BasePytorchTrainableQuantizer. '
|
|
66
|
+
'Could not find torch package.') # pragma: no cover
|