mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -14,8 +14,6 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
from model_compression_toolkit.core.common.quantization.debug_config import DebugConfig
|
|
17
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, GumbelConfig, \
|
|
18
|
-
GradientPTQConfigV2
|
|
19
17
|
from model_compression_toolkit.core.common.quantization import quantization_config
|
|
20
18
|
from model_compression_toolkit.core.common.mixed_precision import mixed_precision_quantization_config
|
|
21
19
|
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
|
|
@@ -26,6 +24,7 @@ from model_compression_toolkit.core.tpc_models.get_target_platform_capabilities
|
|
|
26
24
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
27
25
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
28
26
|
MixedPrecisionQuantizationConfig, MixedPrecisionQuantizationConfigV2
|
|
27
|
+
from model_compression_toolkit.qat.common.qat_config import QATConfig, TrainingMethod
|
|
29
28
|
from model_compression_toolkit.core.common.logger import set_log_folder
|
|
30
29
|
from model_compression_toolkit.core.common.data_loader import FolderImageLoader
|
|
31
30
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo, ChannelAxis
|
|
@@ -35,21 +34,21 @@ from model_compression_toolkit.core.common import network_editors as network_edi
|
|
|
35
34
|
from model_compression_toolkit.core.keras.quantization_facade import keras_post_training_quantization, \
|
|
36
35
|
keras_post_training_quantization_mixed_precision
|
|
37
36
|
from model_compression_toolkit.ptq.keras.quantization_facade import keras_post_training_quantization_experimental
|
|
38
|
-
from model_compression_toolkit.
|
|
39
|
-
|
|
40
|
-
from model_compression_toolkit.
|
|
41
|
-
from model_compression_toolkit.qat.keras.quantization_facade import keras_quantization_aware_training_init, \
|
|
42
|
-
keras_quantization_aware_training_finalize
|
|
43
|
-
from model_compression_toolkit.core.pytorch.quantization_facade import pytorch_post_training_quantization, \
|
|
44
|
-
pytorch_post_training_quantization_mixed_precision
|
|
37
|
+
from model_compression_toolkit.qat.keras.quantization_facade import keras_quantization_aware_training_init, keras_quantization_aware_training_finalize
|
|
38
|
+
from model_compression_toolkit.qat.pytorch.quantization_facade import pytorch_quantization_aware_training_init, pytorch_quantization_aware_training_finalize
|
|
39
|
+
from model_compression_toolkit.core.pytorch.quantization_facade import pytorch_post_training_quantization, pytorch_post_training_quantization_mixed_precision
|
|
45
40
|
from model_compression_toolkit.ptq.pytorch.quantization_facade import pytorch_post_training_quantization_experimental
|
|
46
|
-
from model_compression_toolkit.gptq.pytorch.quantization_facade import \
|
|
47
|
-
pytorch_gradient_post_training_quantization_experimental
|
|
48
|
-
from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config
|
|
49
41
|
|
|
50
42
|
from model_compression_toolkit.core.keras.kpi_data_facade import keras_kpi_data, keras_kpi_data_experimental
|
|
51
43
|
from model_compression_toolkit.core.pytorch.kpi_data_facade import pytorch_kpi_data, pytorch_kpi_data_experimental
|
|
52
44
|
|
|
53
|
-
from model_compression_toolkit.
|
|
45
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.load_model import keras_load_quantized_model
|
|
54
46
|
|
|
55
|
-
|
|
47
|
+
|
|
48
|
+
from model_compression_toolkit import exporter
|
|
49
|
+
|
|
50
|
+
from model_compression_toolkit import gptq
|
|
51
|
+
from model_compression_toolkit.gptq import GradientPTQConfig
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
__version__ = "1.8.0"
|
|
@@ -51,4 +51,4 @@ class BaseModelBuilder(ABC):
|
|
|
51
51
|
Returns: A framework's model built from its graph.
|
|
52
52
|
|
|
53
53
|
"""
|
|
54
|
-
raise NotImplemented(f'{self.__class__.__name__} have to implement build_model method.')
|
|
54
|
+
raise NotImplemented(f'{self.__class__.__name__} have to implement build_model method.') # pragma: no cover
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
import numpy as np
|
|
17
|
+
from model_compression_toolkit.core.common.logger import Logger
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class BaseCollector(object):
|
|
@@ -33,7 +34,8 @@ class BaseCollector(object):
|
|
|
33
34
|
|
|
34
35
|
"""
|
|
35
36
|
|
|
36
|
-
raise
|
|
37
|
+
raise NotImplemented(
|
|
38
|
+
f'{self.__class__.__name__} needs to implement scale operation for its state.') # pragma: no cover
|
|
37
39
|
|
|
38
40
|
def shift(self, shift_value: np.ndarray):
|
|
39
41
|
"""
|
|
@@ -43,7 +45,8 @@ class BaseCollector(object):
|
|
|
43
45
|
|
|
44
46
|
"""
|
|
45
47
|
|
|
46
|
-
raise
|
|
48
|
+
raise NotImplemented(
|
|
49
|
+
f'{self.__class__.__name__} needs to implement shift operation for its state.') # pragma: no cover
|
|
47
50
|
|
|
48
51
|
def update_legal_status(self, is_illegal: bool):
|
|
49
52
|
"""
|
|
@@ -63,5 +66,5 @@ class BaseCollector(object):
|
|
|
63
66
|
"""
|
|
64
67
|
|
|
65
68
|
if not self.is_legal:
|
|
66
|
-
|
|
67
|
-
|
|
69
|
+
Logger.exception(f'{self.__class__.__name__} was manipulated per-channel,'
|
|
70
|
+
'but collected per-tensor. Data is invalid.') # pragma: no cover
|
|
@@ -37,7 +37,7 @@ class BaseStatsCollector(object):
|
|
|
37
37
|
Returns whether this tensor requires statistics collection or not.
|
|
38
38
|
Should be implemented in extending classes.
|
|
39
39
|
"""
|
|
40
|
-
raise
|
|
40
|
+
raise NotImplemented(f'require_collection is not implemented in {self.__class__.__name__}') # pragma: no cover
|
|
41
41
|
|
|
42
42
|
def update_statistics(self,
|
|
43
43
|
x: Any):
|
|
@@ -47,7 +47,7 @@ class BaseStatsCollector(object):
|
|
|
47
47
|
Args:
|
|
48
48
|
x: Tensor.
|
|
49
49
|
"""
|
|
50
|
-
raise
|
|
50
|
+
raise NotImplemented(f'update_statistics is not implemented in {self.__class__.__name__}') # pragma: no cover
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
class StatsCollector(BaseStatsCollector):
|
|
@@ -21,10 +21,11 @@ FOUND_TF = importlib.util.find_spec(TENSORFLOW) is not None and importlib.util.f
|
|
|
21
21
|
"tensorflow_model_optimization") is not None
|
|
22
22
|
FOUND_TORCH = importlib.util.find_spec("torch") is not None
|
|
23
23
|
FOUND_ONNX = importlib.util.find_spec("onnx") is not None
|
|
24
|
+
FOUND_ONNXRUNTIME = importlib.util.find_spec("onnxruntime") is not None
|
|
24
25
|
|
|
25
26
|
WEIGHTS_SIGNED = True
|
|
26
27
|
# Minimal threshold to use for quantization ranges:
|
|
27
|
-
MIN_THRESHOLD = (2 ** -
|
|
28
|
+
MIN_THRESHOLD = (2 ** -16)
|
|
28
29
|
EPS = 1e-8
|
|
29
30
|
MULTIPLIER_N_BITS = 8
|
|
30
31
|
|
|
@@ -114,12 +115,16 @@ ACTIVATION_QUANT_PARAMS_FN = 'activation_quantization_params_fn'
|
|
|
114
115
|
WEIGHTS_QUANT_PARAMS_FN = 'weights_quantization_params_fn'
|
|
115
116
|
WEIGHTS_CHANNELS_AXIS = 'weights_channels_axis'
|
|
116
117
|
|
|
117
|
-
# GPTQ Parameters
|
|
118
|
-
GUMBEL_MAX_ITER = 10000
|
|
119
|
-
|
|
120
118
|
# Memory graph constants
|
|
121
119
|
DUMMY_NODE = 'dummy_node'
|
|
122
120
|
DUMMY_TENSOR = 'dummy_tensor'
|
|
123
121
|
|
|
124
122
|
# TP Model constants
|
|
125
123
|
OPS_SET_LIST = 'ops_set_list'
|
|
124
|
+
|
|
125
|
+
# TF Input node base name
|
|
126
|
+
INPUT_BASE_NAME = 'base_input'
|
|
127
|
+
|
|
128
|
+
# Jacobian-weights constants
|
|
129
|
+
MIN_JACOBIANS_ITER = 10
|
|
130
|
+
JACOBIANS_COMP_TOLERANCE = 1e-3
|
|
@@ -44,7 +44,7 @@ class FrameworkImplementation(ABC):
|
|
|
44
44
|
Returns: Module of the framework constants.
|
|
45
45
|
|
|
46
46
|
"""
|
|
47
|
-
raise
|
|
47
|
+
raise NotImplemented(f'{self.__class__.__name__} did not supply a constants module.') # pragma: no cover
|
|
48
48
|
|
|
49
49
|
@abstractmethod
|
|
50
50
|
def to_numpy(self, tensor: Any) -> np.ndarray:
|
|
@@ -57,7 +57,7 @@ class FrameworkImplementation(ABC):
|
|
|
57
57
|
Numpy array converted from the input tensor.
|
|
58
58
|
"""
|
|
59
59
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
60
|
-
f'framework\'s to_numpy method.')
|
|
60
|
+
f'framework\'s to_numpy method.') # pragma: no cover
|
|
61
61
|
|
|
62
62
|
@abstractmethod
|
|
63
63
|
def to_tensor(self, tensor: np.ndarray) -> Any:
|
|
@@ -70,7 +70,7 @@ class FrameworkImplementation(ABC):
|
|
|
70
70
|
Framework's tensor converted from the input Numpy array.
|
|
71
71
|
"""
|
|
72
72
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
73
|
-
f'framework\'s to_tensor method.')
|
|
73
|
+
f'framework\'s to_tensor method.') # pragma: no cover
|
|
74
74
|
|
|
75
75
|
@abstractmethod
|
|
76
76
|
def model_reader(self,
|
|
@@ -86,7 +86,7 @@ class FrameworkImplementation(ABC):
|
|
|
86
86
|
Graph representing the input model.
|
|
87
87
|
"""
|
|
88
88
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
89
|
-
f'framework\'s model_reader method.')
|
|
89
|
+
f'framework\'s model_reader method.') # pragma: no cover
|
|
90
90
|
|
|
91
91
|
@abstractmethod
|
|
92
92
|
def model_builder(self,
|
|
@@ -111,7 +111,7 @@ class FrameworkImplementation(ABC):
|
|
|
111
111
|
A tuple of the model that was built and an UserInformation object.
|
|
112
112
|
"""
|
|
113
113
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
114
|
-
f'framework\'s model_builder method.')
|
|
114
|
+
f'framework\'s model_builder method.') # pragma: no cover
|
|
115
115
|
|
|
116
116
|
@abstractmethod
|
|
117
117
|
def run_model_inference(self,
|
|
@@ -128,7 +128,7 @@ class FrameworkImplementation(ABC):
|
|
|
128
128
|
The frameworks model's output.
|
|
129
129
|
"""
|
|
130
130
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
131
|
-
f'framework\'s run_model_inference method.')
|
|
131
|
+
f'framework\'s run_model_inference method.') # pragma: no cover
|
|
132
132
|
|
|
133
133
|
@abstractmethod
|
|
134
134
|
def shift_negative_correction(self,
|
|
@@ -147,7 +147,7 @@ class FrameworkImplementation(ABC):
|
|
|
147
147
|
Graph after SNC.
|
|
148
148
|
"""
|
|
149
149
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
150
|
-
f'framework\'s apply_shift_negative_correction method.')
|
|
150
|
+
f'framework\'s apply_shift_negative_correction method.') # pragma: no cover
|
|
151
151
|
|
|
152
152
|
@abstractmethod
|
|
153
153
|
def attach_sc_to_node(self, node: BaseNode, fw_info: FrameworkInfo) -> BaseStatsCollector:
|
|
@@ -163,7 +163,7 @@ class FrameworkImplementation(ABC):
|
|
|
163
163
|
Statistics collector for the node.
|
|
164
164
|
"""
|
|
165
165
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
166
|
-
f'framework\'s attach_sc_to_node method.')
|
|
166
|
+
f'framework\'s attach_sc_to_node method.') # pragma: no cover
|
|
167
167
|
|
|
168
168
|
@abstractmethod
|
|
169
169
|
def get_substitutions_channel_equalization(self,
|
|
@@ -180,7 +180,7 @@ class FrameworkImplementation(ABC):
|
|
|
180
180
|
A list of the framework substitutions used after we collect statistics.
|
|
181
181
|
"""
|
|
182
182
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
183
|
-
f'framework\'s get_substitutions_channel_equalization method.')
|
|
183
|
+
f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover
|
|
184
184
|
|
|
185
185
|
@abstractmethod
|
|
186
186
|
def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List[common.BaseSubstitution]:
|
|
@@ -190,7 +190,7 @@ class FrameworkImplementation(ABC):
|
|
|
190
190
|
|
|
191
191
|
"""
|
|
192
192
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
193
|
-
f'framework\'s get_substitutions_prepare_graph method.')
|
|
193
|
+
f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover
|
|
194
194
|
|
|
195
195
|
@abstractmethod
|
|
196
196
|
def get_substitutions_pre_statistics_collection(self, quant_config: QuantizationConfig) -> \
|
|
@@ -204,7 +204,7 @@ class FrameworkImplementation(ABC):
|
|
|
204
204
|
|
|
205
205
|
"""
|
|
206
206
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
207
|
-
f'framework\'s get_substitutions_pre_statistics_collection method.')
|
|
207
|
+
f'framework\'s get_substitutions_pre_statistics_collection method.') # pragma: no cover
|
|
208
208
|
|
|
209
209
|
@abstractmethod
|
|
210
210
|
def get_linear_collapsing_substitution(self) -> common.BaseSubstitution:
|
|
@@ -212,7 +212,7 @@ class FrameworkImplementation(ABC):
|
|
|
212
212
|
Returns: linear collapsing substitution
|
|
213
213
|
"""
|
|
214
214
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
215
|
-
f'framework\'s get_linear_collapsing_substitution method.')
|
|
215
|
+
f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover
|
|
216
216
|
|
|
217
217
|
@abstractmethod
|
|
218
218
|
def get_substitutions_statistics_correction(self, quant_config: QuantizationConfig) -> \
|
|
@@ -227,7 +227,7 @@ class FrameworkImplementation(ABC):
|
|
|
227
227
|
A list of the framework substitutions used for statistics correction.
|
|
228
228
|
"""
|
|
229
229
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
230
|
-
f'framework\'s get_substitutions_statistics_correction method.')
|
|
230
|
+
f'framework\'s get_substitutions_statistics_correction method.') # pragma: no cover
|
|
231
231
|
|
|
232
232
|
@abstractmethod
|
|
233
233
|
def get_residual_collapsing_substitution(self) -> List[common.BaseSubstitution]:
|
|
@@ -235,7 +235,7 @@ class FrameworkImplementation(ABC):
|
|
|
235
235
|
Returns: A list of the framework substitutions used for residual collapsing
|
|
236
236
|
"""
|
|
237
237
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
238
|
-
f'framework\'s get_residual_collapsing_substitution method.')
|
|
238
|
+
f'framework\'s get_residual_collapsing_substitution method.') # pragma: no cover
|
|
239
239
|
|
|
240
240
|
@abstractmethod
|
|
241
241
|
def get_substitutions_pre_build(self) -> List[common.BaseSubstitution]:
|
|
@@ -245,7 +245,7 @@ class FrameworkImplementation(ABC):
|
|
|
245
245
|
|
|
246
246
|
"""
|
|
247
247
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
248
|
-
f'framework\'s get_substitutions_pre_build method.')
|
|
248
|
+
f'framework\'s get_substitutions_pre_build method.') # pragma: no cover
|
|
249
249
|
|
|
250
250
|
@abstractmethod
|
|
251
251
|
def get_substitutions_post_statistics_collection(self, quant_config: QuantizationConfig) -> List[
|
|
@@ -260,7 +260,7 @@ class FrameworkImplementation(ABC):
|
|
|
260
260
|
A list of the framework substitutions used after we collect statistics.
|
|
261
261
|
"""
|
|
262
262
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
263
|
-
f'framework\'s get_substitutions_post_statistics_collection method.')
|
|
263
|
+
f'framework\'s get_substitutions_post_statistics_collection method.') # pragma: no cover
|
|
264
264
|
|
|
265
265
|
@abstractmethod
|
|
266
266
|
def get_substitutions_virtual_weights_activation_coupling(self) -> List[common.BaseSubstitution]:
|
|
@@ -269,7 +269,8 @@ class FrameworkImplementation(ABC):
|
|
|
269
269
|
"""
|
|
270
270
|
|
|
271
271
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
272
|
-
f'framework\'s get_substitutions_virtual_weights_activation_coupling
|
|
272
|
+
f'framework\'s get_substitutions_virtual_weights_activation_coupling '
|
|
273
|
+
f'method.') # pragma: no cover
|
|
273
274
|
|
|
274
275
|
@abstractmethod
|
|
275
276
|
def get_substitutions_after_second_moment_correction(self, quant_config: QuantizationConfig) \
|
|
@@ -284,7 +285,8 @@ class FrameworkImplementation(ABC):
|
|
|
284
285
|
A list of the framework substitutions used after we apply second moment statistics.
|
|
285
286
|
"""
|
|
286
287
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
287
|
-
f'framework\'s get_substitutions_after_second_moment_correction
|
|
288
|
+
f'framework\'s get_substitutions_after_second_moment_correction '
|
|
289
|
+
f'method.') # pragma: no cover
|
|
288
290
|
|
|
289
291
|
@abstractmethod
|
|
290
292
|
def get_gptq_trainer_obj(self):
|
|
@@ -292,7 +294,7 @@ class FrameworkImplementation(ABC):
|
|
|
292
294
|
Returns: GPTQTrainer object
|
|
293
295
|
"""
|
|
294
296
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
295
|
-
f'framework\'s get_gptq_trainer method.')
|
|
297
|
+
f'framework\'s get_gptq_trainer method.') # pragma: no cover
|
|
296
298
|
|
|
297
299
|
@abstractmethod
|
|
298
300
|
def get_sensitivity_evaluator(self,
|
|
@@ -317,7 +319,7 @@ class FrameworkImplementation(ABC):
|
|
|
317
319
|
"""
|
|
318
320
|
|
|
319
321
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
320
|
-
f'framework\'s get_sensitivity_evaluator method.')
|
|
322
|
+
f'framework\'s get_sensitivity_evaluator method.') # pragma: no cover
|
|
321
323
|
|
|
322
324
|
def get_node_prior_info(self, node: BaseNode,
|
|
323
325
|
fw_info: FrameworkInfo,
|
|
@@ -335,7 +337,7 @@ class FrameworkImplementation(ABC):
|
|
|
335
337
|
"""
|
|
336
338
|
|
|
337
339
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
338
|
-
f'framework\'s get_node_prior_info method.')
|
|
340
|
+
f'framework\'s get_node_prior_info method.') # pragma: no cover
|
|
339
341
|
|
|
340
342
|
def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
|
|
341
343
|
"""
|
|
@@ -346,7 +348,7 @@ class FrameworkImplementation(ABC):
|
|
|
346
348
|
"""
|
|
347
349
|
|
|
348
350
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
349
|
-
f'framework\'s count_node_for_mixed_precision_interest_points method.')
|
|
351
|
+
f'framework\'s count_node_for_mixed_precision_interest_points method.') # pragma: no cover
|
|
350
352
|
|
|
351
353
|
def get_node_distance_fn(self, layer_class: type,
|
|
352
354
|
framework_attrs: Dict[str, Any],
|
|
@@ -365,7 +367,7 @@ class FrameworkImplementation(ABC):
|
|
|
365
367
|
"""
|
|
366
368
|
|
|
367
369
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
368
|
-
f'framework\'s get_node_distance_fn method.')
|
|
370
|
+
f'framework\'s get_node_distance_fn method.') # pragma: no cover
|
|
369
371
|
|
|
370
372
|
@abstractmethod
|
|
371
373
|
def get_model_layers_names(self,
|
|
@@ -381,7 +383,7 @@ class FrameworkImplementation(ABC):
|
|
|
381
383
|
"""
|
|
382
384
|
|
|
383
385
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
384
|
-
f'framework\'s get_model_layers_names method.')
|
|
386
|
+
f'framework\'s get_model_layers_names method.') # pragma: no cover
|
|
385
387
|
|
|
386
388
|
@abstractmethod
|
|
387
389
|
def get_model_layer_by_name(self,
|
|
@@ -399,7 +401,7 @@ class FrameworkImplementation(ABC):
|
|
|
399
401
|
"""
|
|
400
402
|
|
|
401
403
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
402
|
-
f'framework\'s get_model_layer_by_name method.')
|
|
404
|
+
f'framework\'s get_model_layer_by_name method.') # pragma: no cover
|
|
403
405
|
|
|
404
406
|
@abstractmethod
|
|
405
407
|
def model_grad(self,
|
|
@@ -433,7 +435,7 @@ class FrameworkImplementation(ABC):
|
|
|
433
435
|
"""
|
|
434
436
|
|
|
435
437
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
436
|
-
f'framework\'s model_grad method.')
|
|
438
|
+
f'framework\'s model_grad method.') # pragma: no cover
|
|
437
439
|
|
|
438
440
|
@abstractmethod
|
|
439
441
|
def is_node_compatible_for_metric_outputs(self,
|
|
@@ -450,7 +452,7 @@ class FrameworkImplementation(ABC):
|
|
|
450
452
|
"""
|
|
451
453
|
|
|
452
454
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
453
|
-
f'framework\'s is_node_compatible_for_metric_outputs method.')
|
|
455
|
+
f'framework\'s is_node_compatible_for_metric_outputs method.') # pragma: no cover
|
|
454
456
|
|
|
455
457
|
@abstractmethod
|
|
456
458
|
def get_node_mac_operations(self,
|
|
@@ -467,7 +469,7 @@ class FrameworkImplementation(ABC):
|
|
|
467
469
|
"""
|
|
468
470
|
|
|
469
471
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
470
|
-
f'framework\'s get_node_mac_operations method.')
|
|
472
|
+
f'framework\'s get_node_mac_operations method.') # pragma: no cover
|
|
471
473
|
|
|
472
474
|
@abstractmethod
|
|
473
475
|
def apply_second_moment_correction(self,
|
|
@@ -488,7 +490,7 @@ class FrameworkImplementation(ABC):
|
|
|
488
490
|
A Graph after second moment correction.
|
|
489
491
|
"""
|
|
490
492
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
491
|
-
f'framework\'s apply_second_moment_correction method.')
|
|
493
|
+
f'framework\'s apply_second_moment_correction method.') # pragma: no cover
|
|
492
494
|
|
|
493
495
|
@abstractmethod
|
|
494
496
|
def sensitivity_eval_inference(self,
|
|
@@ -505,4 +507,4 @@ class FrameworkImplementation(ABC):
|
|
|
505
507
|
The output of the model inference on the given input.
|
|
506
508
|
"""
|
|
507
509
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
508
|
-
f'framework\'s sensitivity_eval_inference method.')
|
|
510
|
+
f'framework\'s sensitivity_eval_inference method.') # pragma: no cover
|
|
@@ -75,7 +75,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
75
75
|
self.fused_nodes = []
|
|
76
76
|
|
|
77
77
|
def set_fw_info(self,
|
|
78
|
-
|
|
78
|
+
fw_info: FrameworkInfo):
|
|
79
79
|
"""
|
|
80
80
|
Set the graph's framework info.
|
|
81
81
|
Args:
|
|
@@ -93,7 +93,6 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
93
93
|
"""
|
|
94
94
|
self.tpc = tpc
|
|
95
95
|
|
|
96
|
-
|
|
97
96
|
def get_topo_sorted_nodes(self):
|
|
98
97
|
"""
|
|
99
98
|
Returns: a list of toposorted nodes.
|
|
@@ -216,7 +215,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
216
215
|
|
|
217
216
|
sc = self.node_to_in_stats_collector.get(n)
|
|
218
217
|
if sc is None:
|
|
219
|
-
|
|
218
|
+
Logger.error(f'Input statistics collector of node {n.name} is None') # pragma: no cover
|
|
220
219
|
return sc
|
|
221
220
|
|
|
222
221
|
def scale_stats_collector(self,
|
|
@@ -350,7 +349,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
350
349
|
input_nodes_output_index = [0] * len(input_nodes)
|
|
351
350
|
|
|
352
351
|
if len(input_nodes_output_index) != len(input_nodes):
|
|
353
|
-
|
|
352
|
+
Logger.error('Graph.add_node_with_in_edges: input_nodes & input_nodes_output_index must be the same '
|
|
353
|
+
'length') # pragma: no cover
|
|
354
354
|
|
|
355
355
|
self.add_node(new_node)
|
|
356
356
|
for sink_index, (in_node, source_index) in enumerate(zip(input_nodes, input_nodes_output_index)):
|
|
@@ -420,12 +420,14 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
420
420
|
output_nodes = [ot.node for ot in self.get_outputs()] # get output nodes from namedtuples
|
|
421
421
|
if node_to_remove in output_nodes: # If node is in the graph's outputs, the outputs should be updated
|
|
422
422
|
if new_graph_outputs is None:
|
|
423
|
-
Logger.critical(
|
|
423
|
+
Logger.critical(
|
|
424
|
+
f'{node_to_remove.name} is in graph outputs, but new outputs were not given.') # pragma: no cover
|
|
424
425
|
self.set_outputs(new_graph_outputs)
|
|
425
426
|
|
|
426
427
|
if node_to_remove in self.get_inputs(): # If node is in the graph's inputs, the inputs should be updated
|
|
427
428
|
if new_graph_inputs is None:
|
|
428
|
-
Logger.critical(
|
|
429
|
+
Logger.critical(
|
|
430
|
+
f'{node_to_remove.name} is in graph inputs, but new inputs were not given.') # pragma: no cover
|
|
429
431
|
self.set_inputs(new_graph_inputs)
|
|
430
432
|
|
|
431
433
|
# Make sure there are no connected edges left to the node before removing it.
|
|
@@ -17,7 +17,6 @@
|
|
|
17
17
|
import logging
|
|
18
18
|
import os
|
|
19
19
|
from datetime import datetime
|
|
20
|
-
from os import path
|
|
21
20
|
from pathlib import Path
|
|
22
21
|
|
|
23
22
|
LOGGER_NAME = 'Constrained Model Optimization'
|
|
@@ -43,7 +42,7 @@ class Logger:
|
|
|
43
42
|
|
|
44
43
|
"""
|
|
45
44
|
|
|
46
|
-
if not path.exists(log_path):
|
|
45
|
+
if not os.path.exists(log_path):
|
|
47
46
|
Path(log_path).mkdir(parents=True, exist_ok=True)
|
|
48
47
|
|
|
49
48
|
@staticmethod
|
|
@@ -93,6 +92,15 @@ class Logger:
|
|
|
93
92
|
|
|
94
93
|
print(f'log file is in {log_name}')
|
|
95
94
|
|
|
95
|
+
@staticmethod
|
|
96
|
+
def shutdown():
|
|
97
|
+
"""
|
|
98
|
+
An orderly command to shutdown by flushing and closing all logging handlers.
|
|
99
|
+
|
|
100
|
+
"""
|
|
101
|
+
Logger.LOG_PATH = None
|
|
102
|
+
logging.shutdown()
|
|
103
|
+
|
|
96
104
|
########################################
|
|
97
105
|
# Delegating methods to wrapped logger
|
|
98
106
|
########################################
|
|
@@ -41,19 +41,19 @@ class BaseMatcher(object):
|
|
|
41
41
|
"""
|
|
42
42
|
Return a matcher to check the logic AND of two BaseMatchers on an object.
|
|
43
43
|
"""
|
|
44
|
-
raise NotImplemented
|
|
44
|
+
raise NotImplemented # pragma: no cover
|
|
45
45
|
|
|
46
46
|
def __or__(self, other: Any):
|
|
47
47
|
"""
|
|
48
48
|
Return a matcher to check the logic OR of BaseMatchers on an object.
|
|
49
49
|
"""
|
|
50
|
-
raise NotImplemented
|
|
50
|
+
raise NotImplemented # pragma: no cover
|
|
51
51
|
|
|
52
52
|
def logic_not(self):
|
|
53
53
|
"""
|
|
54
54
|
Return a matcher to check the logic NOT of the BaseMatcher on an object.
|
|
55
55
|
"""
|
|
56
|
-
raise NotImplemented
|
|
56
|
+
raise NotImplemented # pragma: no cover
|
|
57
57
|
|
|
58
58
|
def logic_and(self, other: Any):
|
|
59
59
|
"""
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py
CHANGED
|
@@ -127,7 +127,8 @@ class MixedPrecisionQuantizationConfig(QuantizationConfig):
|
|
|
127
127
|
elif hasattr(_dummy_mp_config_experimental, k):
|
|
128
128
|
mp_dict.update({k: v})
|
|
129
129
|
else:
|
|
130
|
-
|
|
130
|
+
Logger.error(f'Attribute "{k}" mismatch: exists in MixedPrecisionQuantizationConfig but not in '
|
|
131
|
+
f'MixedPrecisionQuantizationConfigV2') # pragma: no cover
|
|
131
132
|
|
|
132
133
|
return QuantizationConfig(**qc_dict), MixedPrecisionQuantizationConfigV2(**mp_dict)
|
|
133
134
|
|
|
@@ -75,7 +75,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
|
|
|
75
75
|
|
|
76
76
|
# target_kpi have to be passed. If it was not passed, the facade is not supposed to get here by now.
|
|
77
77
|
if target_kpi is None:
|
|
78
|
-
Logger.critical('Target KPI have to be passed for search_methods bit-width configuration')
|
|
78
|
+
Logger.critical('Target KPI have to be passed for search_methods bit-width configuration') # pragma: no cover
|
|
79
79
|
|
|
80
80
|
# Set graph for MP search
|
|
81
81
|
graph = copy.deepcopy(graph_to_search_cfg) # Copy graph before searching
|
|
@@ -114,7 +114,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
|
|
|
114
114
|
if search_method in search_methods: # Get a specific search function
|
|
115
115
|
search_method_fn = search_methods.get(search_method)
|
|
116
116
|
else:
|
|
117
|
-
raise NotImplemented
|
|
117
|
+
raise NotImplemented # pragma: no cover
|
|
118
118
|
|
|
119
119
|
# Search for the desired mixed-precision configuration
|
|
120
120
|
result_bit_cfg = search_method_fn(search_manager,
|
|
@@ -350,8 +350,8 @@ class ConfigReconstructionHelper:
|
|
|
350
350
|
|
|
351
351
|
if changed_virtual_nodes_idx is not None:
|
|
352
352
|
if original_base_config is None:
|
|
353
|
-
Logger.critical("Must provide a base original config in order to run config reconstruction for partial"
|
|
354
|
-
"set of nodes.")
|
|
353
|
+
Logger.critical("Must provide a base original config in order to run config reconstruction for partial"
|
|
354
|
+
"set of nodes.") # pragma: no cover
|
|
355
355
|
|
|
356
356
|
updated_virtual_nodes = \
|
|
357
357
|
[(idx, self.virtual_graph.get_configurable_sorted_nodes()[idx]) for idx in changed_virtual_nodes_idx]
|
|
@@ -22,6 +22,8 @@ from model_compression_toolkit.core.common import Logger
|
|
|
22
22
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI, KPITarget
|
|
23
23
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import MixedPrecisionSearchManager
|
|
24
24
|
|
|
25
|
+
# Limit ILP solver runtime in seconds
|
|
26
|
+
SOLVER_TIME_LIMIT = 60
|
|
25
27
|
|
|
26
28
|
def mp_integer_programming_search(search_manager: MixedPrecisionSearchManager,
|
|
27
29
|
target_kpi: KPI = None) -> List[int]:
|
|
@@ -64,7 +66,10 @@ def mp_integer_programming_search(search_manager: MixedPrecisionSearchManager,
|
|
|
64
66
|
target_kpi,
|
|
65
67
|
search_manager)
|
|
66
68
|
|
|
67
|
-
|
|
69
|
+
# Use default PULP solver. Limit runtime in seconds
|
|
70
|
+
solver = PULP_CBC_CMD(timeLimit=SOLVER_TIME_LIMIT)
|
|
71
|
+
lp_problem.solve(solver=solver) # Try to solve the problem.
|
|
72
|
+
|
|
68
73
|
assert lp_problem.status == LpStatusOptimal, Logger.critical(
|
|
69
74
|
"No solution was found during solving the LP problem")
|
|
70
75
|
Logger.info(LpStatus[lp_problem.status])
|
|
@@ -30,7 +30,8 @@ class ModelValidation:
|
|
|
30
30
|
If the model has layers with different output channels index, it should throw an exception.
|
|
31
31
|
|
|
32
32
|
"""
|
|
33
|
-
raise NotImplemented(
|
|
33
|
+
raise NotImplemented(
|
|
34
|
+
f'Framework validation class did not implement validate_output_channel_consistency') # pragma: no cover
|
|
34
35
|
|
|
35
36
|
def validate(self):
|
|
36
37
|
"""
|
|
@@ -17,6 +17,8 @@
|
|
|
17
17
|
from typing import Callable, Any
|
|
18
18
|
|
|
19
19
|
import numpy as np
|
|
20
|
+
|
|
21
|
+
from model_compression_toolkit.core.common.logger import Logger
|
|
20
22
|
from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
|
|
21
23
|
get_activation_quantization_params_fn, get_weights_quantization_params_fn
|
|
22
24
|
|
|
@@ -111,7 +113,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
111
113
|
self.activation_quantization_params)
|
|
112
114
|
|
|
113
115
|
if fake_quant is None:
|
|
114
|
-
|
|
116
|
+
Logger.error('Layer is meant to be quantized but fake_quant function is None') # pragma: no cover
|
|
115
117
|
return fake_quant(tensors)
|
|
116
118
|
|
|
117
119
|
@property
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
from collections.abc import Callable
|
|
17
17
|
from functools import partial
|
|
18
18
|
|
|
19
|
+
from model_compression_toolkit.core.common.logger import Logger
|
|
19
20
|
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
20
21
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.kmeans_params import kmeans_tensor
|
|
21
22
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import \
|
|
@@ -47,8 +48,9 @@ def get_activation_quantization_params_fn(activation_quantization_method: Quanti
|
|
|
47
48
|
elif activation_quantization_method == QuantizationMethod.LUT_POT_QUANTIZER:
|
|
48
49
|
params_fn = lut_kmeans_histogram
|
|
49
50
|
else:
|
|
50
|
-
|
|
51
|
-
f'No params function for the configuration of
|
|
51
|
+
Logger.error(
|
|
52
|
+
f'No params function for the configuration of '
|
|
53
|
+
f'quantization method {activation_quantization_method}') # pragma: no cover
|
|
52
54
|
return params_fn
|
|
53
55
|
|
|
54
56
|
|
|
@@ -75,6 +77,7 @@ def get_weights_quantization_params_fn(weights_quantization_method: Quantization
|
|
|
75
77
|
elif weights_quantization_method == QuantizationMethod.LUT_SYM_QUANTIZER:
|
|
76
78
|
params_fn = partial(lut_kmeans_tensor, is_symmetric=True)
|
|
77
79
|
else:
|
|
78
|
-
|
|
79
|
-
f'No params function for the configuration of
|
|
80
|
+
Logger.error(
|
|
81
|
+
f'No params function for the configuration of '
|
|
82
|
+
f'quantization method {weights_quantization_method}') # pragma: no cover
|
|
80
83
|
return params_fn
|