mct-nightly 1.7.1.31122022.post351__py3-none-any.whl → 1.8.0.1042023.post423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/METADATA +16 -16
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/RECORD +193 -150
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +13 -14
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/statistics_collector.py +2 -2
- model_compression_toolkit/core/common/constants.py +9 -4
- model_compression_toolkit/core/common/framework_implementation.py +32 -30
- model_compression_toolkit/core/common/graph/base_graph.py +8 -6
- model_compression_toolkit/core/common/logger.py +10 -2
- model_compression_toolkit/core/common/matchers/base_matcher.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_validation.py +2 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +7 -4
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +14 -17
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +9 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +5 -4
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +7 -0
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +13 -8
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +17 -12
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/target_platform/current_tp_model.py +3 -1
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/attribute_filter.py +17 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/core/common/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -5
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +12 -21
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +40 -14
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +51 -27
- model_compression_toolkit/core/keras/constants.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +2 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/keras/quantization_facade.py +3 -3
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +15 -9
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +2 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +14 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +88 -46
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +27 -12
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/wrapper_quantize_config.py +2 -3
- model_compression_toolkit/core/pytorch/constants.py +5 -0
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +9 -14
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantization_facade.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +7 -5
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/tpc_models/get_target_platform_capabilities.py +6 -2
- model_compression_toolkit/{exporter/model_wrapper/keras/quantize_configs → core/tpc_models/imx500_tpc}/__init__.py +1 -1
- model_compression_toolkit/core/tpc_models/imx500_tpc/latest/__init__.py +24 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/target_platform_capabilities.py +45 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/__init__.py +16 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tp_model.py +156 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_keras.py +101 -0
- model_compression_toolkit/core/tpc_models/imx500_tpc/v1/tpc_pytorch.py +95 -0
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -12
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +12 -39
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +39 -27
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +10 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +6 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +48 -35
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +3 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +180 -0
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +44 -26
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +34 -137
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +143 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +46 -0
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +56 -22
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +29 -112
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +83 -79
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +47 -0
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +44 -0
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -127
- model_compression_toolkit/gptq/common/gptq_constants.py +20 -6
- model_compression_toolkit/gptq/common/gptq_graph.py +22 -0
- model_compression_toolkit/gptq/common/gptq_training.py +32 -26
- model_compression_toolkit/gptq/keras/gptq_loss.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +73 -39
- model_compression_toolkit/gptq/keras/graph_info.py +24 -43
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -18
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +2 -1
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +13 -14
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +78 -0
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/{optimizers → quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +256 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +68 -168
- model_compression_toolkit/gptq/pytorch/gptq_training.py +78 -39
- model_compression_toolkit/gptq/pytorch/graph_info.py +81 -0
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +12 -18
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +5 -1
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +92 -0
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -119
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +75 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/{exporter/model_wrapper/keras/quantizers → gptq/pytorch/quantizer/soft_rounding}/__init__.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +244 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +182 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -3
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -6
- model_compression_toolkit/qat/common/qat_config.py +68 -0
- model_compression_toolkit/qat/keras/quantization_facade.py +55 -48
- model_compression_toolkit/qat/keras/quantizer/__init__.py +3 -0
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +48 -0
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +77 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +283 -0
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +158 -46
- model_compression_toolkit/qat/pytorch/quantization_facade.py +190 -11
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +49 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +74 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +136 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +204 -0
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +190 -0
- model_compression_toolkit/quantizers_infrastructure/__init__.py +23 -0
- model_compression_toolkit/{gptq/keras/quantizer/configs → quantizers_infrastructure/inferable_infrastructure}/__init__.py +1 -1
- model_compression_toolkit/{gptq/keras/quantizer/gumbel_rounding → quantizers_infrastructure/inferable_infrastructure/common}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +41 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +31 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +53 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +49 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/{qunatizers_infrastructure → quantizers_infrastructure/inferable_infrastructure}/keras/load_model.py +26 -8
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +345 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +85 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +27 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +148 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +86 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +111 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +56 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +25 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +79 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +179 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +67 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +87 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +163 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +269 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +152 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +35 -0
- model_compression_toolkit/{exporter/model_wrapper/pytorch/quantizers → quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers}/__init__.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +62 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +83 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +100 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +95 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +70 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +57 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +26 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +77 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +106 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +66 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +104 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +109 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +200 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +116 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +65 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/quant_utils.py +36 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +97 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +90 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +80 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/quantizer_utils.py +48 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +14 -0
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +66 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantize_config_to_node.py +0 -66
- model_compression_toolkit/exporter/model_wrapper/keras/builder/quantizer_to_node.py +0 -134
- model_compression_toolkit/exporter/model_wrapper/keras/extended_quantize_wrapper.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/activation_quantize_config.py +0 -81
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_activation_quantize_config.py +0 -128
- model_compression_toolkit/exporter/model_wrapper/keras/quantize_configs/weights_quantize_config.py +0 -107
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/fq_quantizer.py +0 -99
- model_compression_toolkit/exporter/model_wrapper/keras/quantizers/weights_uniform_quantizer.py +0 -105
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantize_config.py +0 -61
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/fq_quantizer.py +0 -59
- model_compression_toolkit/exporter/model_wrapper/pytorch/quantizers/uniform_weights_quantizer.py +0 -67
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/activation_quantize_config.py +0 -52
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/no_quantization_quantize_config.py +0 -46
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_activation_quantize_config.py +0 -54
- model_compression_toolkit/exporter/model_wrapper/pytorch/wrappers_quantize_configs/weights_quantize_config.py +0 -52
- model_compression_toolkit/gptq/keras/gptq_model_builder.py +0 -104
- model_compression_toolkit/gptq/keras/optimizers/sam_optimizer.py +0 -119
- model_compression_toolkit/gptq/keras/quantizer/config_factory.py +0 -62
- model_compression_toolkit/gptq/keras/quantizer/configs/base_quantizer_gptq_config.py +0 -65
- model_compression_toolkit/gptq/keras/quantizer/configs/weight_quantizer_gptq_config.py +0 -269
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/base_gumbel_rounding.py +0 -263
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/gumbel_softmax.py +0 -75
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/symmetric_gumbel.py +0 -266
- model_compression_toolkit/gptq/keras/quantizer/gumbel_rounding/uniform_gumbel.py +0 -247
- model_compression_toolkit/gptq/keras/quantizer/kernel_functions.py +0 -50
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/uniform_ste.py +0 -49
- model_compression_toolkit/gptq/pytorch/gptq_graph_info.py +0 -94
- model_compression_toolkit/gptq/pytorch/gptq_model_builder.py +0 -113
- model_compression_toolkit/gptq/pytorch/quantizer/gptq_quantizer.py +0 -71
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/__init__.py +0 -14
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/base_gumbel_weights_quantizer.py +0 -157
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/sym_gumbel_weights_quantizer.py +0 -150
- model_compression_toolkit/gptq/pytorch/quantizer/gumbel_rounding/uniform_gumbel_weights_quantizer.py +0 -143
- model_compression_toolkit/gptq/pytorch/quantizer/quantizer_wrapper.py +0 -103
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/ste_weights_quantizer.py +0 -103
- model_compression_toolkit/qat/keras/qat_model_builder.py +0 -105
- model_compression_toolkit/qat/keras/quantizer/quantization_dispatcher_builder.py +0 -56
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetirc_ste.py +0 -145
- model_compression_toolkit/qunatizers_infrastructure/__init__.py +0 -8
- model_compression_toolkit/qunatizers_infrastructure/common/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/common/base_quantizer.py +0 -123
- model_compression_toolkit/qunatizers_infrastructure/common/node_quantization_dispatcher.py +0 -65
- model_compression_toolkit/qunatizers_infrastructure/keras/__init__.py +0 -14
- model_compression_toolkit/qunatizers_infrastructure/keras/base_keras_quantizer.py +0 -75
- model_compression_toolkit/qunatizers_infrastructure/keras/config_serialization.py +0 -83
- model_compression_toolkit/qunatizers_infrastructure/keras/keras_node_quantization_dispatcher.py +0 -74
- model_compression_toolkit/qunatizers_infrastructure/keras/quantize_wrapper.py +0 -194
- model_compression_toolkit/qunatizers_infrastructure/pytorch/__init__.py +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.7.1.31122022.post351.dist-info → mct_nightly-1.8.0.1042023.post423.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{exporter/model_wrapper/pytorch/wrappers_quantize_configs → qat/pytorch/quantizer/ste_rounding}/__init__.py +0 -0
|
@@ -16,61 +16,46 @@ from enum import Enum
|
|
|
16
16
|
from typing import Callable, Any, Dict
|
|
17
17
|
from model_compression_toolkit.core.common.defaultdict import DefaultDict
|
|
18
18
|
from model_compression_toolkit.core import common
|
|
19
|
-
|
|
20
|
-
MAX_LSBS_CHANGE_MAP = {8: 2,
|
|
21
|
-
4: 1,
|
|
22
|
-
2: 1}
|
|
23
|
-
|
|
24
|
-
N_CYCLES = 4
|
|
25
|
-
MIM_TEMP = 0.5
|
|
26
|
-
MAX_TEMP = 1.0
|
|
27
|
-
GAMMA_TEMPERATURE = 0.1
|
|
28
|
-
GUMBEL_SCALE = 0.5
|
|
19
|
+
from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR, MAX_LSB_STR, REG_DEFAULT
|
|
29
20
|
|
|
30
21
|
|
|
31
22
|
class RoundingType(Enum):
|
|
32
23
|
"""
|
|
33
24
|
An enum for choosing the GPTQ rounding methods
|
|
34
25
|
0. STRAIGHT-THROUGH ESTIMATOR
|
|
35
|
-
1.
|
|
26
|
+
1. SoftQuantizer
|
|
36
27
|
"""
|
|
37
28
|
STE = 0
|
|
38
|
-
|
|
29
|
+
SoftQuantizer = 1
|
|
39
30
|
|
|
40
31
|
|
|
41
|
-
class
|
|
32
|
+
class GPTQHessianWeightsConfig:
|
|
42
33
|
"""
|
|
43
|
-
Configuration to use for
|
|
34
|
+
Configuration to use for computing the Hessian-based weights for GPTQ loss metric.
|
|
44
35
|
"""
|
|
45
36
|
|
|
46
37
|
def __init__(self,
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
gumbel_scale: float = GUMBEL_SCALE,
|
|
53
|
-
gumbel_scale_per_bitwidth: Dict[int, float] = None):
|
|
54
|
-
"""
|
|
55
|
-
Initialize a GumbelConfig.
|
|
56
|
-
|
|
38
|
+
hessians_num_samples: int = 16,
|
|
39
|
+
norm_weights: bool = True,
|
|
40
|
+
log_norm: bool = True,
|
|
41
|
+
scale_log_norm: bool = False,
|
|
42
|
+
hessians_n_iter: int = 50):
|
|
57
43
|
|
|
44
|
+
"""
|
|
45
|
+
Initialize a GPTQHessianWeightsConfig.
|
|
58
46
|
Args:
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
gumbel_scale (float): A normalization factor for the gumbel tensor values.
|
|
65
|
-
gumbel_scale_per_bitwidth (dict): An optional mapping between a bit-width and a gumbel scale value for Gumbel Rounding,
|
|
47
|
+
hessians_num_samples (int): Number of samples to use for computing the Hessian-based weights.
|
|
48
|
+
norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
|
|
49
|
+
log_norm (bool): Whether to use log normalization to the GPTQ Hessian-based weights.
|
|
50
|
+
scale_log_norm (bool): Whether to scale the final vector of the Hessian weights.
|
|
51
|
+
hessians_n_iter (int): Number of random iterations to run Hessian approximation for GPTQ weights.
|
|
66
52
|
"""
|
|
67
|
-
|
|
68
|
-
self.
|
|
69
|
-
self.
|
|
70
|
-
self.
|
|
71
|
-
self.
|
|
72
|
-
self.
|
|
73
|
-
self.gumbel_scale_per_bitwidth = gumbel_scale_per_bitwidth
|
|
53
|
+
|
|
54
|
+
self.hessians_num_samples = hessians_num_samples
|
|
55
|
+
self.norm_weights = norm_weights
|
|
56
|
+
self.log_norm = log_norm
|
|
57
|
+
self.scale_log_norm = scale_log_norm
|
|
58
|
+
self.hessians_n_iter = hessians_n_iter
|
|
74
59
|
|
|
75
60
|
|
|
76
61
|
class GradientPTQConfig:
|
|
@@ -78,27 +63,19 @@ class GradientPTQConfig:
|
|
|
78
63
|
Configuration to use for quantization with GradientPTQ (experimental).
|
|
79
64
|
"""
|
|
80
65
|
|
|
81
|
-
def __init__(self,
|
|
82
|
-
n_iter: int,
|
|
66
|
+
def __init__(self, n_iter: int,
|
|
83
67
|
optimizer: Any,
|
|
84
68
|
optimizer_rest: Any = None,
|
|
85
69
|
loss: Callable = None,
|
|
86
70
|
log_function: Callable = None,
|
|
87
71
|
train_bias: bool = True,
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
rounding_type: RoundingType = RoundingType.GumbelRounding,
|
|
91
|
-
rho: float = 0.01,
|
|
92
|
-
lsb_change_per_bit_width: dict = DefaultDict(MAX_LSBS_CHANGE_MAP, lambda: 1),
|
|
93
|
-
eps: float = 1e-6,
|
|
94
|
-
use_jac_based_weights: bool = True,
|
|
95
|
-
num_samples_for_loss: int = 16,
|
|
96
|
-
norm_weights: bool = False,
|
|
97
|
-
quantizer_config: GumbelConfig = GumbelConfig(),
|
|
72
|
+
rounding_type: RoundingType = RoundingType.SoftQuantizer,
|
|
73
|
+
use_hessian_based_weights: bool = True,
|
|
98
74
|
optimizer_quantization_parameter: Any = None,
|
|
99
75
|
optimizer_bias: Any = None,
|
|
100
|
-
|
|
101
|
-
|
|
76
|
+
regularization_factor: float = REG_DEFAULT,
|
|
77
|
+
hessian_weights_config: GPTQHessianWeightsConfig = GPTQHessianWeightsConfig(),
|
|
78
|
+
gptq_quantizer_params_override: Dict[str, Any] = None):
|
|
102
79
|
"""
|
|
103
80
|
Initialize a GradientPTQConfig.
|
|
104
81
|
|
|
@@ -111,20 +88,13 @@ class GradientPTQConfig:
|
|
|
111
88
|
accordingly. see example in multiple_tensors_mse_loss
|
|
112
89
|
log_function (Callable): Function to log information about the GPTQ process.
|
|
113
90
|
train_bias (bool): Whether to update the bias during the training or not.
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
rounding_type (RoundingType): An enum that defines the rounding type (STE or GumbelRoudning).
|
|
117
|
-
rho (rho): A floating point number that defines the sam optimization lookahead.
|
|
118
|
-
lsb_change_per_bit_width (dict): Whether to update the bias during the training or not.
|
|
119
|
-
eps (float): A floating point value for numeric stability.
|
|
120
|
-
use_jac_based_weights (bool): Whether to use jacobian-based weights for weighted average loss.
|
|
121
|
-
num_samples_for_loss (int): Number of samples to use for computing the jacobian-based weights.
|
|
122
|
-
norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
|
|
123
|
-
quantizer_config (Any): A class the contins the quantizer specific config.
|
|
91
|
+
rounding_type (RoundingType): An enum that defines the rounding type.
|
|
92
|
+
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
|
|
124
93
|
optimizer_quantization_parameter (Any): Optimizer to override the rest optimizer for quantizer parameters.
|
|
125
|
-
optimizer_bias (Any): Optimizer to override the rest
|
|
126
|
-
|
|
127
|
-
|
|
94
|
+
optimizer_bias (Any): Optimizer to override the rest optimizer for bias.
|
|
95
|
+
regularization_factor (float): A floating point number that defines the regularization factor.
|
|
96
|
+
hessian_weights_config (GPTQHessianWeightsConfig): A configuration that include all necessary arguments to run a computation of Hessian weights for the GPTQ loss.
|
|
97
|
+
gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters).
|
|
128
98
|
|
|
129
99
|
"""
|
|
130
100
|
self.n_iter = n_iter
|
|
@@ -133,58 +103,35 @@ class GradientPTQConfig:
|
|
|
133
103
|
self.loss = loss
|
|
134
104
|
self.log_function = log_function
|
|
135
105
|
self.train_bias = train_bias
|
|
136
|
-
|
|
106
|
+
|
|
137
107
|
self.rounding_type = rounding_type
|
|
138
|
-
self.
|
|
139
|
-
self.rho = rho
|
|
140
|
-
self.lsb_change_per_bit_width = lsb_change_per_bit_width
|
|
141
|
-
self.eps = eps
|
|
142
|
-
self.use_jac_based_weights = use_jac_based_weights
|
|
143
|
-
self.num_samples_for_loss = num_samples_for_loss
|
|
144
|
-
self.norm_weights = norm_weights
|
|
145
|
-
if not isinstance(quantizer_config, GumbelConfig) and self.is_gumbel:
|
|
146
|
-
common.Logger.error("Please use GumbelConfig as quantizer config when using Gumbel Rounding")
|
|
147
|
-
self.quantizer_config = quantizer_config
|
|
108
|
+
self.use_hessian_based_weights = use_hessian_based_weights
|
|
148
109
|
self.optimizer_quantization_parameter = optimizer_quantization_parameter
|
|
149
110
|
self.optimizer_bias = optimizer_bias
|
|
150
|
-
self.
|
|
151
|
-
self.
|
|
111
|
+
self.regularization_factor = regularization_factor
|
|
112
|
+
self.hessian_weights_config = hessian_weights_config
|
|
152
113
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
"""
|
|
156
|
-
This function state if Gumbel Rounding is in use.
|
|
157
|
-
Returns: boolean
|
|
158
|
-
|
|
159
|
-
"""
|
|
160
|
-
return self.rounding_type == RoundingType.GumbelRounding
|
|
114
|
+
self.gptq_quantizer_params_override = {} if gptq_quantizer_params_override is None \
|
|
115
|
+
else gptq_quantizer_params_override
|
|
161
116
|
|
|
162
117
|
|
|
163
118
|
class GradientPTQConfigV2(GradientPTQConfig):
|
|
164
119
|
"""
|
|
165
120
|
Configuration to use for quantization with GradientPTQV2 (experimental).
|
|
166
121
|
"""
|
|
167
|
-
def __init__(self,
|
|
168
|
-
n_epochs: int,
|
|
122
|
+
def __init__(self, n_epochs: int,
|
|
169
123
|
optimizer: Any,
|
|
170
124
|
optimizer_rest: Any = None,
|
|
171
125
|
loss: Callable = None,
|
|
172
126
|
log_function: Callable = None,
|
|
173
127
|
train_bias: bool = True,
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
rounding_type: RoundingType = RoundingType.GumbelRounding,
|
|
177
|
-
rho: float = 0.01,
|
|
178
|
-
lsb_change_per_bit_width: dict = DefaultDict(MAX_LSBS_CHANGE_MAP, lambda: 1),
|
|
179
|
-
eps: float = 1e-6,
|
|
180
|
-
use_jac_based_weights: bool = True,
|
|
181
|
-
num_samples_for_loss: int = 16,
|
|
182
|
-
norm_weights: bool = False,
|
|
183
|
-
quantizer_config: GumbelConfig = GumbelConfig(),
|
|
128
|
+
rounding_type: RoundingType = RoundingType.SoftQuantizer,
|
|
129
|
+
use_hessian_based_weights: bool = True,
|
|
184
130
|
optimizer_quantization_parameter: Any = None,
|
|
185
131
|
optimizer_bias: Any = None,
|
|
186
|
-
|
|
187
|
-
|
|
132
|
+
regularization_factor: float = REG_DEFAULT,
|
|
133
|
+
hessian_weights_config: GPTQHessianWeightsConfig = GPTQHessianWeightsConfig(),
|
|
134
|
+
gptq_quantizer_params_override: Dict[str, Any] = None):
|
|
188
135
|
"""
|
|
189
136
|
Initialize a GradientPTQConfigV2.
|
|
190
137
|
|
|
@@ -197,20 +144,13 @@ class GradientPTQConfigV2(GradientPTQConfig):
|
|
|
197
144
|
accordingly. see example in multiple_tensors_mse_loss
|
|
198
145
|
log_function (Callable): Function to log information about the GPTQ process.
|
|
199
146
|
train_bias (bool): Whether to update the bias during the training or not.
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
rounding_type (RoundingType): An enum that defines the rounding type (STE or GumbelRoudning).
|
|
203
|
-
rho (rho): A floating point number that defines the sam optimization lookahead.
|
|
204
|
-
lsb_change_per_bit_width (dict): Whether to update the bias during the training or not.
|
|
205
|
-
eps (float): A floating point value for numeric stability.
|
|
206
|
-
use_jac_based_weights (bool): Whether to use jacobian-based weights for weighted average loss.
|
|
207
|
-
num_samples_for_loss (int): Number of samples to use for computing the jacobian-based weights.
|
|
208
|
-
norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
|
|
209
|
-
quantizer_config (Any): A class the contins the quantizer specific config.
|
|
147
|
+
rounding_type (RoundingType): An enum that defines the rounding type.
|
|
148
|
+
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
|
|
210
149
|
optimizer_quantization_parameter (Any): Optimizer to override the rest optimizer for quantizer parameters.
|
|
211
150
|
optimizer_bias (Any): Optimizer to override the rest optimizerfor bias.
|
|
212
|
-
|
|
213
|
-
|
|
151
|
+
regularization_factor (float): A floating point number that defines the regularization factor.
|
|
152
|
+
hessian_weights_config (GPTQHessianWeightsConfig): A configuration that include all necessary arguments to run a computation of Hessian weights for the GPTQ loss.
|
|
153
|
+
gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters).
|
|
214
154
|
|
|
215
155
|
"""
|
|
216
156
|
|
|
@@ -220,20 +160,13 @@ class GradientPTQConfigV2(GradientPTQConfig):
|
|
|
220
160
|
loss=loss,
|
|
221
161
|
log_function=log_function,
|
|
222
162
|
train_bias=train_bias,
|
|
223
|
-
quantization_parameters_learning=quantization_parameters_learning,
|
|
224
|
-
sam_optimization=sam_optimization,
|
|
225
163
|
rounding_type=rounding_type,
|
|
226
|
-
|
|
227
|
-
lsb_change_per_bit_width=lsb_change_per_bit_width,
|
|
228
|
-
eps=eps,
|
|
229
|
-
use_jac_based_weights=use_jac_based_weights,
|
|
230
|
-
num_samples_for_loss=num_samples_for_loss,
|
|
231
|
-
norm_weights=norm_weights,
|
|
232
|
-
quantizer_config=quantizer_config,
|
|
164
|
+
use_hessian_based_weights=use_hessian_based_weights,
|
|
233
165
|
optimizer_quantization_parameter=optimizer_quantization_parameter,
|
|
234
166
|
optimizer_bias=optimizer_bias,
|
|
235
|
-
|
|
236
|
-
|
|
167
|
+
regularization_factor=regularization_factor,
|
|
168
|
+
hessian_weights_config=hessian_weights_config,
|
|
169
|
+
gptq_quantizer_params_override=gptq_quantizer_params_override)
|
|
237
170
|
self.n_epochs = n_epochs
|
|
238
171
|
|
|
239
172
|
@classmethod
|
|
@@ -248,8 +181,5 @@ class GradientPTQConfigV2(GradientPTQConfig):
|
|
|
248
181
|
"""
|
|
249
182
|
n_epochs = int(round(config_v1.n_iter) / n_ptq_iter)
|
|
250
183
|
v1_params = config_v1.__dict__
|
|
251
|
-
v1_params.
|
|
184
|
+
v1_params = {k: v for k, v in v1_params.items() if k != 'n_iter'}
|
|
252
185
|
return cls(n_epochs, **v1_params)
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
@@ -1,11 +1,25 @@
|
|
|
1
|
+
# Parameters names
|
|
1
2
|
AUXVAR = 'auxvar_tensor'
|
|
2
3
|
ITERVAR = 'iteration_variable'
|
|
3
|
-
THRESHOLD_TENSOR = "ptq_threshold_tensor"
|
|
4
4
|
SCALE_TENSOR = "scale_ptq_tensor"
|
|
5
|
-
|
|
6
|
-
AUXSHIFT = '_shift'
|
|
7
|
-
TEMP = '_temp'
|
|
5
|
+
AUXSHIFT = 'shift'
|
|
8
6
|
WEIGHTS_QUANTIZATION_PARAMS = 'weights_quantization_params'
|
|
9
|
-
PTQ_MIN_RANGE = "
|
|
10
|
-
PTQ_MAX_RANGE = "
|
|
7
|
+
PTQ_MIN_RANGE = "min_range"
|
|
8
|
+
PTQ_MAX_RANGE = "max_range"
|
|
9
|
+
PTQ_THRESHOLD = "ptq_threshold"
|
|
10
|
+
SCALE_PTQ = "scale"
|
|
11
11
|
|
|
12
|
+
# Default quantizer values
|
|
13
|
+
N_CYCLES = 4
|
|
14
|
+
MIM_TEMP = 0.5
|
|
15
|
+
MAX_TEMP = 1.0
|
|
16
|
+
REG_DEFAULT = 0.01
|
|
17
|
+
MAX_LSB_CHANGE = 1
|
|
18
|
+
|
|
19
|
+
# Soft rounding arguments values
|
|
20
|
+
SOFT_ROUNDING_GAMMA = -0.1
|
|
21
|
+
SOFT_ROUNDING_ZETA = 1.1
|
|
22
|
+
|
|
23
|
+
# GPTQ config constant
|
|
24
|
+
QUANT_PARAM_LEARNING_STR = 'quantization_parameter_learning'
|
|
25
|
+
MAX_LSB_STR = 'max_lsbs_change_map'
|
|
@@ -13,6 +13,9 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Tuple, List
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit import FrameworkInfo
|
|
18
|
+
from model_compression_toolkit.core.common import Logger
|
|
16
19
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
17
20
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
18
21
|
|
|
@@ -42,3 +45,22 @@ def get_compare_points(input_graph: Graph) -> Tuple[List[BaseNode], List[str], L
|
|
|
42
45
|
compare_points_std.append(n.prior_info.std_output)
|
|
43
46
|
compare_points_mean.append(n.prior_info.mean_output)
|
|
44
47
|
return compare_points, compare_points_name, compare_points_mean, compare_points_std
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_kernel_attribute_name_for_gptq(layer_type: type, fw_info: FrameworkInfo) -> str:
|
|
51
|
+
"""
|
|
52
|
+
Returns a layer's kernel attribute name for GPTQ training purposes.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
layer_type: A type of model's layer.
|
|
56
|
+
fw_info: A FrameworkInfo object.
|
|
57
|
+
|
|
58
|
+
Returns: The name of the kernel attribute.
|
|
59
|
+
|
|
60
|
+
"""
|
|
61
|
+
kernel_attribute = fw_info.get_kernel_op_attributes(layer_type)
|
|
62
|
+
if len(kernel_attribute) != 1:
|
|
63
|
+
Logger.error( # pragma: no cover
|
|
64
|
+
f"In GPTQ training only the kernel weights attribute should be trained, but number of kernel "
|
|
65
|
+
f"attributes is {len(kernel_attribute)}.")
|
|
66
|
+
return kernel_attribute[0]
|
|
@@ -20,6 +20,7 @@ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
|
20
20
|
from model_compression_toolkit.core.common import Graph, Logger, BaseNode
|
|
21
21
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
22
22
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
23
|
+
from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR
|
|
23
24
|
from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points
|
|
24
25
|
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
|
|
25
26
|
|
|
@@ -69,26 +70,23 @@ class GPTQTrainer(ABC):
|
|
|
69
70
|
def get_optimizer_with_param(self,
|
|
70
71
|
flattened_trainable_weights: List[Any],
|
|
71
72
|
flattened_bias_weights: List[Any],
|
|
72
|
-
trainable_quantization_parameters: List[Any]
|
|
73
|
-
temperature_weights: List[Any]) -> List[Any]:
|
|
73
|
+
trainable_quantization_parameters: List[Any]) -> List[Any]:
|
|
74
74
|
"""
|
|
75
75
|
Create Optimizers with their trainable parameters
|
|
76
76
|
Args:
|
|
77
77
|
flattened_trainable_weights: list of trainable weights parameters (flattened)
|
|
78
78
|
flattened_bias_weights: list of trainable bias parameters (flattened)
|
|
79
79
|
trainable_quantization_parameters: list of trainable quantization parameters
|
|
80
|
-
temperature_weights: list of temperature weights variables
|
|
81
80
|
Returns:
|
|
82
81
|
List of Optimizer objects with parameters
|
|
83
82
|
"""
|
|
84
83
|
|
|
85
84
|
w2train = [*flattened_trainable_weights]
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
w2train.extend(temperature_weights)
|
|
85
|
+
|
|
86
|
+
quant_params_learning = self.gptq_config.gptq_quantizer_params_override.get(QUANT_PARAM_LEARNING_STR, False)
|
|
89
87
|
|
|
90
88
|
optimizer_with_param = [(self.gptq_config.optimizer, w2train)]
|
|
91
|
-
if self.gptq_config.train_bias or
|
|
89
|
+
if self.gptq_config.train_bias or quant_params_learning:
|
|
92
90
|
w2train_res = []
|
|
93
91
|
if self.gptq_config.train_bias:
|
|
94
92
|
if self.gptq_config.optimizer_bias is not None:
|
|
@@ -96,35 +94,42 @@ class GPTQTrainer(ABC):
|
|
|
96
94
|
else:
|
|
97
95
|
w2train_res.extend(flattened_bias_weights)
|
|
98
96
|
if self.gptq_config.optimizer_rest is None:
|
|
99
|
-
Logger.error(
|
|
97
|
+
Logger.error( # pragma: no cover
|
|
100
98
|
"To enable bias micro training an additional optimizer is required, please define the optimizer_rest")
|
|
101
|
-
if
|
|
99
|
+
if quant_params_learning:
|
|
102
100
|
if self.gptq_config.optimizer_quantization_parameter is not None: # Ability to override optimizer
|
|
103
101
|
optimizer_with_param.append((self.gptq_config.optimizer_quantization_parameter,
|
|
104
102
|
trainable_quantization_parameters))
|
|
105
103
|
else:
|
|
106
104
|
w2train_res.extend(trainable_quantization_parameters)
|
|
107
105
|
if self.gptq_config.optimizer_rest is None:
|
|
108
|
-
Logger.error(
|
|
109
|
-
"To enable
|
|
110
|
-
|
|
106
|
+
Logger.error( # pragma: no cover
|
|
107
|
+
"To enable quantization parameters micro training an additional optimizer is required, please define the optimizer_rest")
|
|
108
|
+
if len(w2train_res) > 0:
|
|
109
|
+
# Either bias or quantization parameters are trainable but did not provide a specific optimizer,
|
|
110
|
+
# so we should use optimizer_rest to train them
|
|
111
|
+
if self.gptq_config.optimizer_rest is None:
|
|
112
|
+
Logger.error( # pragma: no cover
|
|
113
|
+
"To enable bias or quantization parameters micro training an additional optimizer is required, please define the optimizer_rest")
|
|
114
|
+
optimizer_with_param.append((self.gptq_config.optimizer_rest, w2train_res))
|
|
111
115
|
|
|
112
116
|
return optimizer_with_param
|
|
113
117
|
|
|
114
118
|
|
|
115
|
-
def
|
|
116
|
-
|
|
119
|
+
def compute_hessian_based_weights(self,
|
|
120
|
+
representative_data_gen: Callable) -> np.ndarray:
|
|
117
121
|
"""
|
|
118
|
-
Computes the
|
|
122
|
+
Computes the Hessian-based weights using the framework's model_grad method per batch of images.
|
|
119
123
|
|
|
120
124
|
Args:
|
|
121
|
-
representative_data_gen: Dataset used for inference to compute the
|
|
125
|
+
representative_data_gen: Dataset used for inference to compute the Hessian-based weights.
|
|
122
126
|
|
|
123
127
|
Returns: A vector of weights, one for each compare point,
|
|
124
128
|
to be used for the loss metric weighted average computation when running GPTQ training.
|
|
125
129
|
"""
|
|
126
|
-
if self.gptq_config.
|
|
127
|
-
images = self._generate_images_batch(representative_data_gen,
|
|
130
|
+
if self.gptq_config.use_hessian_based_weights:
|
|
131
|
+
images = self._generate_images_batch(representative_data_gen,
|
|
132
|
+
self.gptq_config.hessian_weights_config.hessians_num_samples)
|
|
128
133
|
|
|
129
134
|
model_output_replacement = self._get_model_output_replacement()
|
|
130
135
|
|
|
@@ -142,17 +147,18 @@ class GPTQTrainer(ABC):
|
|
|
142
147
|
output_list=model_output_replacement,
|
|
143
148
|
all_outputs_indices=[],
|
|
144
149
|
alpha=0,
|
|
145
|
-
norm_weights=self.gptq_config.norm_weights,
|
|
146
|
-
n_iter=self.gptq_config.
|
|
150
|
+
norm_weights=self.gptq_config.hessian_weights_config.norm_weights,
|
|
151
|
+
n_iter=self.gptq_config.hessian_weights_config.hessians_n_iter)
|
|
147
152
|
points_apprx_jacobians_weights.append(image_ip_gradients)
|
|
148
|
-
if self.gptq_config.log_norm:
|
|
153
|
+
if self.gptq_config.hessian_weights_config.log_norm:
|
|
149
154
|
mean_jacobian_weights = np.mean(points_apprx_jacobians_weights, axis=0)
|
|
150
155
|
mean_jacobian_weights = np.where(mean_jacobian_weights != 0, mean_jacobian_weights,
|
|
151
156
|
np.partition(mean_jacobian_weights, 1)[1])
|
|
152
157
|
log_weights = np.log10(mean_jacobian_weights)
|
|
153
158
|
|
|
154
|
-
|
|
155
|
-
|
|
159
|
+
if self.gptq_config.hessian_weights_config.scale_log_norm:
|
|
160
|
+
return (log_weights - np.min(log_weights)) / (np.max(log_weights) - np.min(log_weights))
|
|
161
|
+
|
|
156
162
|
return log_weights - np.min(log_weights)
|
|
157
163
|
else:
|
|
158
164
|
return np.mean(points_apprx_jacobians_weights, axis=0)
|
|
@@ -204,7 +210,7 @@ class GPTQTrainer(ABC):
|
|
|
204
210
|
Quantized graph for GPTQ fine-tuning, GPTQ graph user info
|
|
205
211
|
"""
|
|
206
212
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
207
|
-
f'framework\'s GPTQ model builder method.')
|
|
213
|
+
f'framework\'s GPTQ model builder method.') # pragma: no cover
|
|
208
214
|
|
|
209
215
|
@abstractmethod
|
|
210
216
|
def train(self, representative_data_gen: Callable):
|
|
@@ -214,7 +220,7 @@ class GPTQTrainer(ABC):
|
|
|
214
220
|
representative_data_gen: Dataset to use for inputs of the models.
|
|
215
221
|
"""
|
|
216
222
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
217
|
-
f'framework\'s train method.')
|
|
223
|
+
f'framework\'s train method.') # pragma: no cover
|
|
218
224
|
|
|
219
225
|
@abstractmethod
|
|
220
226
|
def update_graph(self) -> Graph:
|
|
@@ -225,7 +231,7 @@ class GPTQTrainer(ABC):
|
|
|
225
231
|
Updated graph after GPTQ.
|
|
226
232
|
"""
|
|
227
233
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
228
|
-
f'framework\'s update_graph method.')
|
|
234
|
+
f'framework\'s update_graph method.') # pragma: no cover
|
|
229
235
|
|
|
230
236
|
def _get_model_output_replacement(self) -> List[BaseNode]:
|
|
231
237
|
"""
|
|
@@ -86,6 +86,7 @@ def mse_loss_per_tensor(y: tf.Tensor,
|
|
|
86
86
|
_loss = tf.reduce_mean(tf.pow(tf.abs(y - x), p))
|
|
87
87
|
return _loss / tf.reduce_mean(tf.pow(tf.abs(x), p)) if normalized else _loss
|
|
88
88
|
|
|
89
|
+
|
|
89
90
|
def activation_mse(flp_act_list,
|
|
90
91
|
fxp_act_list,
|
|
91
92
|
p_vector=None,
|
|
@@ -116,7 +117,6 @@ def activation_mse(flp_act_list,
|
|
|
116
117
|
return tf.reduce_mean(tf.stack(loss_values_list)), tf.reduce_mean(tf.stack(bias_loss_list))
|
|
117
118
|
|
|
118
119
|
|
|
119
|
-
|
|
120
120
|
class GPTQMultipleTensorsLoss:
|
|
121
121
|
def __init__(self, norm_loss: bool = False):
|
|
122
122
|
self.alpha = None
|