mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -60,7 +60,8 @@ def lut_kmeans_tensor(tensor_data: np.ndarray,
|
|
|
60
60
|
the thresholds per channel and the multiplier num bits.
|
|
61
61
|
"""
|
|
62
62
|
if n_bits > MULTIPLIER_N_BITS:
|
|
63
|
-
Logger.critical(f'Look-Up-Table bit configuration has {n_bits} bits, but must be less or equal to
|
|
63
|
+
Logger.critical(f'Look-Up-Table bit configuration has {n_bits} bits, but must be less or equal to '
|
|
64
|
+
f'{MULTIPLIER_N_BITS}') # pragma: no cover
|
|
64
65
|
# TODO: need to set this externally
|
|
65
66
|
if len(np.unique(tensor_data.flatten())) < 2 ** n_bits:
|
|
66
67
|
n_clusters = len(np.unique(tensor_data.flatten()))
|
|
@@ -115,7 +116,8 @@ def lut_kmeans_histogram(bins: np.ndarray,
|
|
|
115
116
|
"""
|
|
116
117
|
|
|
117
118
|
if n_bits >= MULTIPLIER_N_BITS:
|
|
118
|
-
Logger.critical(f'Look-Up-Table bit configuration has {n_bits} bits. It must be less then
|
|
119
|
+
Logger.critical(f'Look-Up-Table bit configuration has {n_bits} bits. It must be less then '
|
|
120
|
+
f'{MULTIPLIER_N_BITS}') # pragma: no cover
|
|
119
121
|
|
|
120
122
|
bins_with_values = np.abs(bins)[1:][counts > 0]
|
|
121
123
|
if len(np.unique(bins_with_values.flatten())) < 2 ** n_bits:
|
|
@@ -49,25 +49,22 @@ def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConf
|
|
|
49
49
|
bins_counts)
|
|
50
50
|
min_value, max_value = out_stats_container.get_min_max_values()
|
|
51
51
|
|
|
52
|
-
if nodes_prior_info
|
|
53
|
-
|
|
54
|
-
signed = min_value < 0
|
|
55
|
-
else:
|
|
56
|
-
signed = np.any(bins_values[:-1][bins_counts > 0] < 0)
|
|
57
|
-
|
|
58
|
-
if nodes_prior_info.is_output_bounded():
|
|
59
|
-
if activation_quant_cfg.activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
|
|
60
|
-
activation_quant_cfg.activation_quantization_params_fn = \
|
|
61
|
-
quantization_params_generation.power_of_two_no_clipping_selection_min_max
|
|
62
|
-
elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.SYMMETRIC:
|
|
63
|
-
activation_quant_cfg.activation_quantization_params_fn = \
|
|
64
|
-
quantization_params_generation.symmetric_no_clipping_selection_min_max
|
|
65
|
-
elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.UNIFORM:
|
|
66
|
-
activation_quant_cfg.activation_quantization_params_fn = \
|
|
67
|
-
quantization_params_generation.uniform_no_clipping_selection_min_max
|
|
52
|
+
if nodes_prior_info.is_output_bounded():
|
|
53
|
+
signed = min_value < 0
|
|
68
54
|
else:
|
|
69
55
|
signed = np.any(bins_values[:-1][bins_counts > 0] < 0)
|
|
70
56
|
|
|
57
|
+
if nodes_prior_info.is_output_bounded():
|
|
58
|
+
if activation_quant_cfg.activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
|
|
59
|
+
activation_quant_cfg.activation_quantization_params_fn = \
|
|
60
|
+
quantization_params_generation.power_of_two_no_clipping_selection_min_max
|
|
61
|
+
elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.SYMMETRIC:
|
|
62
|
+
activation_quant_cfg.activation_quantization_params_fn = \
|
|
63
|
+
quantization_params_generation.symmetric_no_clipping_selection_min_max
|
|
64
|
+
elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.UNIFORM:
|
|
65
|
+
activation_quant_cfg.activation_quantization_params_fn = \
|
|
66
|
+
quantization_params_generation.uniform_no_clipping_selection_min_max
|
|
67
|
+
|
|
71
68
|
activation_params = activation_quant_cfg.activation_quantization_params_fn(bins_values,
|
|
72
69
|
bins_counts,
|
|
73
70
|
activation_quant_cfg.l_p_value,
|
|
@@ -78,4 +75,4 @@ def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConf
|
|
|
78
75
|
quant_error_method=activation_quant_cfg.activation_error_method)
|
|
79
76
|
activation_params.update({SIGNED: signed})
|
|
80
77
|
|
|
81
|
-
return activation_params
|
|
78
|
+
return activation_params
|
|
@@ -18,7 +18,7 @@ from typing import Tuple, List
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
|
|
20
20
|
from model_compression_toolkit.core.common.constants import MIN_THRESHOLD, EPS
|
|
21
|
-
|
|
21
|
+
from model_compression_toolkit.core import common
|
|
22
22
|
|
|
23
23
|
def max_power_of_two(x: np.ndarray,
|
|
24
24
|
min_threshold: float = MIN_THRESHOLD) -> np.ndarray:
|
|
@@ -235,7 +235,14 @@ def get_tensor_max(tensor_data: np.ndarray,
|
|
|
235
235
|
Returns: maximal value (or values).
|
|
236
236
|
|
|
237
237
|
"""
|
|
238
|
-
|
|
238
|
+
if n_bits < 1:
|
|
239
|
+
common.Logger.error("n_bits must be positive")
|
|
240
|
+
if is_uniform_quantization:
|
|
241
|
+
expansion_factor = 1.0
|
|
242
|
+
elif n_bits == 1:
|
|
243
|
+
expansion_factor = 0.0
|
|
244
|
+
else:
|
|
245
|
+
expansion_factor = np.power(2.0, n_bits - 1) / (np.power(2.0, n_bits - 1) - 1)
|
|
239
246
|
if per_channel:
|
|
240
247
|
output_shape = get_output_shape(tensor_data.shape, channel_axis)
|
|
241
248
|
reshaped_tensor_data = reshape_tensor_for_per_channel_search(tensor_data, channel_axis)
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
|
|
18
|
+
from model_compression_toolkit.core.common.logger import Logger
|
|
18
19
|
from model_compression_toolkit.core.common.constants import RANGE_MIN, RANGE_MAX, THRESHOLD
|
|
19
20
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import uniform_quantize_tensor, \
|
|
20
21
|
quantize_tensor
|
|
@@ -51,9 +52,9 @@ def power_of_two_quantizer(tensor_data: np.ndarray,
|
|
|
51
52
|
"""
|
|
52
53
|
threshold = quantization_params.get(THRESHOLD)
|
|
53
54
|
if threshold is None:
|
|
54
|
-
|
|
55
|
+
Logger.error(f"{THRESHOLD} parameter must be defined in 'quantization_params'") # pragma: no cover
|
|
55
56
|
if not threshold_is_power_of_two(threshold, per_channel):
|
|
56
|
-
|
|
57
|
+
Logger.error(f"Expects {THRESHOLD} parameter to be a power of two, but got {threshold}") # pragma: no cover
|
|
57
58
|
|
|
58
59
|
return quantize_tensor(tensor_data,
|
|
59
60
|
threshold,
|
|
@@ -84,7 +85,7 @@ def symmetric_quantizer(tensor_data: np.ndarray,
|
|
|
84
85
|
"""
|
|
85
86
|
threshold = quantization_params.get(THRESHOLD)
|
|
86
87
|
if threshold is None:
|
|
87
|
-
|
|
88
|
+
Logger.error(f"{THRESHOLD} parameter must be defined in 'quantization_params'") # pragma: no cover
|
|
88
89
|
|
|
89
90
|
return quantize_tensor(tensor_data,
|
|
90
91
|
threshold,
|
|
@@ -115,6 +116,6 @@ def uniform_quantizer(tensor_data: np.ndarray,
|
|
|
115
116
|
range_min = quantization_params.get(RANGE_MIN)
|
|
116
117
|
range_max = quantization_params.get(RANGE_MAX)
|
|
117
118
|
if range_min is None or range_max is None:
|
|
118
|
-
|
|
119
|
+
Logger.error("'quantization range' parameters must be defined in 'quantization_params'") # pragma: no cover
|
|
119
120
|
|
|
120
121
|
return uniform_quantize_tensor(tensor_data, range_min, range_max, n_bits)
|
|
@@ -108,7 +108,7 @@ def create_node_activation_qc(qc: QuantizationConfig,
|
|
|
108
108
|
|
|
109
109
|
activation_quantization_fn = fw_info.activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
|
|
110
110
|
if activation_quantization_fn is None:
|
|
111
|
-
Logger.critical('Unknown quantization method for activations')
|
|
111
|
+
Logger.critical('Unknown quantization method for activations') # pragma: no cover
|
|
112
112
|
|
|
113
113
|
activation_quantization_params_fn = get_activation_quantization_params_fn(op_cfg.activation_quantization_method)
|
|
114
114
|
|
|
@@ -142,14 +142,14 @@ def create_node_qc_candidate(qc: QuantizationConfig,
|
|
|
142
142
|
weights_quantization_fn = get_weights_quantization_fn(op_cfg.weights_quantization_method)
|
|
143
143
|
|
|
144
144
|
if weights_quantization_fn is None:
|
|
145
|
-
Logger.critical('Unknown quantization method for weights')
|
|
145
|
+
Logger.critical('Unknown quantization method for weights') # pragma: no cover
|
|
146
146
|
|
|
147
147
|
weights_quantization_params_fn = get_weights_quantization_params_fn(op_cfg.weights_quantization_method)
|
|
148
148
|
|
|
149
149
|
# get attributes for activation quantization
|
|
150
150
|
activation_quantization_fn = fw_info.activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
|
|
151
151
|
if activation_quantization_fn is None:
|
|
152
|
-
Logger.critical('Unknown quantization method for activations')
|
|
152
|
+
Logger.critical('Unknown quantization method for activations') # pragma: no cover
|
|
153
153
|
|
|
154
154
|
activation_quantization_params_fn = get_activation_quantization_params_fn(op_cfg.activation_quantization_method)
|
|
155
155
|
|
|
@@ -77,6 +77,13 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
|
|
|
77
77
|
num_nodes_before_substitution = len(graph.nodes)
|
|
78
78
|
num_edges_before_substitution = len(graph.edges)
|
|
79
79
|
|
|
80
|
+
# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
|
|
81
|
+
# we should skip the substitution.
|
|
82
|
+
if source_node.reuse or source_node.reuse_group is not None:
|
|
83
|
+
for qc in source_node.candidates_quantization_cfg:
|
|
84
|
+
qc.weights_quantization_cfg.weights_second_moment_correction = False
|
|
85
|
+
return graph
|
|
86
|
+
|
|
80
87
|
# We apply only on nodes with folded BatchNormalization.
|
|
81
88
|
if source_node.prior_info.std_output is None or source_node.prior_info.mean_output is None:
|
|
82
89
|
for qc in source_node.candidates_quantization_cfg:
|
|
@@ -24,6 +24,7 @@ from model_compression_toolkit.core.common.graph.graph_matchers import EdgeMatch
|
|
|
24
24
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
25
25
|
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
26
26
|
from model_compression_toolkit.core.common.constants import THRESHOLD, RANGE_MIN, RANGE_MAX
|
|
27
|
+
from model_compression_toolkit.core.common.logger import Logger
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
class BatchNormalizationRefusing(common.BaseSubstitution):
|
|
@@ -95,15 +96,22 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
|
|
|
95
96
|
|
|
96
97
|
source_node = edge_nodes[0]
|
|
97
98
|
|
|
99
|
+
# We apply only on nodes with reconstructed BatchNormalization.
|
|
100
|
+
if not source_node.final_weights_quantization_cfg.weights_second_moment_correction:
|
|
101
|
+
return graph
|
|
102
|
+
|
|
98
103
|
# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
|
|
99
104
|
# we should skip the substitution.
|
|
100
105
|
if source_node.reuse or source_node.reuse_group is not None:
|
|
101
|
-
|
|
106
|
+
Logger.exception("If the linear operator is part of a reused group we should skip the the BN folding "
|
|
107
|
+
"substitution and SMC feature") # pragma: no cover
|
|
102
108
|
|
|
103
109
|
bn_node = edge_nodes[1]
|
|
104
110
|
|
|
105
111
|
if len(graph.get_next_nodes(source_node)) > 1 or len(graph.get_prev_nodes(bn_node)) > 1:
|
|
106
|
-
|
|
112
|
+
Logger.exception(
|
|
113
|
+
"If the linear operator has multiple outputs or the bn layer has multiple inputs we should "
|
|
114
|
+
"skip the the BN folding substitution and SMC feature") # pragma: no cover
|
|
107
115
|
|
|
108
116
|
kernel = source_node.get_weights_by_keys(self.kernel_str)
|
|
109
117
|
bias = source_node.get_weights_by_keys(self.bias_str)
|
|
@@ -113,9 +121,6 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
|
|
|
113
121
|
moving_variance = bn_node.get_weights_by_keys(self.moving_variance_str)
|
|
114
122
|
eps = bn_node.framework_attr[self.epsilon_str]
|
|
115
123
|
|
|
116
|
-
if bias is None:
|
|
117
|
-
bias = 0.0
|
|
118
|
-
|
|
119
124
|
weights_scale = gamma / np.sqrt(moving_variance + eps)
|
|
120
125
|
bias = beta + (bias - moving_mean) * weights_scale
|
|
121
126
|
|
|
@@ -177,7 +182,7 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
|
|
|
177
182
|
corr_dict[THRESHOLD] = corr_threshold
|
|
178
183
|
conv_bn.final_weights_quantization_cfg.set_weights_quantization_param(corr_dict)
|
|
179
184
|
|
|
180
|
-
# In case of
|
|
185
|
+
# In case of UNIFORM weight quantization method, we update the range_min, range_max by weights_scale
|
|
181
186
|
elif conv_bn.final_weights_quantization_cfg.weights_quantization_method == QuantizationMethod.UNIFORM:
|
|
182
187
|
corr_dict = copy.deepcopy(conv_bn.final_weights_quantization_cfg.weights_quantization_params)
|
|
183
188
|
original_range_min = conv_bn.final_weights_quantization_cfg.weights_quantization_params[RANGE_MIN]
|
|
@@ -189,5 +194,5 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
|
|
|
189
194
|
conv_bn.final_weights_quantization_cfg.set_weights_quantization_param(corr_dict)
|
|
190
195
|
|
|
191
196
|
else:
|
|
192
|
-
|
|
193
|
-
|
|
197
|
+
Logger.exception("Second moment statistics correction feature disabled for models with weights "
|
|
198
|
+
"quantization method of Power of 2") # pragma: no cover
|
|
@@ -16,6 +16,7 @@ import copy
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
from typing import List, Tuple, Any, Callable
|
|
18
18
|
|
|
19
|
+
from model_compression_toolkit.core.common.logger import Logger
|
|
19
20
|
from model_compression_toolkit.core.common import FrameworkInfo, Graph, BaseNode
|
|
20
21
|
from model_compression_toolkit.core.common.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
|
|
21
22
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
@@ -25,7 +26,8 @@ from model_compression_toolkit.core.common.quantization.set_node_quantization_co
|
|
|
25
26
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
|
26
27
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
|
|
27
28
|
import get_activations_qparams
|
|
28
|
-
from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import
|
|
29
|
+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
|
|
30
|
+
_mse_error_histogram
|
|
29
31
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation import z_score_filter
|
|
30
32
|
|
|
31
33
|
"""
|
|
@@ -73,12 +75,12 @@ def op2d_bias_correction(op2d_node: BaseNode,
|
|
|
73
75
|
|
|
74
76
|
# special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters
|
|
75
77
|
if output_channel_index == input_channel_index:
|
|
76
|
-
axis_not_output_channel.remove(3)
|
|
78
|
+
axis_not_output_channel.remove(3) # 3 is the depth multiplier index
|
|
77
79
|
|
|
78
80
|
bias_correction = shift_to_correct * np.sum(kernel, axis=tuple(axis_not_output_channel))
|
|
79
81
|
op2d_node.set_weights_by_keys(bias_str, bias - bias_correction.flatten())
|
|
80
82
|
else:
|
|
81
|
-
raise NotImplementedError
|
|
83
|
+
raise NotImplementedError # pragma: no cover
|
|
82
84
|
|
|
83
85
|
|
|
84
86
|
def insert_node_between_two_nodes(graph: Graph,
|
|
@@ -123,7 +125,7 @@ def insert_node_after_node(graph: Graph,
|
|
|
123
125
|
|
|
124
126
|
last_nodes = graph.get_next_nodes(first_node)
|
|
125
127
|
if len(last_nodes) != 1:
|
|
126
|
-
|
|
128
|
+
Logger.error('Can only insert if there is only one input') # pragma: no cover
|
|
127
129
|
last_node = last_nodes[0]
|
|
128
130
|
insert_node_between_two_nodes(graph, node_to_insert, first_node, last_node)
|
|
129
131
|
|
|
@@ -145,7 +147,7 @@ def insert_node_before_node(graph: Graph,
|
|
|
145
147
|
"""
|
|
146
148
|
first_nodes = graph.get_prev_nodes(last_node)
|
|
147
149
|
if len(first_nodes) != 1:
|
|
148
|
-
|
|
150
|
+
Logger.error('Can only insert if there is only one input') # pragma: no cover
|
|
149
151
|
first_node = first_nodes[0]
|
|
150
152
|
insert_node_between_two_nodes(graph, node_to_insert, first_node, last_node)
|
|
151
153
|
|
|
@@ -222,8 +224,8 @@ def shift_negative_function(graph: Graph,
|
|
|
222
224
|
min_to_correct, max_value2compare = graph.get_out_stats_collector(non_linear_node).get_min_max_values()
|
|
223
225
|
|
|
224
226
|
if not non_linear_node.is_all_activation_candidates_equal():
|
|
225
|
-
|
|
226
|
-
|
|
227
|
+
Logger.error("Shift negative correction is not supported for more than one activation quantization "
|
|
228
|
+
"configuration candidate") # pragma: no cover
|
|
227
229
|
|
|
228
230
|
# all candidates have same activation config, so taking the first candidate for calculations
|
|
229
231
|
non_linear_node_cfg_candidate = non_linear_node.candidates_quantization_cfg[0].activation_quantization_cfg
|
|
@@ -241,7 +243,8 @@ def shift_negative_function(graph: Graph,
|
|
|
241
243
|
# taking the minimal quantized point that is still positive.
|
|
242
244
|
num_q_points = 2 ** non_linear_node_cfg_candidate.activation_n_bits
|
|
243
245
|
lsb = activation_threshold / num_q_points
|
|
244
|
-
q_points = np.linspace(0, activation_threshold - lsb, num_q_points).astype(
|
|
246
|
+
q_points = np.linspace(0, activation_threshold - lsb, num_q_points).astype(
|
|
247
|
+
'float32') # Change to type float32 to support tensorflow dtypes
|
|
245
248
|
|
|
246
249
|
delta = q_points + min_to_correct
|
|
247
250
|
delta[delta < 0] = np.inf
|
|
@@ -253,14 +256,16 @@ def shift_negative_function(graph: Graph,
|
|
|
253
256
|
hist_bins, hist_count)
|
|
254
257
|
|
|
255
258
|
min_mse, _th, _shift = np.inf, None, None
|
|
256
|
-
for _activation_threshold in [activation_threshold, 2*activation_threshold]:
|
|
259
|
+
for _activation_threshold in [activation_threshold, 2 * activation_threshold]:
|
|
257
260
|
qparams = {THRESHOLD: _activation_threshold, SIGNED: False}
|
|
258
261
|
_lsb = _activation_threshold / num_q_points
|
|
259
|
-
_q_points = np.linspace(0, _activation_threshold - _lsb, num_q_points).astype(
|
|
262
|
+
_q_points = np.linspace(0, _activation_threshold - _lsb, num_q_points).astype(
|
|
263
|
+
'float32') # Change to type float32 to support tensorflow dtypes
|
|
260
264
|
for _shift_value in _q_points:
|
|
261
265
|
_hist_bins = hist_bins.astype(np.float32) + _shift_value
|
|
262
|
-
q_bins = non_linear_node_cfg_candidate.activation_quantization_fn(
|
|
263
|
-
|
|
266
|
+
q_bins = non_linear_node_cfg_candidate.activation_quantization_fn(
|
|
267
|
+
non_linear_node_cfg_candidate.activation_n_bits,
|
|
268
|
+
qparams)(_hist_bins)
|
|
264
269
|
mse = _mse_error_histogram(q_bins, None, _hist_bins, hist_count)
|
|
265
270
|
if mse < min_mse:
|
|
266
271
|
min_mse = mse
|
|
@@ -61,7 +61,7 @@ class BaseWeightsActivationSplit(BaseSubstitution):
|
|
|
61
61
|
# Node is not composite, therefore, can't be split
|
|
62
62
|
Logger.critical(f"The graph contains a node {node.name} with non composite candidates."
|
|
63
63
|
f"In order to run mixed-precision search with BOPS target KPI, "
|
|
64
|
-
f"all model layers should be composite.")
|
|
64
|
+
f"all model layers should be composite.") # pragma: no cover
|
|
65
65
|
|
|
66
66
|
weights_node = VirtualSplitWeightsNode(node)
|
|
67
67
|
activation_node = VirtualSplitActivationNode(node, self.activation_layer_type, self.fw_attr)
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
+
from model_compression_toolkit.core.common.logger import Logger
|
|
17
|
+
|
|
16
18
|
def get_current_tp_model():
|
|
17
19
|
"""
|
|
18
20
|
|
|
@@ -38,7 +40,7 @@ class CurrentTPModel:
|
|
|
38
40
|
|
|
39
41
|
"""
|
|
40
42
|
if self.tp_model is None:
|
|
41
|
-
|
|
43
|
+
Logger.error('Target platform model is not initialized.') # pragma: no cover
|
|
42
44
|
return self.tp_model
|
|
43
45
|
|
|
44
46
|
def reset(self):
|
model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py
CHANGED
|
@@ -16,6 +16,8 @@
|
|
|
16
16
|
import operator
|
|
17
17
|
from typing import Any, Callable, Dict
|
|
18
18
|
|
|
19
|
+
from model_compression_toolkit.core.common.logger import Logger
|
|
20
|
+
|
|
19
21
|
|
|
20
22
|
class Filter:
|
|
21
23
|
"""
|
|
@@ -31,7 +33,7 @@ class Filter:
|
|
|
31
33
|
Returns:
|
|
32
34
|
Whether the passed configuration matches the filter or not.
|
|
33
35
|
"""
|
|
34
|
-
raise
|
|
36
|
+
raise NotImplemented('Filter did not implement match') # pragma: no cover
|
|
35
37
|
|
|
36
38
|
|
|
37
39
|
class AttributeFilter(Filter):
|
|
@@ -85,7 +87,7 @@ class AttributeFilter(Filter):
|
|
|
85
87
|
"""
|
|
86
88
|
|
|
87
89
|
if not isinstance(other, AttributeFilter):
|
|
88
|
-
|
|
90
|
+
Logger.error("Not an attribute filter. Can not run an OR operation.") # pragma: no cover
|
|
89
91
|
return OrAttributeFilter(self, other)
|
|
90
92
|
|
|
91
93
|
def __and__(self, other: Any):
|
|
@@ -99,7 +101,7 @@ class AttributeFilter(Filter):
|
|
|
99
101
|
AndAttributeFilter that filters with AND between the current AttributeFilter and the passed AttributeFilter.
|
|
100
102
|
"""
|
|
101
103
|
if not isinstance(other, AttributeFilter):
|
|
102
|
-
|
|
104
|
+
Logger.error("Not an attribute filter. Can not run an AND operation.") # pragma: no cover
|
|
103
105
|
return AndAttributeFilter(self, other)
|
|
104
106
|
|
|
105
107
|
def match(self,
|
|
@@ -123,7 +125,7 @@ class AttributeFilter(Filter):
|
|
|
123
125
|
Returns: A string representation for the filter.
|
|
124
126
|
|
|
125
127
|
"""
|
|
126
|
-
raise
|
|
128
|
+
raise NotImplemented("Filter must implement op_as_str ") # pragma: no cover
|
|
127
129
|
|
|
128
130
|
def __repr__(self):
|
|
129
131
|
return f'{self.attr} {self.op_as_str()} {self.value}'
|
|
@@ -267,3 +269,14 @@ class Eq(AttributeFilter):
|
|
|
267
269
|
super().__init__(attr=attr, value=value, op=operator.eq)
|
|
268
270
|
|
|
269
271
|
def op_as_str(self): return "="
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class Contains(AttributeFilter):
|
|
275
|
+
"""
|
|
276
|
+
Filter configurations such that it matches configurations that have an attribute with a value that contains the value that Contains holds.
|
|
277
|
+
"""
|
|
278
|
+
|
|
279
|
+
def __init__(self, attr: str, value: Any):
|
|
280
|
+
super().__init__(attr=attr, value=value, op=operator.contains)
|
|
281
|
+
|
|
282
|
+
def op_as_str(self): return " in "
|
|
@@ -131,9 +131,7 @@ class OperationsToLayers:
|
|
|
131
131
|
for layer in ops2layers.layers:
|
|
132
132
|
qco_by_opset_name = _current_tpc.get().tp_model.get_config_options_by_operators_set(ops2layers.name)
|
|
133
133
|
if layer in existing_layers:
|
|
134
|
-
|
|
135
|
-
|
|
134
|
+
Logger.error(f'Found layer {layer.__name__} in more than one '
|
|
135
|
+
f'OperatorsSet') # pragma: no cover
|
|
136
136
|
else:
|
|
137
137
|
existing_layers.update({layer: qco_by_opset_name})
|
|
138
|
-
|
|
139
|
-
|
|
@@ -131,7 +131,8 @@ class TargetPlatformCapabilities(ImmutableClass):
|
|
|
131
131
|
if isinstance(tpc_component, OperationsSetToLayers):
|
|
132
132
|
self.op_sets_to_layers += tpc_component
|
|
133
133
|
else:
|
|
134
|
-
|
|
134
|
+
Logger.error(f'Trying to append an unfamiliar TargetPlatformCapabilitiesComponent of type: '
|
|
135
|
+
f'{type(tpc_component)}') # pragma: no cover
|
|
135
136
|
|
|
136
137
|
def __enter__(self):
|
|
137
138
|
"""
|
|
@@ -175,7 +176,7 @@ class TargetPlatformCapabilities(ImmutableClass):
|
|
|
175
176
|
QuantizationConfigOptions of the node.
|
|
176
177
|
"""
|
|
177
178
|
if node is None:
|
|
178
|
-
|
|
179
|
+
Logger.error(f'Can not retrieve QC options for None node') # pragma: no cover
|
|
179
180
|
for fl, qco in self.filterlayer2qco.items():
|
|
180
181
|
if fl.match(node):
|
|
181
182
|
return qco
|
|
@@ -205,7 +206,6 @@ class TargetPlatformCapabilities(ImmutableClass):
|
|
|
205
206
|
layer2qco.update({l: qco})
|
|
206
207
|
return layer2qco, filterlayer2qco
|
|
207
208
|
|
|
208
|
-
|
|
209
209
|
def remove_fusing_names_from_not_used_list(self):
|
|
210
210
|
"""
|
|
211
211
|
Remove OperatorSets names from the list of the unused sets (so a warning
|
|
@@ -235,5 +235,3 @@ class TargetPlatformCapabilities(ImmutableClass):
|
|
|
235
235
|
"""
|
|
236
236
|
for op in self.__tp_model_opsets_not_used:
|
|
237
237
|
Logger.warning(f'{op} is defined in TargetPlatformModel, but is not used in TargetPlatformCapabilities.')
|
|
238
|
-
|
|
239
|
-
|
|
@@ -20,38 +20,25 @@ from typing import List, Dict, Callable
|
|
|
20
20
|
from networkx.algorithms.dag import topological_sort
|
|
21
21
|
|
|
22
22
|
import tensorflow as tf
|
|
23
|
-
from tensorflow.keras.layers import Layer
|
|
23
|
+
from tensorflow.keras.layers import Layer, InputLayer
|
|
24
24
|
from model_compression_toolkit.core import common
|
|
25
25
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
26
26
|
from model_compression_toolkit.core.keras.constants import LAYER_NAME
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
def identity_wrapper(node: BaseNode, layer: Layer):
|
|
30
|
-
"""
|
|
31
|
-
A function which takes a computational graph node and a keras layer and return an identity wrapping which return the layer itself
|
|
32
|
-
Args:
|
|
33
|
-
node: A node of mct graph.
|
|
34
|
-
layer: A keras layer
|
|
35
|
-
|
|
36
|
-
Returns: keras layer
|
|
37
|
-
|
|
38
|
-
"""
|
|
39
|
-
return layer
|
|
40
|
-
|
|
41
|
-
|
|
42
29
|
class OperationHandler:
|
|
43
30
|
"""
|
|
44
31
|
Class to handle conversions from graph nodes to Keras operators and retrieving them.
|
|
45
32
|
"""
|
|
46
33
|
|
|
47
|
-
def __init__(self, graph: Graph
|
|
34
|
+
def __init__(self, graph: Graph):
|
|
48
35
|
# hold nodes after sorting them
|
|
49
36
|
self.node_sort = list(topological_sort(graph))
|
|
50
37
|
|
|
51
38
|
self.layer_to_node_dict = {}
|
|
52
39
|
|
|
53
40
|
# hold dictionary from node to its equivalent Keras layer
|
|
54
|
-
self.node_to_fw_op_dict = instance_builder(self.node_sort
|
|
41
|
+
self.node_to_fw_op_dict = instance_builder(self.node_sort)
|
|
55
42
|
|
|
56
43
|
def get_node_op_function(self, n: BaseNode) -> Layer:
|
|
57
44
|
"""
|
|
@@ -86,10 +73,15 @@ def node_builder(n: common.BaseNode) -> Layer:
|
|
|
86
73
|
Returns:
|
|
87
74
|
Keras layer that was built from the node.
|
|
88
75
|
"""
|
|
89
|
-
|
|
90
76
|
framework_attr = copy.copy(n.framework_attr)
|
|
77
|
+
if n.layer_class is InputLayer:
|
|
78
|
+
# replace input node with identity, so can wrap it with QuantizationWrapper
|
|
79
|
+
_layer_class = Layer # Identity
|
|
80
|
+
framework_attr = {}
|
|
81
|
+
else:
|
|
82
|
+
_layer_class = n.layer_class
|
|
91
83
|
framework_attr[LAYER_NAME] = n.name # Overwrite framework name to identical graph node name
|
|
92
|
-
node_instance =
|
|
84
|
+
node_instance = _layer_class.from_config(framework_attr) # Build layer from node's configuration.
|
|
93
85
|
with tf.name_scope(n.name):
|
|
94
86
|
# Add layer name to default weight name to avoid name duplications
|
|
95
87
|
node_instance.build(n.input_shape)
|
|
@@ -98,13 +90,12 @@ def node_builder(n: common.BaseNode) -> Layer:
|
|
|
98
90
|
return node_instance
|
|
99
91
|
|
|
100
92
|
|
|
101
|
-
def instance_builder(toposort: List[BaseNode]
|
|
93
|
+
def instance_builder(toposort: List[BaseNode]) -> Dict[BaseNode, Layer]:
|
|
102
94
|
"""
|
|
103
95
|
Build a dictionary of nodes to their corresponding Keras
|
|
104
96
|
layers, given a list of nodes.
|
|
105
97
|
|
|
106
98
|
Args:
|
|
107
|
-
wrapper: A function wrapper keras Layers.
|
|
108
99
|
toposort: List of nodes sorted topological to build their layers.
|
|
109
100
|
|
|
110
101
|
Returns:
|
|
@@ -114,7 +105,7 @@ def instance_builder(toposort: List[BaseNode], wrapper: Callable) -> Dict[BaseNo
|
|
|
114
105
|
nodes_dict = dict()
|
|
115
106
|
for n in toposort:
|
|
116
107
|
if not n.reuse: # Hold a single node in dictionary for all reused nodes from the same layer.
|
|
117
|
-
keras_node =
|
|
108
|
+
keras_node = node_builder(n)
|
|
118
109
|
nodes_dict.update({n: keras_node})
|
|
119
110
|
|
|
120
111
|
return nodes_dict
|