mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -1,247 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
import tensorflow as tf
|
|
16
|
-
import numpy as np
|
|
17
|
-
|
|
18
|
-
from model_compression_toolkit import GumbelConfig
|
|
19
|
-
from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
|
|
20
|
-
from model_compression_toolkit.gptq.keras.quantizer.gumbel_rounding.base_gumbel_rounding import GumbelRoundingBase, \
|
|
21
|
-
init_aux_var
|
|
22
|
-
from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
|
|
23
|
-
from tensorflow.python.framework.tensor_shape import TensorShape
|
|
24
|
-
from model_compression_toolkit.core.common.defaultdict import DefaultDict
|
|
25
|
-
from typing import Dict, Any, List
|
|
26
|
-
from model_compression_toolkit.gptq.keras.quantizer.gumbel_rounding.gumbel_softmax import gumbel_softmax, ste_gumbel
|
|
27
|
-
from model_compression_toolkit.core.common.constants import RANGE_MIN, RANGE_MAX
|
|
28
|
-
from model_compression_toolkit.gptq.common import gptq_constants
|
|
29
|
-
from model_compression_toolkit.gptq.keras.quantizer.ste_rounding.uniform_ste import rounding_uniform_quantizer
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def gumbel_rounding_uniform_quantizer(tensor_data: tf.Tensor,
|
|
33
|
-
auxvar_tensor: tf.Variable,
|
|
34
|
-
range_min: tf.Tensor,
|
|
35
|
-
range_max: tf.Tensor,
|
|
36
|
-
n_bits: int) -> tf.Tensor:
|
|
37
|
-
"""
|
|
38
|
-
Quantize a tensor according to given range (min, max) and number of bits.
|
|
39
|
-
|
|
40
|
-
Args:
|
|
41
|
-
tensor_data: Tensor values to quantize.
|
|
42
|
-
auxvar_tensor: Tensor that manifests the bit shift the weight due to gptq.
|
|
43
|
-
range_min: minimum bound of the range for quantization (or array of min values per channel).
|
|
44
|
-
range_max: maximum bound of the range for quantization (or array of max values per channel).
|
|
45
|
-
n_bits: Number of bits to quantize the tensor.
|
|
46
|
-
|
|
47
|
-
Returns:
|
|
48
|
-
Quantized data.
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
# adjusts the quantization rage so the quantization grid include zero.
|
|
52
|
-
a, b = qutils.fix_range_to_include_zero(range_min, range_max, n_bits)
|
|
53
|
-
|
|
54
|
-
# Compute the step size of quantized values.
|
|
55
|
-
delta = (b - a) / (2 ** n_bits - 1)
|
|
56
|
-
|
|
57
|
-
input_tensor_int = tf.stop_gradient(tf.floor((tensor_data - a) / delta)) # Apply rounding
|
|
58
|
-
tensor_q = input_tensor_int + auxvar_tensor
|
|
59
|
-
|
|
60
|
-
# Clip data in range
|
|
61
|
-
clipped_tensor = qutils.ste_clip(tensor_q, min_val=0, max_val=2 ** n_bits - 1)
|
|
62
|
-
|
|
63
|
-
# Quantize the data between min/max of quantization range.
|
|
64
|
-
q = delta * clipped_tensor + a
|
|
65
|
-
return q
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
class UniformGumbelRounding(GumbelRoundingBase):
|
|
69
|
-
"""
|
|
70
|
-
Trainable constrained quantizer to quantize a layer inputs.
|
|
71
|
-
"""
|
|
72
|
-
PTQ_MIN_RANGE = "_min_range"
|
|
73
|
-
PTQ_MAX_RANGE = "_max_range"
|
|
74
|
-
|
|
75
|
-
def __init__(self, num_bits: int, per_axis: bool, signed: bool, quantization_parameter_learning: bool,
|
|
76
|
-
min_range: np.ndarray, max_range: np.ndarray, gumbel_config: GumbelConfig,
|
|
77
|
-
quantization_axis: int = -1, max_lsbs_change_map: dict = DefaultDict({}, lambda: 1),
|
|
78
|
-
max_iteration: int = 10000):
|
|
79
|
-
"""
|
|
80
|
-
Initialize a TrainableWeightQuantizer object with parameters to use
|
|
81
|
-
for the quantization.
|
|
82
|
-
|
|
83
|
-
Args:
|
|
84
|
-
num_bits: Number of bits to use for the quantization.
|
|
85
|
-
per_axis: Whether to quantize per-channel or per-tensor.
|
|
86
|
-
signed: Signedness to use for the quantization range.
|
|
87
|
-
quantization_parameter_learning: Threshold to use for the quantization.
|
|
88
|
-
min_range: a numpy array of the min range.
|
|
89
|
-
max_range: a numpy array of the max range.
|
|
90
|
-
gumbel_config: A class with the gumbel rounding configurations.
|
|
91
|
-
quantization_axis: Axis of tensor to use for the quantization.
|
|
92
|
-
max_lsbs_change_map: a mapping between number of bits to max lsb change.
|
|
93
|
-
max_iteration: The number of iteration of gptq.
|
|
94
|
-
"""
|
|
95
|
-
super().__init__(num_bits, per_axis, signed, False, False, quantization_parameter_learning,
|
|
96
|
-
quantization_axis, gumbel_config,
|
|
97
|
-
max_lsbs_change_map,
|
|
98
|
-
max_iteration)
|
|
99
|
-
self.threshold_shape = np.asarray(min_range).shape
|
|
100
|
-
self.min_range = np.reshape(np.asarray(min_range), [-1]) if self.per_axis else float(
|
|
101
|
-
min_range)
|
|
102
|
-
self.max_range = np.reshape(np.asarray(max_range), [-1]) if self.per_axis else float(
|
|
103
|
-
max_range)
|
|
104
|
-
self.k_threshold = len(self.max_range) if self.per_axis else 1
|
|
105
|
-
|
|
106
|
-
def build(self,
|
|
107
|
-
tensor_shape: TensorShape,
|
|
108
|
-
name: str,
|
|
109
|
-
layer: QuantizeWrapper) -> Dict[str, tf.Variable]:
|
|
110
|
-
"""
|
|
111
|
-
Add min and max variables to layer.
|
|
112
|
-
Args:
|
|
113
|
-
tensor_shape: Tensor shape the quantizer quantize.
|
|
114
|
-
name: Prefix of variables names.
|
|
115
|
-
layer: Layer to add the variables to. The variables are saved
|
|
116
|
-
in the layer's scope.
|
|
117
|
-
|
|
118
|
-
Returns:
|
|
119
|
-
Dictionary of new variables.
|
|
120
|
-
"""
|
|
121
|
-
super().build(tensor_shape, name, layer)
|
|
122
|
-
|
|
123
|
-
if self.per_axis:
|
|
124
|
-
input_shape = tensor_shape
|
|
125
|
-
n_axis = len(input_shape)
|
|
126
|
-
quantization_axis = n_axis + self.quantization_axis if self.quantization_axis < 0 else \
|
|
127
|
-
self.quantization_axis
|
|
128
|
-
reshape_shape = [self.k_threshold if i == quantization_axis else 1 for i in range(n_axis)]
|
|
129
|
-
else:
|
|
130
|
-
reshape_shape = [self.k_threshold]
|
|
131
|
-
|
|
132
|
-
max_range = layer.add_weight(
|
|
133
|
-
name + self.PTQ_MAX_RANGE,
|
|
134
|
-
shape=reshape_shape,
|
|
135
|
-
initializer=tf.keras.initializers.Constant(1.0),
|
|
136
|
-
trainable=self.quantization_parameter_learning)
|
|
137
|
-
max_range.assign(self.max_range.reshape(reshape_shape))
|
|
138
|
-
|
|
139
|
-
min_range = layer.add_weight(
|
|
140
|
-
name + self.PTQ_MIN_RANGE,
|
|
141
|
-
shape=reshape_shape,
|
|
142
|
-
initializer=tf.keras.initializers.Constant(1.0),
|
|
143
|
-
trainable=self.quantization_parameter_learning)
|
|
144
|
-
min_range.assign(self.min_range.reshape(reshape_shape))
|
|
145
|
-
|
|
146
|
-
auxvar_tensor = layer.add_weight(
|
|
147
|
-
name + gptq_constants.AUXVAR,
|
|
148
|
-
shape=[self.m, *self.w_shape],
|
|
149
|
-
initializer=tf.keras.initializers.Constant(0.0),
|
|
150
|
-
trainable=True)
|
|
151
|
-
w = getattr(layer.layer, name)
|
|
152
|
-
|
|
153
|
-
q_error = w - rounding_uniform_quantizer(w, min_range, max_range,
|
|
154
|
-
n_bits=self.num_bits)
|
|
155
|
-
ceil_indicator = (q_error < 0).numpy().astype("int") # Negative error means the choose point is rounded to ceil.
|
|
156
|
-
auxvar_tensor.assign(init_aux_var(ceil_indicator, self.w_shape, self.m))
|
|
157
|
-
|
|
158
|
-
self.quantizer_parameters.update({gptq_constants.AUXVAR: auxvar_tensor,
|
|
159
|
-
self.PTQ_MAX_RANGE: max_range,
|
|
160
|
-
self.PTQ_MIN_RANGE: min_range})
|
|
161
|
-
return self.quantizer_parameters
|
|
162
|
-
|
|
163
|
-
def __call__(self, inputs: tf.Tensor,
|
|
164
|
-
training: bool,
|
|
165
|
-
weights: Dict[str, tf.Variable],
|
|
166
|
-
**kwargs: Dict[str, Any]):
|
|
167
|
-
"""
|
|
168
|
-
Quantize a tensor.
|
|
169
|
-
Args:
|
|
170
|
-
inputs: Input tensor to quantize.
|
|
171
|
-
training: Whether the graph is in training mode.
|
|
172
|
-
weights: Dictionary of weights the quantizer can use to quantize the tensor.
|
|
173
|
-
**kwargs: Additional variables the quantizer may receive.
|
|
174
|
-
|
|
175
|
-
Returns:
|
|
176
|
-
The quantized tensor.
|
|
177
|
-
"""
|
|
178
|
-
|
|
179
|
-
auxvar = weights[gptq_constants.AUXVAR]
|
|
180
|
-
ar_iter = weights[gptq_constants.GPTQ_ITER]
|
|
181
|
-
ptq_min_range = weights[self.PTQ_MIN_RANGE]
|
|
182
|
-
ptq_max_range = weights[self.PTQ_MAX_RANGE]
|
|
183
|
-
aux_index_shift = weights[gptq_constants.AUXSHIFT]
|
|
184
|
-
self.update_iteration(training, ar_iter)
|
|
185
|
-
if self.per_axis:
|
|
186
|
-
input_shape = inputs.shape
|
|
187
|
-
n_axis = len(input_shape)
|
|
188
|
-
quantization_axis = n_axis + self.quantization_axis if self.quantization_axis < 0 else \
|
|
189
|
-
self.quantization_axis
|
|
190
|
-
reshape_shape = [-1 if i == quantization_axis else 1 for i in range(n_axis)]
|
|
191
|
-
|
|
192
|
-
reshape_shape_aux_ind = [-1, *[1 for _ in range(n_axis)]]
|
|
193
|
-
#####################################################
|
|
194
|
-
# Gumbel Softmax
|
|
195
|
-
#####################################################
|
|
196
|
-
if training:
|
|
197
|
-
p_t = gumbel_softmax(auxvar, self.tau, self.g_t)
|
|
198
|
-
else:
|
|
199
|
-
p_t = gumbel_softmax(auxvar, self.minimal_temp, 0)
|
|
200
|
-
p_t = ste_gumbel(p_t)
|
|
201
|
-
self.p_t = p_t
|
|
202
|
-
#####################################################
|
|
203
|
-
# Calculate v hat and threshold hat
|
|
204
|
-
#####################################################
|
|
205
|
-
ptq_min_range = tf.reshape(ptq_min_range, reshape_shape)
|
|
206
|
-
ptq_max_range = tf.reshape(ptq_max_range, reshape_shape)
|
|
207
|
-
|
|
208
|
-
auxvar_hat = tf.reduce_sum(p_t * tf.reshape(aux_index_shift, reshape_shape_aux_ind), axis=0)
|
|
209
|
-
#####################################################
|
|
210
|
-
# Quantized Input
|
|
211
|
-
#####################################################
|
|
212
|
-
q_tensor = gumbel_rounding_uniform_quantizer(inputs, auxvar_hat,
|
|
213
|
-
ptq_min_range,
|
|
214
|
-
ptq_max_range,
|
|
215
|
-
self.num_bits)
|
|
216
|
-
return q_tensor
|
|
217
|
-
else:
|
|
218
|
-
raise NotImplemented
|
|
219
|
-
return gumbel_rounding_uniform_quantizer(inputs, auxvar_hat,
|
|
220
|
-
ptq_max_range,
|
|
221
|
-
ptq_min_range,
|
|
222
|
-
self.num_bits)
|
|
223
|
-
|
|
224
|
-
def get_quant_config(self, layer) -> Dict[str, np.ndarray]:
|
|
225
|
-
"""
|
|
226
|
-
Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
|
|
227
|
-
|
|
228
|
-
Args:
|
|
229
|
-
layer: quantized layer
|
|
230
|
-
|
|
231
|
-
Returns:
|
|
232
|
-
A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
|
|
233
|
-
Keys must match NodeQuantizationConfig attributes
|
|
234
|
-
|
|
235
|
-
"""
|
|
236
|
-
min_range = self.quantizer_parameters[self.PTQ_MIN_RANGE]
|
|
237
|
-
max_range = self.quantizer_parameters[self.PTQ_MAX_RANGE]
|
|
238
|
-
return {RANGE_MIN: min_range.numpy().reshape(self.threshold_shape),
|
|
239
|
-
RANGE_MAX: max_range.numpy().reshape(self.threshold_shape)}
|
|
240
|
-
|
|
241
|
-
def get_quantization_variable(self) -> List[tf.Tensor]:
|
|
242
|
-
"""
|
|
243
|
-
This function return a list of quantizer parameters.
|
|
244
|
-
Returns: A list of the quantizer parameters
|
|
245
|
-
|
|
246
|
-
"""
|
|
247
|
-
return [self.quantizer_parameters[self.PTQ_MIN_RANGE], self.quantizer_parameters[self.PTQ_MAX_RANGE]]
|
|
@@ -1,50 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
import tensorflow as tf
|
|
16
|
-
from model_compression_toolkit.core.keras.constants import KERNEL
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def get_kernel(weights_list: list) -> tf.Tensor:
|
|
20
|
-
"""
|
|
21
|
-
This function a list of weights and return the kernel
|
|
22
|
-
Args:
|
|
23
|
-
weights_list: A list of Tensors
|
|
24
|
-
|
|
25
|
-
Returns: The kernel tensor.
|
|
26
|
-
|
|
27
|
-
"""
|
|
28
|
-
for w in weights_list:
|
|
29
|
-
if KERNEL in w.name:
|
|
30
|
-
return w
|
|
31
|
-
raise Exception("Can't find kernel variable")
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def threshold_reshape(threshold_tensor: tf.Tensor, input_w: tf.Tensor, in_quantization_axis: int) -> tf.Tensor:
|
|
35
|
-
"""
|
|
36
|
-
This function take a threshold tensor and re-aline it to the weight tensor channel axis.
|
|
37
|
-
Args:
|
|
38
|
-
threshold_tensor: A tensor of threshold
|
|
39
|
-
input_w: A weight tensor
|
|
40
|
-
in_quantization_axis: A int value that represent the channel axis.
|
|
41
|
-
|
|
42
|
-
Returns: A reshape tensor of threshold.
|
|
43
|
-
|
|
44
|
-
"""
|
|
45
|
-
input_shape = input_w.shape
|
|
46
|
-
n_axis = len(input_shape)
|
|
47
|
-
quantization_axis = n_axis + in_quantization_axis if in_quantization_axis < 0 else in_quantization_axis
|
|
48
|
-
reshape_shape = [-1 if i == quantization_axis else 1 for i in range(n_axis)]
|
|
49
|
-
ptq_threshold_tensor = tf.reshape(threshold_tensor, reshape_shape)
|
|
50
|
-
return ptq_threshold_tensor
|
|
@@ -1,49 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
import tensorflow as tf
|
|
16
|
-
|
|
17
|
-
from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def rounding_uniform_quantizer(tensor_data: tf.Tensor,
|
|
21
|
-
range_min: tf.Tensor,
|
|
22
|
-
range_max: tf.Tensor,
|
|
23
|
-
n_bits: int) -> tf.Tensor:
|
|
24
|
-
"""
|
|
25
|
-
Quantize a tensor according to given range (min, max) and number of bits.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
tensor_data: Tensor values to quantize.
|
|
29
|
-
range_min: minimum bound of the range for quantization (or array of min values per channel).
|
|
30
|
-
range_max: maximum bound of the range for quantization (or array of max values per channel).
|
|
31
|
-
n_bits: Number of bits to quantize the tensor.
|
|
32
|
-
|
|
33
|
-
Returns:
|
|
34
|
-
Quantized data.
|
|
35
|
-
"""
|
|
36
|
-
# adjusts the quantization rage so the quantization grid include zero.
|
|
37
|
-
a, b = qutils.fix_range_to_include_zero(range_min, range_max, n_bits)
|
|
38
|
-
|
|
39
|
-
# Compute the step size of quantized values.
|
|
40
|
-
delta = (b - a) / (2 ** n_bits - 1)
|
|
41
|
-
|
|
42
|
-
input_tensor_int = qutils.ste_round((tensor_data - a) / delta) # Apply rounding
|
|
43
|
-
|
|
44
|
-
# Clip data in range
|
|
45
|
-
clipped_tensor = qutils.ste_clip(input_tensor_int, min_val=0, max_val=2 ** n_bits - 1)
|
|
46
|
-
|
|
47
|
-
# Quantize the data between min/max of quantization range.
|
|
48
|
-
q = delta * clipped_tensor + a
|
|
49
|
-
return q
|
|
@@ -1,94 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
import torch
|
|
16
|
-
import torch.nn as nn
|
|
17
|
-
from typing import List
|
|
18
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.quantizer_wrapper import WeightQuantizerWrapper
|
|
19
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.gumbel_rounding.base_gumbel_weights_quantizer import BaseGumbelWeightQuantizer
|
|
20
|
-
from model_compression_toolkit.core.pytorch.constants import BIAS
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def get_trainable_parameters(fxp_model: nn.Module,
|
|
24
|
-
add_bias: bool = False,
|
|
25
|
-
quantization_parameters_learning: bool = False,
|
|
26
|
-
is_gumbel: bool = False) -> (List[nn.Parameter], List[nn.Parameter], List[nn.Parameter]):
|
|
27
|
-
"""
|
|
28
|
-
Get trainable parameters from all layers in a model
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
fxp_model: Model to get its trainable parameters.
|
|
32
|
-
add_bias: Whether to include biases of the model (if there are) or not.
|
|
33
|
-
quantization_parameters_learning: Whether to include quantization parameters of the model or not.
|
|
34
|
-
is_gumbel: Whether the fxp model is quantized using Gumbel Rounding
|
|
35
|
-
Returns:
|
|
36
|
-
A list of trainable variables in a model. Each item is a list of a layers weights.
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
trainable_aux_weights = nn.ParameterList()
|
|
40
|
-
trainable_threshold = nn.ParameterList()
|
|
41
|
-
trainable_bias = nn.ParameterList()
|
|
42
|
-
trainable_temperature = nn.ParameterList()
|
|
43
|
-
|
|
44
|
-
for layer in fxp_model.modules():
|
|
45
|
-
if isinstance(layer, WeightQuantizerWrapper):
|
|
46
|
-
trainable_aux_weights.append(layer.weight_quantizer.get_aux_variable())
|
|
47
|
-
if quantization_parameters_learning:
|
|
48
|
-
trainable_threshold.extend(layer.weight_quantizer.get_quantization_variable())
|
|
49
|
-
if is_gumbel:
|
|
50
|
-
trainable_temperature.append(layer.weight_quantizer.get_temperature_variable())
|
|
51
|
-
if add_bias and hasattr(layer.op, BIAS):
|
|
52
|
-
bias = getattr(layer.op, BIAS)
|
|
53
|
-
trainable_bias.append(bias)
|
|
54
|
-
|
|
55
|
-
return trainable_aux_weights, trainable_bias, trainable_threshold, trainable_temperature
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def get_gumbel_probability(fxp_model: nn.Module) -> List[torch.Tensor]:
|
|
59
|
-
"""
|
|
60
|
-
This function return the gumbel softmax probability of GumRounding
|
|
61
|
-
Args:
|
|
62
|
-
fxp_model: A model to be quantized with GumRounding
|
|
63
|
-
|
|
64
|
-
Returns: A list of tensors.
|
|
65
|
-
|
|
66
|
-
"""
|
|
67
|
-
gumbel_prob_aux = []
|
|
68
|
-
for layer in fxp_model.modules():
|
|
69
|
-
if isinstance(layer, WeightQuantizerWrapper) and isinstance(layer.weight_quantizer, BaseGumbelWeightQuantizer):
|
|
70
|
-
gumbel_prob_aux.append(layer.weight_quantizer.get_gumbel_probability())
|
|
71
|
-
return gumbel_prob_aux
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def get_weights_for_loss(fxp_model: nn.Module) -> [List, List]:
|
|
75
|
-
"""
|
|
76
|
-
Get all float and quantized kernels for the GPTQ loss
|
|
77
|
-
|
|
78
|
-
Args:
|
|
79
|
-
fxp_model: Model to get its float and quantized weights.
|
|
80
|
-
|
|
81
|
-
Returns:
|
|
82
|
-
A list of float kernels, each item is the float kernel of the layer
|
|
83
|
-
A list of quantized kernels, each item is the quantized kernel of the layer
|
|
84
|
-
"""
|
|
85
|
-
|
|
86
|
-
flp_weights_list, fxp_weights_list = [], []
|
|
87
|
-
for layer in fxp_model.modules():
|
|
88
|
-
if isinstance(layer, WeightQuantizerWrapper):
|
|
89
|
-
# Collect pairs of float and quantized weights per layer
|
|
90
|
-
weights = layer.op.weight
|
|
91
|
-
flp_weights_list.append(weights)
|
|
92
|
-
fxp_weights_list.append(layer.weight_quantizer(weights, training=False))
|
|
93
|
-
|
|
94
|
-
return flp_weights_list, fxp_weights_list
|
|
@@ -1,113 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
import torch
|
|
16
|
-
from typing import Tuple, List
|
|
17
|
-
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
18
|
-
from model_compression_toolkit.core import common
|
|
19
|
-
from model_compression_toolkit.core.common.graph.base_graph import BaseNode
|
|
20
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
21
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
22
|
-
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder, PytorchModel
|
|
23
|
-
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
|
24
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.quantizer_wrapper import quantizer_wrapper
|
|
25
|
-
from model_compression_toolkit.core.pytorch.utils import get_working_device
|
|
26
|
-
from model_compression_toolkit.core.pytorch.constants import BUFFER
|
|
27
|
-
from model_compression_toolkit.core.pytorch.reader.node_holders import BufferHolder
|
|
28
|
-
from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
class GPTQPytorchModel(PytorchModel):
|
|
32
|
-
"""
|
|
33
|
-
Class for GPTQ PyTorch model.
|
|
34
|
-
"""
|
|
35
|
-
|
|
36
|
-
def __init__(self,
|
|
37
|
-
graph: common.Graph,
|
|
38
|
-
gptq_config: GradientPTQConfig,
|
|
39
|
-
append2output=None,
|
|
40
|
-
return_float_outputs: bool = True):
|
|
41
|
-
"""
|
|
42
|
-
Args:
|
|
43
|
-
graph: Graph to build the model from.
|
|
44
|
-
gptq_config: Configuration for GPTQ optimization.
|
|
45
|
-
append2output: Nodes to append to model's output.
|
|
46
|
-
return_float_outputs: Whether the model returns float tensors or not.
|
|
47
|
-
"""
|
|
48
|
-
|
|
49
|
-
super().__init__(graph,
|
|
50
|
-
append2output,
|
|
51
|
-
DEFAULT_PYTORCH_INFO,
|
|
52
|
-
return_float_outputs)
|
|
53
|
-
|
|
54
|
-
for node in graph.nodes():
|
|
55
|
-
if not isinstance(node, FunctionalNode):
|
|
56
|
-
if node.type == BufferHolder:
|
|
57
|
-
self.add_module(node.name, node_builder(node))
|
|
58
|
-
self.get_submodule(node.name).register_buffer(node.name,torch.Tensor(node.get_weights_by_keys(BUFFER)).to(get_working_device()))
|
|
59
|
-
else:
|
|
60
|
-
self.add_module(node.name, quantizer_wrapper(node, gptq_config))
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def _quantize_node_activations(self,
|
|
64
|
-
node: BaseNode,
|
|
65
|
-
input_tensors: List[torch.Tensor]) -> List[torch.Tensor]:
|
|
66
|
-
"""
|
|
67
|
-
Quantize node's activation given input tensors.
|
|
68
|
-
|
|
69
|
-
Args:
|
|
70
|
-
node: Node to quantize its outputs.
|
|
71
|
-
input_tensors: Input tensors of the node.
|
|
72
|
-
|
|
73
|
-
Returns:
|
|
74
|
-
Output of the node.
|
|
75
|
-
|
|
76
|
-
"""
|
|
77
|
-
return node.final_activation_quantization_cfg.quantize_node_output(input_tensors)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
class GPTQPytorchModelBuilder(PyTorchModelBuilder):
|
|
81
|
-
"""
|
|
82
|
-
Builder of GPTQ Pytorch models.
|
|
83
|
-
"""
|
|
84
|
-
|
|
85
|
-
def __init__(self,
|
|
86
|
-
graph: common.Graph,
|
|
87
|
-
gptq_config: GradientPTQConfig,
|
|
88
|
-
append2output=None,
|
|
89
|
-
return_float_outputs: bool = True):
|
|
90
|
-
"""
|
|
91
|
-
|
|
92
|
-
Args:
|
|
93
|
-
graph: Graph to build the model from.
|
|
94
|
-
gptq_config: Configuration for GPTQ optimization.
|
|
95
|
-
append2output: Nodes to append to model's output.
|
|
96
|
-
return_float_outputs: Whether the model returns float tensors or not.
|
|
97
|
-
"""
|
|
98
|
-
super().__init__(graph,
|
|
99
|
-
append2output,
|
|
100
|
-
DEFAULT_PYTORCH_INFO,
|
|
101
|
-
return_float_outputs)
|
|
102
|
-
self.gptq_config = gptq_config
|
|
103
|
-
|
|
104
|
-
def build_model(self) -> Tuple[PytorchModel, UserInformation]:
|
|
105
|
-
"""
|
|
106
|
-
Build a GPTQ PyTorch model and return it.
|
|
107
|
-
Returns:
|
|
108
|
-
GPTQ PyTorch model and user information.
|
|
109
|
-
"""
|
|
110
|
-
return GPTQPytorchModel(self.graph,
|
|
111
|
-
self.gptq_config,
|
|
112
|
-
self.append2output,
|
|
113
|
-
self.return_float_outputs), self.graph.user_info
|
|
@@ -1,71 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
import torch
|
|
16
|
-
import torch.nn as nn
|
|
17
|
-
from typing import List, Union, Dict, Any
|
|
18
|
-
from abc import abstractmethod
|
|
19
|
-
from model_compression_toolkit.core.common import Logger
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class BaseWeightQuantizer(nn.Module):
|
|
23
|
-
|
|
24
|
-
def __init__(self):
|
|
25
|
-
"""
|
|
26
|
-
Construct a Base Pytorch model that utilizes a fake weight quantizer
|
|
27
|
-
"""
|
|
28
|
-
super().__init__()
|
|
29
|
-
self.trainable_params = dict()
|
|
30
|
-
|
|
31
|
-
def get_trainable_params(self) -> List:
|
|
32
|
-
"""
|
|
33
|
-
A function to get a list of trainable parameters of the quantizer for GPTQ retraining
|
|
34
|
-
Returns:
|
|
35
|
-
A list of trainable tensors
|
|
36
|
-
"""
|
|
37
|
-
return [value for value in self.trainable_params.values() if value is not None]
|
|
38
|
-
|
|
39
|
-
@abstractmethod
|
|
40
|
-
def get_aux_variable(self) -> torch.Tensor:
|
|
41
|
-
"""
|
|
42
|
-
Returns auxiliary trainable variables
|
|
43
|
-
"""
|
|
44
|
-
raise Logger.error(f'{self.__class__.__name__} have to implement the this abstract function.')
|
|
45
|
-
|
|
46
|
-
@abstractmethod
|
|
47
|
-
def get_quantization_variable(self) -> Union[torch.Tensor, List]:
|
|
48
|
-
"""
|
|
49
|
-
Returns quantization trainable variables
|
|
50
|
-
"""
|
|
51
|
-
raise Logger.error(f'{self.__class__.__name__} have to implement the this abstract function.')
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
@abstractmethod
|
|
55
|
-
def get_weight_quantization_params(self) -> Dict[str, Any]:
|
|
56
|
-
"""
|
|
57
|
-
Returns weight quantization dictionary params
|
|
58
|
-
"""
|
|
59
|
-
raise Logger.error(f'{self.__class__.__name__} have to implement the this abstract function.')
|
|
60
|
-
|
|
61
|
-
@abstractmethod
|
|
62
|
-
def forward(self, w:nn.parameter, training:bool = True) -> torch.Tensor:
|
|
63
|
-
"""
|
|
64
|
-
Forward-Pass
|
|
65
|
-
Args:
|
|
66
|
-
w: weights to quantize.
|
|
67
|
-
training: whether in training mode or not
|
|
68
|
-
Returns:
|
|
69
|
-
quantized weights
|
|
70
|
-
"""
|
|
71
|
-
raise Logger.error(f'{self.__class__.__name__} have to implement the this abstract function.')
|
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|