mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +12 -41
- model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
- model_compression_toolkit/core/__init__.py +14 -0
- model_compression_toolkit/core/analyzer.py +3 -2
- model_compression_toolkit/core/common/__init__.py +0 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +1 -8
- model_compression_toolkit/core/common/framework_info.py +1 -1
- model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
- model_compression_toolkit/core/common/graph/base_graph.py +2 -2
- model_compression_toolkit/core/common/graph/base_node.py +57 -1
- model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
- model_compression_toolkit/core/common/memory_computation.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/model_collector.py +2 -2
- model_compression_toolkit/core/common/model_validation.py +1 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -1
- model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
- model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
- model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
- model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
- model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
- model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/keras/constants.py +0 -7
- model_compression_toolkit/core/keras/default_framework_info.py +3 -3
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/keras/keras_implementation.py +2 -10
- model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
- model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
- model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +1 -1
- model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/constants.py +4 -6
- model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
- model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
- model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/runner.py +7 -7
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -104
- model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
- model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
- model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
- model_compression_toolkit/gptq/common/gptq_training.py +30 -39
- model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
- model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
- model_compression_toolkit/gptq/keras/graph_info.py +8 -33
- model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
- model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
- model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
- model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
- model_compression_toolkit/gptq/runner.py +3 -2
- model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
- model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
- model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
- model_compression_toolkit/ptq/__init__.py +3 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
- model_compression_toolkit/qat/__init__.py +4 -0
- model_compression_toolkit/qat/common/__init__.py +1 -2
- model_compression_toolkit/qat/common/qat_config.py +3 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
- model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
- model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
- model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
- model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
- model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
- model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
- /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
|
@@ -14,61 +14,69 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Any
|
|
16
16
|
|
|
17
|
-
from keras.engine.input_layer import InputLayer
|
|
18
|
-
|
|
19
|
-
from model_compression_toolkit.core.common import Logger
|
|
20
|
-
from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
|
|
21
|
-
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer
|
|
22
17
|
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.constants import FOUND_TF
|
|
23
20
|
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer
|
|
24
22
|
|
|
25
|
-
def is_keras_layer_exportable(layer: Any) -> bool:
|
|
26
|
-
"""
|
|
27
|
-
Check whether a Keras layer is a valid exportable layer or not.
|
|
28
23
|
|
|
29
|
-
|
|
30
|
-
|
|
24
|
+
if FOUND_TF:
|
|
25
|
+
from keras.engine.input_layer import InputLayer
|
|
26
|
+
from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
|
|
31
27
|
|
|
32
|
-
|
|
28
|
+
def is_keras_layer_exportable(layer: Any) -> bool:
|
|
29
|
+
"""
|
|
33
30
|
Check whether a Keras layer is a valid exportable layer or not.
|
|
34
|
-
"""
|
|
35
|
-
# Keras Input layers are not wrapped
|
|
36
|
-
if isinstance(layer, InputLayer):
|
|
37
|
-
return True
|
|
38
31
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
Logger.error(
|
|
42
|
-
f'Exportable layer must be wrapped using KerasQuantizationWrapper, but layer {layer.name} is of type '
|
|
43
|
-
f'{type(layer)}') # pragma: no cover
|
|
32
|
+
Args:
|
|
33
|
+
layer: Keras layer to check if considered to be valid for exporting.
|
|
44
34
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
35
|
+
Returns:
|
|
36
|
+
Check whether a Keras layer is a valid exportable layer or not.
|
|
37
|
+
"""
|
|
38
|
+
# Keras Input layers are not wrapped
|
|
39
|
+
if isinstance(layer, InputLayer):
|
|
40
|
+
return True
|
|
50
41
|
|
|
51
|
-
|
|
52
|
-
if not
|
|
42
|
+
valid_layer = isinstance(layer, KerasQuantizationWrapper)
|
|
43
|
+
if not valid_layer:
|
|
53
44
|
Logger.error(
|
|
54
|
-
f'
|
|
55
|
-
f'{type(
|
|
45
|
+
f'Exportable layer must be wrapped using KerasQuantizationWrapper, but layer {layer.name} is of type '
|
|
46
|
+
f'{type(layer)}') # pragma: no cover
|
|
56
47
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
48
|
+
valid_weights_quantizers = isinstance(layer.weights_quantizers, dict)
|
|
49
|
+
if not valid_weights_quantizers:
|
|
50
|
+
Logger.error(
|
|
51
|
+
f'KerasQuantizationWrapper must have a weights_quantizers but has a '
|
|
52
|
+
f'{type(layer.weights_quantizers)} object') # pragma: no cover
|
|
53
|
+
|
|
54
|
+
for _, weights_quantizer in layer.weights_quantizers.items():
|
|
55
|
+
if not isinstance(weights_quantizer, BaseInferableQuantizer):
|
|
56
|
+
Logger.error(
|
|
57
|
+
f'weights_quantizer must be a BaseInferableQuantizer object but has a '
|
|
58
|
+
f'{type(weights_quantizer)} object') # pragma: no cover
|
|
62
59
|
|
|
63
|
-
|
|
64
|
-
if not
|
|
60
|
+
valid_activation_quantizers = isinstance(layer.activation_quantizers, list)
|
|
61
|
+
if not valid_activation_quantizers:
|
|
65
62
|
Logger.error(
|
|
66
|
-
f'
|
|
67
|
-
f'{type(activation_quantizers)} object')
|
|
63
|
+
f'KerasQuantizationWrapper must have a activation_quantizers list but has a '
|
|
64
|
+
f'{type(layer.activation_quantizers)} object') # pragma: no cover
|
|
68
65
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
66
|
+
for activation_quantizers in layer.activation_quantizers:
|
|
67
|
+
if not isinstance(activation_quantizers, BaseInferableQuantizer):
|
|
68
|
+
Logger.error(
|
|
69
|
+
f'activation_quantizers must be a BaseInferableQuantizer object but has a '
|
|
70
|
+
f'{type(activation_quantizers)} object') # pragma: no cover
|
|
73
71
|
|
|
74
|
-
|
|
72
|
+
quantizers = layer.activation_quantizers + list(layer.weights_quantizers.values())
|
|
73
|
+
is_valid_quantizers = all([isinstance(x, BaseInferableQuantizer) for x in quantizers])
|
|
74
|
+
if not is_valid_quantizers:
|
|
75
|
+
Logger.error(f'Found a quantizer that is not of type BaseInferableQuantizer') # pragma: no cover
|
|
76
|
+
|
|
77
|
+
return True
|
|
78
|
+
else:
|
|
79
|
+
def is_keras_layer_exportable(*args, **kwargs): # pragma: no cover
|
|
80
|
+
Logger.error('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
81
|
+
'when using is_keras_layer_exportable. '
|
|
82
|
+
'Could not find Tensorflow package.')
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -13,42 +13,50 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import torch
|
|
17
16
|
|
|
18
17
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
19
18
|
from model_compression_toolkit.core import common
|
|
20
19
|
from model_compression_toolkit.core.common import Graph
|
|
21
|
-
from model_compression_toolkit.
|
|
22
|
-
from model_compression_toolkit.
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
20
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
|
21
|
+
from model_compression_toolkit.logger import Logger
|
|
22
|
+
|
|
23
|
+
if FOUND_TORCH:
|
|
24
|
+
import torch
|
|
25
|
+
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
|
|
26
|
+
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizers import \
|
|
27
|
+
get_quantization_quantizers
|
|
28
|
+
|
|
29
|
+
def fully_quantized_wrapper(node: common.BaseNode, module: torch.nn.Module) -> qi.PytorchQuantizationWrapper:
|
|
30
|
+
"""
|
|
31
|
+
A function which takes a computational graph node and a pytorch module and
|
|
32
|
+
perform the quantization wrapping
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
node: A node of mct graph.
|
|
36
|
+
module: A Pytorch module
|
|
37
|
+
|
|
38
|
+
Returns: Wrapped layer
|
|
39
|
+
|
|
40
|
+
"""
|
|
41
|
+
weight_quantizers, activation_quantizers = get_quantization_quantizers(node)
|
|
42
|
+
wrapped_layer = qi.PytorchQuantizationWrapper(module, weight_quantizers, activation_quantizers)
|
|
43
|
+
return wrapped_layer
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_exportable_pytorch_model(graph: Graph):
|
|
47
|
+
"""
|
|
48
|
+
Convert graph to fully quantized PyTorch model.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
graph: Graph to convert to a PyTorch model.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Fully quantized PyTorch model.
|
|
55
|
+
"""
|
|
56
|
+
return PyTorchModelBuilder(graph=graph,
|
|
57
|
+
wrapper=fully_quantized_wrapper).build_model()
|
|
58
|
+
else:
|
|
59
|
+
def get_exportable_pytorch_model(*args, **kwargs): # pragma: no cover
|
|
60
|
+
Logger.error('Installing torch is mandatory '
|
|
61
|
+
'when using get_exportable_pytorch_model. '
|
|
62
|
+
'Could not find PyTorch package.')
|
|
@@ -15,9 +15,11 @@
|
|
|
15
15
|
|
|
16
16
|
from typing import Dict, Any
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.core.common import BaseNode
|
|
19
|
-
from model_compression_toolkit.
|
|
20
|
-
|
|
18
|
+
from model_compression_toolkit.core.common import BaseNode
|
|
19
|
+
from model_compression_toolkit.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX, \
|
|
20
|
+
SCALE_PER_CHANNEL, CLUSTER_CENTERS
|
|
21
|
+
from model_compression_toolkit.logger import Logger
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
21
23
|
from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
|
|
22
24
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
|
|
23
25
|
get_inferable_quantizer_class
|
|
@@ -45,6 +47,15 @@ def get_weights_inferable_quantizer_kwargs(node: BaseNode) -> Dict[str, Any]:
|
|
|
45
47
|
qi_inferable_quantizers_constants.MIN_RANGE: node_w_qc.weights_quantization_params[RANGE_MIN].flatten(),
|
|
46
48
|
qi_inferable_quantizers_constants.MAX_RANGE: node_w_qc.weights_quantization_params[RANGE_MAX].flatten(),
|
|
47
49
|
qi_inferable_quantizers_constants.CHANNEL_AXIS: node_w_qc.weights_channels_axis}
|
|
50
|
+
|
|
51
|
+
elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER, QuantizationMethod.LUT_SYM_QUANTIZER]:
|
|
52
|
+
return {qi_inferable_quantizers_constants.NUM_BITS: node_w_qc.weights_n_bits,
|
|
53
|
+
qi_inferable_quantizers_constants.CLUSTER_CENTERS: node_w_qc.weights_quantization_params[CLUSTER_CENTERS].flatten(),
|
|
54
|
+
qi_inferable_quantizers_constants.THRESHOLD: node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten(),
|
|
55
|
+
qi_inferable_quantizers_constants.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
|
|
56
|
+
qi_inferable_quantizers_constants.CHANNEL_AXIS: node_w_qc.weights_channels_axis}
|
|
57
|
+
# TODO: Add MULTIPLIER_N_BITS & EPS to node quantization config
|
|
58
|
+
|
|
48
59
|
else:
|
|
49
60
|
Logger.critical(f'Not supported quantization method for weights inferable quantizers.') # pragma: no cover
|
|
50
61
|
|
|
@@ -65,6 +76,15 @@ def get_activation_inferable_quantizer_kwargs(node: BaseNode) -> Dict[str, Any]:
|
|
|
65
76
|
return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.activation_n_bits,
|
|
66
77
|
qi_inferable_quantizers_constants.MIN_RANGE: np.asarray([node_qc.activation_quantization_params[RANGE_MIN]]),
|
|
67
78
|
qi_inferable_quantizers_constants.MAX_RANGE: np.asarray([node_qc.activation_quantization_params[RANGE_MAX]])}
|
|
79
|
+
|
|
80
|
+
elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER]:
|
|
81
|
+
return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.activation_n_bits,
|
|
82
|
+
qi_inferable_quantizers_constants.CLUSTER_CENTERS: np.asarray(
|
|
83
|
+
[node_qc.activation_quantization_params[CLUSTER_CENTERS]]),
|
|
84
|
+
qi_inferable_quantizers_constants.THRESHOLD: np.asarray(
|
|
85
|
+
[node_qc.activation_quantization_params[THRESHOLD]]),
|
|
86
|
+
qi_inferable_quantizers_constants.SIGNED: node_qc.activation_quantization_params.get(SIGNED)}
|
|
87
|
+
# TODO: Add MULTIPLIER_N_BITS & EPS to node quantization config
|
|
68
88
|
else:
|
|
69
89
|
Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
|
|
70
90
|
|
|
@@ -111,10 +131,10 @@ def get_activations_quantizer_for_node(node: BaseNode) -> BasePyTorchInferableQu
|
|
|
111
131
|
node_act_qc = node.final_activation_quantization_cfg
|
|
112
132
|
activation_quantization_method = node_act_qc.activation_quantization_method
|
|
113
133
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
134
|
+
quantizer_for_node = get_inferable_quantizer_class(QuantizationTarget.Activation,
|
|
135
|
+
activation_quantization_method,
|
|
136
|
+
BasePyTorchInferableQuantizer)
|
|
117
137
|
kwargs = get_activation_inferable_quantizer_kwargs(node)
|
|
118
138
|
|
|
119
|
-
return
|
|
139
|
+
return quantizer_for_node(**kwargs)
|
|
120
140
|
|
|
@@ -14,24 +14,31 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Any
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
BasePyTorchInferableQuantizer
|
|
17
|
+
from model_compression_toolkit.logger import Logger
|
|
18
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
|
20
19
|
|
|
20
|
+
if FOUND_TORCH:
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
|
|
22
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.pytorch.quantizers import \
|
|
23
|
+
BasePyTorchInferableQuantizer
|
|
24
|
+
def is_pytorch_layer_exportable(layer: Any) -> bool:
|
|
25
|
+
"""
|
|
26
|
+
Check whether a torch Module is a valid exportable module or not.
|
|
21
27
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
Check whether a torch Module is a valid exportable module or not.
|
|
28
|
+
Args:
|
|
29
|
+
layer: PyTorch module to check if considered to be valid for exporting.
|
|
25
30
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
31
|
+
Returns:
|
|
32
|
+
Check whether a PyTorch layer is a valid exportable layer or not.
|
|
33
|
+
"""
|
|
34
|
+
if isinstance(layer, PytorchQuantizationWrapper):
|
|
35
|
+
quantizers = list(layer.weights_quantizers.values())
|
|
36
|
+
quantizers.extend(layer.activation_quantizers)
|
|
37
|
+
if all([isinstance(q, BasePyTorchInferableQuantizer) for q in quantizers]):
|
|
38
|
+
return True
|
|
39
|
+
return False
|
|
40
|
+
else:
|
|
41
|
+
def is_pytorch_layer_exportable(*args, **kwargs): # pragma: no cover
|
|
42
|
+
Logger.error('Installing torch is mandatory '
|
|
43
|
+
'when using is_pytorch_layer_exportable. '
|
|
44
|
+
'Could not find PyTorch package.')
|
|
@@ -12,3 +12,9 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, GradientPTQConfigV2
|
|
17
|
+
from model_compression_toolkit.gptq.keras.quantization_facade import keras_gradient_post_training_quantization_experimental
|
|
18
|
+
from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config
|
|
19
|
+
from model_compression_toolkit.gptq.pytorch.quantization_facade import pytorch_gradient_post_training_quantization_experimental
|
|
20
|
+
from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config
|
|
@@ -16,9 +16,7 @@ from enum import Enum
|
|
|
16
16
|
from typing import Callable, Any, Dict
|
|
17
17
|
from model_compression_toolkit.core.common.defaultdict import DefaultDict
|
|
18
18
|
from model_compression_toolkit.core import common
|
|
19
|
-
from model_compression_toolkit.gptq.common.gptq_constants import
|
|
20
|
-
MAX_LSB_STR
|
|
21
|
-
from model_compression_toolkit.gptq.common.gptq_quantizer_config import GPTQQuantizerConfig, SoftQuantizerConfig
|
|
19
|
+
from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR, MAX_LSB_STR, REG_DEFAULT
|
|
22
20
|
|
|
23
21
|
|
|
24
22
|
class RoundingType(Enum):
|
|
@@ -31,30 +29,53 @@ class RoundingType(Enum):
|
|
|
31
29
|
SoftQuantizer = 1
|
|
32
30
|
|
|
33
31
|
|
|
32
|
+
class GPTQHessianWeightsConfig:
|
|
33
|
+
"""
|
|
34
|
+
Configuration to use for computing the Hessian-based weights for GPTQ loss metric.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self,
|
|
38
|
+
hessians_num_samples: int = 16,
|
|
39
|
+
norm_weights: bool = True,
|
|
40
|
+
log_norm: bool = True,
|
|
41
|
+
scale_log_norm: bool = False,
|
|
42
|
+
hessians_n_iter: int = 50):
|
|
43
|
+
|
|
44
|
+
"""
|
|
45
|
+
Initialize a GPTQHessianWeightsConfig.
|
|
46
|
+
Args:
|
|
47
|
+
hessians_num_samples (int): Number of samples to use for computing the Hessian-based weights.
|
|
48
|
+
norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
|
|
49
|
+
log_norm (bool): Whether to use log normalization to the GPTQ Hessian-based weights.
|
|
50
|
+
scale_log_norm (bool): Whether to scale the final vector of the Hessian weights.
|
|
51
|
+
hessians_n_iter (int): Number of random iterations to run Hessian approximation for GPTQ weights.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
self.hessians_num_samples = hessians_num_samples
|
|
55
|
+
self.norm_weights = norm_weights
|
|
56
|
+
self.log_norm = log_norm
|
|
57
|
+
self.scale_log_norm = scale_log_norm
|
|
58
|
+
self.hessians_n_iter = hessians_n_iter
|
|
59
|
+
|
|
60
|
+
|
|
34
61
|
class GradientPTQConfig:
|
|
35
62
|
"""
|
|
36
63
|
Configuration to use for quantization with GradientPTQ (experimental).
|
|
37
64
|
"""
|
|
38
65
|
|
|
39
|
-
def __init__(self,
|
|
40
|
-
n_iter: int,
|
|
66
|
+
def __init__(self, n_iter: int,
|
|
41
67
|
optimizer: Any,
|
|
42
68
|
optimizer_rest: Any = None,
|
|
43
69
|
loss: Callable = None,
|
|
44
70
|
log_function: Callable = None,
|
|
45
71
|
train_bias: bool = True,
|
|
46
|
-
quantization_parameters_learning: bool = False,
|
|
47
72
|
rounding_type: RoundingType = RoundingType.SoftQuantizer,
|
|
48
|
-
|
|
49
|
-
eps: float = 1e-6,
|
|
50
|
-
use_jac_based_weights: bool = True,
|
|
51
|
-
num_samples_for_loss: int = 16,
|
|
52
|
-
norm_weights: bool = False,
|
|
73
|
+
use_hessian_based_weights: bool = True,
|
|
53
74
|
optimizer_quantization_parameter: Any = None,
|
|
54
75
|
optimizer_bias: Any = None,
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
76
|
+
regularization_factor: float = REG_DEFAULT,
|
|
77
|
+
hessian_weights_config: GPTQHessianWeightsConfig = GPTQHessianWeightsConfig(),
|
|
78
|
+
gptq_quantizer_params_override: Dict[str, Any] = None):
|
|
58
79
|
"""
|
|
59
80
|
Initialize a GradientPTQConfig.
|
|
60
81
|
|
|
@@ -67,18 +88,13 @@ class GradientPTQConfig:
|
|
|
67
88
|
accordingly. see example in multiple_tensors_mse_loss
|
|
68
89
|
log_function (Callable): Function to log information about the GPTQ process.
|
|
69
90
|
train_bias (bool): Whether to update the bias during the training or not.
|
|
70
|
-
quantization_parameters_learning (bool): Whether to update the quantization param during the training or not.
|
|
71
91
|
rounding_type (RoundingType): An enum that defines the rounding type.
|
|
72
|
-
|
|
73
|
-
eps (float): A floating point value for numeric stability.
|
|
74
|
-
use_jac_based_weights (bool): Whether to use jacobian-based weights for weighted average loss.
|
|
75
|
-
num_samples_for_loss (int): Number of samples to use for computing the jacobian-based weights.
|
|
76
|
-
norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
|
|
92
|
+
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
|
|
77
93
|
optimizer_quantization_parameter (Any): Optimizer to override the rest optimizer for quantizer parameters.
|
|
78
94
|
optimizer_bias (Any): Optimizer to override the rest optimizer for bias.
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
95
|
+
regularization_factor (float): A floating point number that defines the regularization factor.
|
|
96
|
+
hessian_weights_config (GPTQHessianWeightsConfig): A configuration that include all necessary arguments to run a computation of Hessian weights for the GPTQ loss.
|
|
97
|
+
gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters).
|
|
82
98
|
|
|
83
99
|
"""
|
|
84
100
|
self.n_iter = n_iter
|
|
@@ -88,68 +104,34 @@ class GradientPTQConfig:
|
|
|
88
104
|
self.log_function = log_function
|
|
89
105
|
self.train_bias = train_bias
|
|
90
106
|
|
|
91
|
-
if quantization_parameters_learning and rounding_type == RoundingType.STE:
|
|
92
|
-
common.Logger.error("Quantization parameters learning is not supported with STE rounding.")
|
|
93
|
-
|
|
94
|
-
self.quantization_parameters_learning = quantization_parameters_learning
|
|
95
107
|
self.rounding_type = rounding_type
|
|
96
|
-
self.
|
|
97
|
-
self.eps = eps
|
|
98
|
-
self.use_jac_based_weights = use_jac_based_weights
|
|
99
|
-
self.num_samples_for_loss = num_samples_for_loss
|
|
100
|
-
self.norm_weights = norm_weights
|
|
108
|
+
self.use_hessian_based_weights = use_hessian_based_weights
|
|
101
109
|
self.optimizer_quantization_parameter = optimizer_quantization_parameter
|
|
102
110
|
self.optimizer_bias = optimizer_bias
|
|
103
|
-
self.
|
|
104
|
-
self.
|
|
111
|
+
self.regularization_factor = regularization_factor
|
|
112
|
+
self.hessian_weights_config = hessian_weights_config
|
|
105
113
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
else:
|
|
109
|
-
common.Logger.error(f"Quantizer config of type {type(quantizer_config)} "
|
|
110
|
-
f"is not suitable for rounding type {rounding_type}")
|
|
111
|
-
|
|
112
|
-
def _verify_quantizer_config(self, quantizer_config, rounding_type) -> bool:
|
|
113
|
-
"""
|
|
114
|
-
Verifies that the given quantizer config matches the given rounding type.
|
|
115
|
-
|
|
116
|
-
Args:
|
|
117
|
-
quantizer_config: A quantizer config.
|
|
118
|
-
rounding_type: A RoundingType.
|
|
119
|
-
|
|
120
|
-
Returns: True if the quantizer config matches the rounding type, False otherwise.
|
|
121
|
-
|
|
122
|
-
"""
|
|
123
|
-
if rounding_type == RoundingType.SoftQuantizer:
|
|
124
|
-
return type(quantizer_config) == SoftQuantizerConfig
|
|
125
|
-
|
|
126
|
-
# Here, we compare type() and not isinstance to exclude instance equality because of inheritance
|
|
127
|
-
return type(quantizer_config) == GPTQQuantizerConfig
|
|
114
|
+
self.gptq_quantizer_params_override = {} if gptq_quantizer_params_override is None \
|
|
115
|
+
else gptq_quantizer_params_override
|
|
128
116
|
|
|
129
117
|
|
|
130
118
|
class GradientPTQConfigV2(GradientPTQConfig):
|
|
131
119
|
"""
|
|
132
120
|
Configuration to use for quantization with GradientPTQV2 (experimental).
|
|
133
121
|
"""
|
|
134
|
-
def __init__(self,
|
|
135
|
-
n_epochs: int,
|
|
122
|
+
def __init__(self, n_epochs: int,
|
|
136
123
|
optimizer: Any,
|
|
137
124
|
optimizer_rest: Any = None,
|
|
138
125
|
loss: Callable = None,
|
|
139
126
|
log_function: Callable = None,
|
|
140
127
|
train_bias: bool = True,
|
|
141
|
-
quantization_parameters_learning: bool = False,
|
|
142
128
|
rounding_type: RoundingType = RoundingType.SoftQuantizer,
|
|
143
|
-
|
|
144
|
-
eps: float = 1e-6,
|
|
145
|
-
use_jac_based_weights: bool = True,
|
|
146
|
-
num_samples_for_loss: int = 16,
|
|
147
|
-
norm_weights: bool = False,
|
|
129
|
+
use_hessian_based_weights: bool = True,
|
|
148
130
|
optimizer_quantization_parameter: Any = None,
|
|
149
131
|
optimizer_bias: Any = None,
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
132
|
+
regularization_factor: float = REG_DEFAULT,
|
|
133
|
+
hessian_weights_config: GPTQHessianWeightsConfig = GPTQHessianWeightsConfig(),
|
|
134
|
+
gptq_quantizer_params_override: Dict[str, Any] = None):
|
|
153
135
|
"""
|
|
154
136
|
Initialize a GradientPTQConfigV2.
|
|
155
137
|
|
|
@@ -162,18 +144,13 @@ class GradientPTQConfigV2(GradientPTQConfig):
|
|
|
162
144
|
accordingly. see example in multiple_tensors_mse_loss
|
|
163
145
|
log_function (Callable): Function to log information about the GPTQ process.
|
|
164
146
|
train_bias (bool): Whether to update the bias during the training or not.
|
|
165
|
-
quantization_parameters_learning (bool): Whether to update the quantization param during the training or not.
|
|
166
147
|
rounding_type (RoundingType): An enum that defines the rounding type.
|
|
167
|
-
|
|
168
|
-
eps (float): A floating point value for numeric stability.
|
|
169
|
-
use_jac_based_weights (bool): Whether to use jacobian-based weights for weighted average loss.
|
|
170
|
-
num_samples_for_loss (int): Number of samples to use for computing the jacobian-based weights.
|
|
171
|
-
norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
|
|
148
|
+
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
|
|
172
149
|
optimizer_quantization_parameter (Any): Optimizer to override the rest optimizer for quantizer parameters.
|
|
173
150
|
optimizer_bias (Any): Optimizer to override the rest optimizerfor bias.
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
151
|
+
regularization_factor (float): A floating point number that defines the regularization factor.
|
|
152
|
+
hessian_weights_config (GPTQHessianWeightsConfig): A configuration that include all necessary arguments to run a computation of Hessian weights for the GPTQ loss.
|
|
153
|
+
gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters).
|
|
177
154
|
|
|
178
155
|
"""
|
|
179
156
|
|
|
@@ -183,18 +160,13 @@ class GradientPTQConfigV2(GradientPTQConfig):
|
|
|
183
160
|
loss=loss,
|
|
184
161
|
log_function=log_function,
|
|
185
162
|
train_bias=train_bias,
|
|
186
|
-
quantization_parameters_learning=quantization_parameters_learning,
|
|
187
163
|
rounding_type=rounding_type,
|
|
188
|
-
|
|
189
|
-
eps=eps,
|
|
190
|
-
use_jac_based_weights=use_jac_based_weights,
|
|
191
|
-
num_samples_for_loss=num_samples_for_loss,
|
|
192
|
-
norm_weights=norm_weights,
|
|
164
|
+
use_hessian_based_weights=use_hessian_based_weights,
|
|
193
165
|
optimizer_quantization_parameter=optimizer_quantization_parameter,
|
|
194
166
|
optimizer_bias=optimizer_bias,
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
167
|
+
regularization_factor=regularization_factor,
|
|
168
|
+
hessian_weights_config=hessian_weights_config,
|
|
169
|
+
gptq_quantizer_params_override=gptq_quantizer_params_override)
|
|
198
170
|
self.n_epochs = n_epochs
|
|
199
171
|
|
|
200
172
|
@classmethod
|
|
@@ -211,22 +183,3 @@ class GradientPTQConfigV2(GradientPTQConfig):
|
|
|
211
183
|
v1_params = config_v1.__dict__
|
|
212
184
|
v1_params = {k: v for k, v in v1_params.items() if k != 'n_iter'}
|
|
213
185
|
return cls(n_epochs, **v1_params)
|
|
214
|
-
|
|
215
|
-
def get_extended_quantizer_parametes(self) -> Dict[str, Any]:
|
|
216
|
-
"""
|
|
217
|
-
Return a dictionary with a mapping to necessary additional parameters for initializing the GPTQ quantizer.
|
|
218
|
-
|
|
219
|
-
Returns: A dictionary with parameters for initializing a quantizer.
|
|
220
|
-
|
|
221
|
-
"""
|
|
222
|
-
|
|
223
|
-
if self.rounding_type == RoundingType.SoftQuantizer:
|
|
224
|
-
return {N_BATCHES_STR: self.quantizer_config.n_batches,
|
|
225
|
-
QUANT_PARAM_LEARNING_STR: self.quantization_parameters_learning,
|
|
226
|
-
N_EPOCHS_STR: self.n_epochs}
|
|
227
|
-
elif self.rounding_type == RoundingType.STE:
|
|
228
|
-
return {MAX_LSB_STR: self.lsb_change_per_bit_width}
|
|
229
|
-
|
|
230
|
-
return {}
|
|
231
|
-
|
|
232
|
-
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
AUXVAR = 'auxvar_tensor'
|
|
3
3
|
ITERVAR = 'iteration_variable'
|
|
4
4
|
SCALE_TENSOR = "scale_ptq_tensor"
|
|
5
|
-
GPTQ_ITER = "gptq_iter"
|
|
6
5
|
AUXSHIFT = 'shift'
|
|
7
6
|
WEIGHTS_QUANTIZATION_PARAMS = 'weights_quantization_params'
|
|
8
7
|
PTQ_MIN_RANGE = "min_range"
|
|
@@ -11,22 +10,16 @@ PTQ_THRESHOLD = "ptq_threshold"
|
|
|
11
10
|
SCALE_PTQ = "scale"
|
|
12
11
|
|
|
13
12
|
# Default quantizer values
|
|
14
|
-
N_EPOCHS = 10000
|
|
15
13
|
N_CYCLES = 4
|
|
16
14
|
MIM_TEMP = 0.5
|
|
17
15
|
MAX_TEMP = 1.0
|
|
18
16
|
REG_DEFAULT = 0.01
|
|
19
|
-
MAX_ITERATIONS_DEFAULT = 10000
|
|
20
17
|
MAX_LSB_CHANGE = 1
|
|
21
18
|
|
|
22
19
|
# Soft rounding arguments values
|
|
23
20
|
SOFT_ROUNDING_GAMMA = -0.1
|
|
24
21
|
SOFT_ROUNDING_ZETA = 1.1
|
|
25
|
-
SOFT_ROUNDING_BETA = 2 / 3
|
|
26
22
|
|
|
27
23
|
# GPTQ config constant
|
|
28
|
-
REGULARIZATION_VALUES = "regularization_values"
|
|
29
|
-
N_BATCHES_STR = 'n_batches'
|
|
30
24
|
QUANT_PARAM_LEARNING_STR = 'quantization_parameter_learning'
|
|
31
|
-
N_EPOCHS_STR = 'n_epochs'
|
|
32
25
|
MAX_LSB_STR = 'max_lsbs_change_map'
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from abc import abstractmethod
|
|
17
|
+
|
|
18
|
+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GPTQFrameworkImplemantation(FrameworkImplementation):
|
|
22
|
+
"""
|
|
23
|
+
Class to implement framework related methods that are used in GPTQ
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def get_gptq_trainer_obj(self):
|
|
28
|
+
"""
|
|
29
|
+
Returns: GPTQTrainer object
|
|
30
|
+
"""
|
|
31
|
+
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
32
|
+
f'framework\'s get_gptq_trainer method.') # pragma: no cover
|