mct-nightly 1.8.0.22042023.post414__py3-none-any.whl → 1.8.0.22052023.post408__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.8.0.22042023.post414.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/METADATA +1 -1
- {mct_nightly-1.8.0.22042023.post414.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/RECORD +237 -230
- model_compression_toolkit/__init__.py +8 -31
- model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
- model_compression_toolkit/core/__init__.py +14 -0
- model_compression_toolkit/core/analyzer.py +3 -2
- model_compression_toolkit/core/common/__init__.py +0 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +1 -8
- model_compression_toolkit/core/common/fusion/layer_fusing.py +2 -2
- model_compression_toolkit/core/common/graph/base_graph.py +1 -1
- model_compression_toolkit/core/common/graph/base_node.py +57 -1
- model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
- model_compression_toolkit/core/common/memory_computation.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +2 -3
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/model_collector.py +2 -2
- model_compression_toolkit/core/common/model_validation.py +1 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -1
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +2 -1
- model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +3 -2
- model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +4 -3
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -2
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +2 -2
- model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +4 -4
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
- model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +66 -21
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +2 -2
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/keras/constants.py +0 -7
- model_compression_toolkit/core/keras/default_framework_info.py +2 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/keras/keras_implementation.py +2 -10
- model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
- model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +7 -7
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
- model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +1 -1
- model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +2 -2
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/constants.py +0 -6
- model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +6 -6
- model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -9
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +3 -2
- model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/runner.py +6 -6
- model_compression_toolkit/exporter/__init__.py +6 -3
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/export_serialization_format.py +20 -0
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/{tflite → keras}/fakely_quant_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/{tflite → keras}/int8_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +60 -22
- model_compression_toolkit/exporter/model_exporter/pytorch/export_serialization_format.py +20 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +54 -31
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +5 -3
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +4 -2
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +3 -2
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +3 -2
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +2 -2
- model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
- model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
- model_compression_toolkit/gptq/common/gptq_training.py +5 -4
- model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
- model_compression_toolkit/gptq/keras/gptq_training.py +41 -14
- model_compression_toolkit/gptq/keras/graph_info.py +4 -0
- model_compression_toolkit/gptq/keras/quantization_facade.py +26 -19
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
- model_compression_toolkit/gptq/pytorch/gptq_training.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +11 -11
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -3
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +1 -3
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +2 -2
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/gptq/runner.py +3 -2
- model_compression_toolkit/{exporter/model_exporter/tflite → legacy}/__init__.py +1 -1
- model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +8 -9
- model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +8 -9
- model_compression_toolkit/ptq/__init__.py +3 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +10 -11
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -7
- model_compression_toolkit/qat/__init__.py +4 -0
- model_compression_toolkit/qat/common/__init__.py +1 -2
- model_compression_toolkit/qat/common/qat_config.py +5 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +33 -27
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +31 -4
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +12 -10
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +8 -8
- model_compression_toolkit/qat/pytorch/quantization_facade.py +8 -8
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +3 -2
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +6 -4
- model_compression_toolkit/quantizers_infrastructure/__init__.py +2 -2
- model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +5 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/activation_quantization_holder.py +147 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +5 -5
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +1 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +9 -9
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +3 -5
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +2 -2
- model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
- model_compression_toolkit/target_platform_capabilities/target_platform/current_tp_model.py +1 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/operators.py +1 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
- model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +11 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attribute_filter.py +1 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/layer_filter_params.py +32 -34
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +2 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +3 -24
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +1 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/target_platform_capabilities.py +3 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v2/tp_model.py +7 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v3/tp_model.py +7 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v3_lut/tp_model.py +7 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v4/tp_model.py +7 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v4_lut/tp_model.py +7 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/v5/tp_model.py +7 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +1 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +1 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +2 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +1 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py +2 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +1 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py +2 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +0 -73
- {mct_nightly-1.8.0.22042023.post414.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.8.0.22042023.post414.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/WHEEL +0 -0
- {mct_nightly-1.8.0.22042023.post414.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{core/common/logger.py → logger.py} +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
model_compression_toolkit/exporter/model_exporter/{tflite → keras}/fakely_quant_tflite_exporter.py
RENAMED
|
@@ -20,7 +20,7 @@ import keras.models
|
|
|
20
20
|
import tensorflow as tf
|
|
21
21
|
|
|
22
22
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.load_model import keras_load_quantized_model
|
|
23
|
-
from model_compression_toolkit.
|
|
23
|
+
from model_compression_toolkit.logger import Logger
|
|
24
24
|
from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import FakelyQuantKerasExporter
|
|
25
25
|
|
|
26
26
|
|
|
@@ -23,7 +23,7 @@ from keras.layers import Dense, Conv2D, Reshape
|
|
|
23
23
|
from keras.models import clone_model
|
|
24
24
|
|
|
25
25
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
26
|
-
from model_compression_toolkit.
|
|
26
|
+
from model_compression_toolkit.logger import Logger
|
|
27
27
|
from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import FakelyQuantKerasExporter
|
|
28
28
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import \
|
|
29
29
|
constants as keras_inferable_constants
|
|
@@ -12,53 +12,91 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from enum import Enum
|
|
16
15
|
from typing import Callable, Dict
|
|
17
16
|
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
17
|
+
from model_compression_toolkit.constants import FOUND_TF
|
|
18
|
+
from model_compression_toolkit.exporter.model_exporter.keras.export_serialization_format import \
|
|
19
|
+
KerasExportSerializationFormat
|
|
20
|
+
from model_compression_toolkit.logger import Logger
|
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.quantization_format import \
|
|
23
|
+
QuantizationFormat
|
|
25
24
|
|
|
26
25
|
if FOUND_TF:
|
|
27
26
|
import keras
|
|
28
27
|
from model_compression_toolkit.exporter.model_wrapper.keras.validate_layer import is_keras_layer_exportable
|
|
29
|
-
from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import
|
|
28
|
+
from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import \
|
|
29
|
+
FakelyQuantKerasExporter
|
|
30
|
+
from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_tflite_exporter import \
|
|
31
|
+
FakelyQuantTFLiteExporter
|
|
32
|
+
from model_compression_toolkit.exporter.model_exporter.keras.int8_tflite_exporter import INT8TFLiteExporter
|
|
33
|
+
|
|
34
|
+
supported_serialization_quantization_export_dict = {
|
|
35
|
+
KerasExportSerializationFormat.KERAS_H5: [QuantizationFormat.FAKELY_QUANT],
|
|
36
|
+
KerasExportSerializationFormat.TFLITE: [QuantizationFormat.FAKELY_QUANT, QuantizationFormat.INT8]
|
|
37
|
+
}
|
|
30
38
|
|
|
31
39
|
def keras_export_model(model: keras.models.Model,
|
|
32
40
|
save_model_path: str,
|
|
41
|
+
target_platform_capabilities: TargetPlatformCapabilities,
|
|
33
42
|
is_layer_exportable_fn: Callable = is_keras_layer_exportable,
|
|
34
|
-
|
|
43
|
+
serialization_format: KerasExportSerializationFormat =
|
|
44
|
+
KerasExportSerializationFormat.KERAS_H5) -> \
|
|
45
|
+
Dict[str, type]:
|
|
35
46
|
"""
|
|
36
|
-
Export a Keras quantized model to h5 model.
|
|
47
|
+
Export a Keras quantized model to a h5 or tflite model.
|
|
37
48
|
The model will be saved to the path in save_model_path.
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
49
|
+
keras_export_model supports the combination of QuantizationFormat.FAKELY_QUANT (where weights
|
|
50
|
+
and activations are float fakely-quantized values) and KerasExportSerializationFormat.KERAS_H5 (where the model
|
|
51
|
+
will be saved to h5 model) or the combination of KerasExportSerializationFormat.TFLITE (where the model will be
|
|
52
|
+
saved to tflite model) with QuantizationFormat.FAKELY_QUANT or QuantizationFormat.INT8 (where weights and
|
|
53
|
+
activations are represented using 8bits integers).
|
|
41
54
|
|
|
42
55
|
Args:
|
|
43
56
|
model: Model to export.
|
|
44
|
-
is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
|
|
45
|
-
mode: Mode to export the model according to.
|
|
46
57
|
save_model_path: Path to save the model.
|
|
58
|
+
target_platform_capabilities: TargetPlatformCapabilities object that describes the desired inference
|
|
59
|
+
target platform (includes quantization format).
|
|
60
|
+
is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
|
|
61
|
+
serialization_format: Format to export the model according to (by default
|
|
62
|
+
KerasExportSerializationFormat.KERAS_H5).
|
|
47
63
|
|
|
48
64
|
Returns:
|
|
49
65
|
Custom objects dictionary needed to load the model.
|
|
50
66
|
|
|
51
67
|
"""
|
|
52
68
|
|
|
53
|
-
if
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
69
|
+
if serialization_format == KerasExportSerializationFormat.KERAS_H5:
|
|
70
|
+
if target_platform_capabilities.tp_model.quantization_format == QuantizationFormat.FAKELY_QUANT:
|
|
71
|
+
exporter = FakelyQuantKerasExporter(model,
|
|
72
|
+
is_layer_exportable_fn,
|
|
73
|
+
save_model_path)
|
|
74
|
+
else:
|
|
75
|
+
Logger.critical(
|
|
76
|
+
f'Unsupported quantization {target_platform_capabilities.tp_model.quantization_format} for '
|
|
77
|
+
f'serialization {serialization_format} was used to export Keras model. Please see API for '
|
|
78
|
+
f'supported formats.') # pragma: no cover
|
|
79
|
+
|
|
80
|
+
elif serialization_format == KerasExportSerializationFormat.TFLITE:
|
|
81
|
+
if target_platform_capabilities.tp_model.quantization_format == QuantizationFormat.FAKELY_QUANT:
|
|
82
|
+
exporter = FakelyQuantTFLiteExporter(model,
|
|
83
|
+
is_layer_exportable_fn,
|
|
84
|
+
save_model_path)
|
|
85
|
+
|
|
86
|
+
elif target_platform_capabilities.tp_model.quantization_format == QuantizationFormat.INT8:
|
|
87
|
+
exporter = INT8TFLiteExporter(model,
|
|
88
|
+
is_layer_exportable_fn,
|
|
89
|
+
save_model_path)
|
|
90
|
+
else:
|
|
91
|
+
Logger.critical(
|
|
92
|
+
f'Unsupported quantization {target_platform_capabilities.tp_model.quantization_format} for '
|
|
93
|
+
f'serialization {serialization_format} was used to export Keras model. Please see API for '
|
|
94
|
+
f'supported formats.') # pragma: no cover
|
|
57
95
|
|
|
58
96
|
else:
|
|
59
97
|
Logger.critical(
|
|
60
|
-
f'Unsupported
|
|
61
|
-
f'
|
|
98
|
+
f'Unsupported serialization {serialization_format} was used to export Keras model. Please see API '
|
|
99
|
+
f'for supported formats.') # pragma: no cover
|
|
62
100
|
|
|
63
101
|
exporter.export()
|
|
64
102
|
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from enum import Enum
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PytorchExportSerializationFormat(Enum):
|
|
19
|
+
TORCHSCRIPT = 0
|
|
20
|
+
ONNX = 1
|
model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py
CHANGED
|
@@ -16,17 +16,21 @@ from typing import Callable
|
|
|
16
16
|
|
|
17
17
|
import torch.nn
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit.
|
|
19
|
+
from model_compression_toolkit.logger import Logger
|
|
20
20
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
21
21
|
from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter
|
|
22
22
|
from packaging import version
|
|
23
23
|
|
|
24
|
+
from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
|
|
25
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import LAYER
|
|
26
|
+
|
|
24
27
|
# ONNX opset version 16 is supported from PyTorch 1.12
|
|
25
28
|
if version.parse(torch.__version__) < version.parse("1.12"):
|
|
26
29
|
OPSET_VERSION = 15
|
|
27
30
|
else:
|
|
28
31
|
OPSET_VERSION = 16
|
|
29
32
|
|
|
33
|
+
|
|
30
34
|
class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
|
|
31
35
|
"""
|
|
32
36
|
Exporter for fakely-quant PyTorch models.
|
|
@@ -70,6 +74,16 @@ class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
|
|
|
70
74
|
|
|
71
75
|
Logger.info(f"Exporting PyTorch fake quant onnx model: {self.save_model_path}")
|
|
72
76
|
|
|
77
|
+
# Replace float weight with wrapped quantized weights
|
|
78
|
+
for layer in self.model.modules():
|
|
79
|
+
if isinstance(layer, PytorchQuantizationWrapper):
|
|
80
|
+
for name in layer.weights_quantizers.keys():
|
|
81
|
+
quantized_weight = torch.nn.Parameter(layer.get_quantized_weights()[name]).detach()
|
|
82
|
+
linear_layer = getattr(layer, LAYER)
|
|
83
|
+
delattr(linear_layer, name)
|
|
84
|
+
setattr(linear_layer, name, torch.nn.Parameter(quantized_weight))
|
|
85
|
+
layer.weights_quantizers = {}
|
|
86
|
+
|
|
73
87
|
torch.onnx.export(self.model,
|
|
74
88
|
model_input,
|
|
75
89
|
self.save_model_path,
|
|
@@ -16,7 +16,7 @@ from typing import Callable
|
|
|
16
16
|
|
|
17
17
|
import torch.nn
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit.
|
|
19
|
+
from model_compression_toolkit.logger import Logger
|
|
20
20
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
21
21
|
from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter
|
|
22
22
|
|
|
@@ -12,63 +12,86 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from enum import Enum
|
|
16
15
|
from typing import Callable
|
|
17
16
|
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
17
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
|
18
|
+
from model_compression_toolkit.exporter.model_exporter.pytorch.export_serialization_format import \
|
|
19
|
+
PytorchExportSerializationFormat
|
|
20
|
+
from model_compression_toolkit.logger import Logger
|
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.quantization_format import \
|
|
23
|
+
QuantizationFormat
|
|
26
24
|
|
|
27
25
|
if FOUND_TORCH:
|
|
28
26
|
import torch.nn
|
|
29
|
-
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import
|
|
30
|
-
|
|
27
|
+
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import \
|
|
28
|
+
FakelyQuantONNXPyTorchExporter
|
|
29
|
+
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_torchscript_pytorch_exporter import \
|
|
30
|
+
FakelyQuantTorchScriptPyTorchExporter
|
|
31
31
|
from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable
|
|
32
32
|
|
|
33
|
+
supported_serialization_quantization_export_dict = {
|
|
34
|
+
PytorchExportSerializationFormat.TORCHSCRIPT: [QuantizationFormat.FAKELY_QUANT],
|
|
35
|
+
PytorchExportSerializationFormat.ONNX: [QuantizationFormat.FAKELY_QUANT]
|
|
36
|
+
}
|
|
37
|
+
|
|
33
38
|
def pytorch_export_model(model: torch.nn.Module,
|
|
34
39
|
save_model_path: str,
|
|
35
40
|
repr_dataset: Callable,
|
|
41
|
+
target_platform_capabilities: TargetPlatformCapabilities,
|
|
36
42
|
is_layer_exportable_fn: Callable = is_pytorch_layer_exportable,
|
|
37
|
-
|
|
43
|
+
serialization_format: PytorchExportSerializationFormat =
|
|
44
|
+
PytorchExportSerializationFormat.TORCHSCRIPT) -> None:
|
|
38
45
|
"""
|
|
39
46
|
Export a PyTorch quantized model to a torchscript or onnx model.
|
|
40
47
|
The model will be saved to the path in save_model_path.
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
is in an ONNX format and its weights and activations are float fakely-quantized values)
|
|
48
|
+
Currently, pytorch_export_model supports only QuantizationFormat.FAKELY_QUANT (where weights
|
|
49
|
+
and activations are float fakely-quantized values) and PytorchExportSerializationFormat.TORCHSCRIPT
|
|
50
|
+
(where the model will be saved to TorchScript model) or PytorchExportSerializationFormat.ONNX
|
|
51
|
+
(where the model will be saved to ONNX model).
|
|
46
52
|
|
|
47
53
|
Args:
|
|
48
54
|
model: Model to export.
|
|
49
|
-
is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
|
|
50
|
-
mode: Mode to export the model according to.
|
|
51
55
|
save_model_path: Path to save the model.
|
|
52
56
|
repr_dataset: Representative dataset for tracing the pytorch model (mandatory for exporting it).
|
|
57
|
+
target_platform_capabilities: TargetPlatformCapabilities object that describes the desired inference
|
|
58
|
+
target platform (includes quantization format).
|
|
59
|
+
is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
|
|
60
|
+
serialization_format: Format to export the model according to (by default
|
|
61
|
+
PytorchExportSerializationFormat.TORCHSCRIPT).
|
|
53
62
|
|
|
54
63
|
"""
|
|
55
64
|
|
|
56
|
-
if
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
65
|
+
if serialization_format == PytorchExportSerializationFormat.TORCHSCRIPT:
|
|
66
|
+
if target_platform_capabilities.tp_model.quantization_format in \
|
|
67
|
+
supported_serialization_quantization_export_dict[serialization_format]:
|
|
68
|
+
exporter = FakelyQuantTorchScriptPyTorchExporter(model,
|
|
69
|
+
is_layer_exportable_fn,
|
|
70
|
+
save_model_path,
|
|
71
|
+
repr_dataset)
|
|
72
|
+
else:
|
|
73
|
+
Logger.critical(
|
|
74
|
+
f'Unsupported quantization {target_platform_capabilities.tp_model.quantization_format} for '
|
|
75
|
+
f'serialization {serialization_format} was used to export Pytorch model. Please see API for '
|
|
76
|
+
f'supported formats.') # pragma: no cover
|
|
61
77
|
|
|
62
|
-
elif
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
78
|
+
elif serialization_format == PytorchExportSerializationFormat.ONNX:
|
|
79
|
+
if target_platform_capabilities.tp_model.quantization_format in \
|
|
80
|
+
supported_serialization_quantization_export_dict[serialization_format]:
|
|
81
|
+
exporter = FakelyQuantONNXPyTorchExporter(model,
|
|
82
|
+
is_layer_exportable_fn,
|
|
83
|
+
save_model_path,
|
|
84
|
+
repr_dataset)
|
|
85
|
+
else:
|
|
86
|
+
Logger.critical(
|
|
87
|
+
f'Unsupported quantization {target_platform_capabilities.tp_model.quantization_format} for '
|
|
88
|
+
f'serialization {serialization_format} was used to export Pytorch model. Please see API for '
|
|
89
|
+
f'supported formats.') # pragma: no cover
|
|
67
90
|
|
|
68
91
|
else:
|
|
69
92
|
Logger.critical(
|
|
70
|
-
f'Unsupported
|
|
71
|
-
f'
|
|
93
|
+
f'Unsupported serialization {serialization_format} was used to export Pytorch model. Please see API '
|
|
94
|
+
f'for supported formats.') # pragma: no cover
|
|
72
95
|
|
|
73
96
|
exporter.export()
|
|
74
97
|
|
model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -17,9 +17,10 @@ from typing import Tuple
|
|
|
17
17
|
|
|
18
18
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
19
19
|
from model_compression_toolkit.core import common
|
|
20
|
-
from model_compression_toolkit.core.common import Graph
|
|
21
|
-
from model_compression_toolkit.
|
|
20
|
+
from model_compression_toolkit.core.common import Graph
|
|
21
|
+
from model_compression_toolkit.constants import FOUND_TF
|
|
22
22
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
23
|
+
from model_compression_toolkit.logger import Logger
|
|
23
24
|
|
|
24
25
|
if FOUND_TF:
|
|
25
26
|
import tensorflow as tf
|
|
@@ -34,6 +35,7 @@ if FOUND_TF:
|
|
|
34
35
|
Args:
|
|
35
36
|
n: A node of mct graph.
|
|
36
37
|
layer: A keras layer
|
|
38
|
+
include_activation_quantizers: Whether to use the wrapper for the activation quantizer or not
|
|
37
39
|
|
|
38
40
|
Returns: Wrapped layer with weights quantizers and activation quantizers
|
|
39
41
|
|
|
@@ -55,7 +57,7 @@ if FOUND_TF:
|
|
|
55
57
|
Exportable Keras model and user information.
|
|
56
58
|
"""
|
|
57
59
|
exportable_model, user_info = KerasModelBuilder(graph=graph,
|
|
58
|
-
|
|
60
|
+
wrapper=_get_wrapper).build_model()
|
|
59
61
|
exportable_model.trainable = False
|
|
60
62
|
return exportable_model, user_info
|
|
61
63
|
else:
|
|
@@ -14,8 +14,10 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Dict, Any
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.core.common import BaseNode
|
|
18
|
-
from model_compression_toolkit.
|
|
17
|
+
from model_compression_toolkit.core.common import BaseNode
|
|
18
|
+
from model_compression_toolkit.constants import THRESHOLD, RANGE_MIN, RANGE_MAX, SIGNED, CLUSTER_CENTERS, SCALE_PER_CHANNEL
|
|
19
|
+
|
|
20
|
+
from model_compression_toolkit.logger import Logger
|
|
19
21
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
20
22
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget
|
|
21
23
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import get_inferable_quantizer_class
|
|
@@ -15,8 +15,8 @@
|
|
|
15
15
|
from typing import Any
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.constants import FOUND_TF
|
|
20
20
|
|
|
21
21
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer
|
|
22
22
|
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -16,8 +16,9 @@
|
|
|
16
16
|
|
|
17
17
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
18
18
|
from model_compression_toolkit.core import common
|
|
19
|
-
from model_compression_toolkit.core.common import Graph
|
|
20
|
-
from model_compression_toolkit.
|
|
19
|
+
from model_compression_toolkit.core.common import Graph
|
|
20
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
|
21
|
+
from model_compression_toolkit.logger import Logger
|
|
21
22
|
|
|
22
23
|
if FOUND_TORCH:
|
|
23
24
|
import torch
|
|
@@ -15,9 +15,10 @@
|
|
|
15
15
|
|
|
16
16
|
from typing import Dict, Any
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.core.common import BaseNode
|
|
19
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.core.common import BaseNode
|
|
19
|
+
from model_compression_toolkit.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX, \
|
|
20
20
|
SCALE_PER_CHANNEL, CLUSTER_CENTERS
|
|
21
|
+
from model_compression_toolkit.logger import Logger
|
|
21
22
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
22
23
|
from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
|
|
23
24
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
|
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Any
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.
|
|
18
|
-
from model_compression_toolkit.
|
|
17
|
+
from model_compression_toolkit.logger import Logger
|
|
18
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
|
19
19
|
|
|
20
20
|
if FOUND_TORCH:
|
|
21
21
|
from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from abc import abstractmethod
|
|
17
|
+
|
|
18
|
+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GPTQFrameworkImplemantation(FrameworkImplementation):
|
|
22
|
+
"""
|
|
23
|
+
Class to implement framework related methods that are used in GPTQ
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def get_gptq_trainer_obj(self):
|
|
28
|
+
"""
|
|
29
|
+
Returns: GPTQTrainer object
|
|
30
|
+
"""
|
|
31
|
+
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
32
|
+
f'framework\'s get_gptq_trainer method.') # pragma: no cover
|
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Tuple, List
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit import FrameworkInfo
|
|
18
|
-
from model_compression_toolkit.
|
|
17
|
+
from model_compression_toolkit.core import FrameworkInfo
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
19
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
20
20
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
21
21
|
|
|
@@ -17,12 +17,13 @@ from abc import ABC, abstractmethod
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
from typing import Callable, List, Any
|
|
19
19
|
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
20
|
-
from model_compression_toolkit.core.common import Graph,
|
|
20
|
+
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
21
21
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
22
|
-
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
23
22
|
from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR
|
|
23
|
+
from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
|
|
24
24
|
from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points
|
|
25
25
|
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
|
|
26
|
+
from model_compression_toolkit.logger import Logger
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
class GPTQTrainer(ABC):
|
|
@@ -34,7 +35,7 @@ class GPTQTrainer(ABC):
|
|
|
34
35
|
graph_float: Graph,
|
|
35
36
|
graph_quant: Graph,
|
|
36
37
|
gptq_config: GradientPTQConfig,
|
|
37
|
-
fw_impl:
|
|
38
|
+
fw_impl: GPTQFrameworkImplemantation,
|
|
38
39
|
fw_info: FrameworkInfo):
|
|
39
40
|
"""
|
|
40
41
|
Build two models from a graph: A teacher network (float model) and a student network (quantized model).
|
|
@@ -259,7 +260,7 @@ def gptq_training(graph_float: Graph,
|
|
|
259
260
|
graph_quant: Graph,
|
|
260
261
|
gptq_config: GradientPTQConfig,
|
|
261
262
|
representative_data_gen: Callable,
|
|
262
|
-
fw_impl:
|
|
263
|
+
fw_impl: GPTQFrameworkImplemantation,
|
|
263
264
|
fw_info: FrameworkInfo) -> Graph:
|
|
264
265
|
"""
|
|
265
266
|
GPTQ training process using knowledge distillation with a teacher network (float model) and a student network (quantized model).
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Type
|
|
17
|
+
|
|
18
|
+
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
|
19
|
+
from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
|
|
20
|
+
from model_compression_toolkit.gptq.keras.gptq_training import KerasGPTQTrainer
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class GPTQKerasImplemantation(GPTQFrameworkImplemantation, KerasImplementation):
|
|
24
|
+
|
|
25
|
+
def get_gptq_trainer_obj(self) -> Type[KerasGPTQTrainer]:
|
|
26
|
+
"""
|
|
27
|
+
Returns: Keras object of GPTQTrainer
|
|
28
|
+
"""
|
|
29
|
+
return KerasGPTQTrainer
|