mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +12 -41
- model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
- model_compression_toolkit/core/__init__.py +14 -0
- model_compression_toolkit/core/analyzer.py +3 -2
- model_compression_toolkit/core/common/__init__.py +0 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +1 -8
- model_compression_toolkit/core/common/framework_info.py +1 -1
- model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
- model_compression_toolkit/core/common/graph/base_graph.py +2 -2
- model_compression_toolkit/core/common/graph/base_node.py +57 -1
- model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
- model_compression_toolkit/core/common/memory_computation.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/model_collector.py +2 -2
- model_compression_toolkit/core/common/model_validation.py +1 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -1
- model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
- model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
- model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
- model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
- model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
- model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/keras/constants.py +0 -7
- model_compression_toolkit/core/keras/default_framework_info.py +3 -3
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/keras/keras_implementation.py +2 -10
- model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
- model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
- model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +1 -1
- model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/constants.py +4 -6
- model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
- model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
- model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/runner.py +7 -7
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -104
- model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
- model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
- model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
- model_compression_toolkit/gptq/common/gptq_training.py +30 -39
- model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
- model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
- model_compression_toolkit/gptq/keras/graph_info.py +8 -33
- model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
- model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
- model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
- model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
- model_compression_toolkit/gptq/runner.py +3 -2
- model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
- model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
- model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
- model_compression_toolkit/ptq/__init__.py +3 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
- model_compression_toolkit/qat/__init__.py +4 -0
- model_compression_toolkit/qat/common/__init__.py +1 -2
- model_compression_toolkit/qat/common/qat_config.py +3 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
- model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
- model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
- model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
- model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
- model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
- model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
- /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import Callable
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit.gptq import RoundingType, GradientPTQConfigV2, GradientPTQConfig
|
|
18
|
+
from model_compression_toolkit.gptq.keras.quantizer.soft_rounding.soft_quantizer_reg import \
|
|
19
|
+
SoftQuantizerRegularization
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen: Callable) -> Callable:
|
|
23
|
+
"""
|
|
24
|
+
Returns a function that computes the regularization term for GPTQ training based on the given
|
|
25
|
+
rounding type in the GPTQ configuration.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
gptq_config: A GPTQ configuration.
|
|
29
|
+
representative_data_gen: Dataset used for the GPTQ training.
|
|
30
|
+
|
|
31
|
+
Returns: A function for computing the regularization. If there is no regularization function defined for the given
|
|
32
|
+
rounding type, then it returns a function that just returns 0.
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
if gptq_config.rounding_type == RoundingType.SoftQuantizer:
|
|
36
|
+
# dry run on the representative dataset to count number of batches
|
|
37
|
+
num_batches = 0
|
|
38
|
+
for _ in representative_data_gen():
|
|
39
|
+
num_batches += 1
|
|
40
|
+
|
|
41
|
+
n_epochs = GradientPTQConfigV2.from_v1(n_ptq_iter=num_batches, config_v1=gptq_config).n_epochs if \
|
|
42
|
+
not type(gptq_config) == GradientPTQConfigV2 else gptq_config.n_epochs
|
|
43
|
+
return SoftQuantizerRegularization(total_gradient_steps=num_batches * n_epochs)
|
|
44
|
+
else:
|
|
45
|
+
return lambda m, e_reg: 0
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import List
|
|
16
|
+
|
|
17
|
+
import tensorflow as tf
|
|
18
|
+
from keras import Model
|
|
19
|
+
|
|
20
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
21
|
+
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
22
|
+
from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LinearTempDecay:
|
|
26
|
+
"""
|
|
27
|
+
Annealing process for the soft quantizer regularization temperature term.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 20, end_b: int = 2):
|
|
31
|
+
"""
|
|
32
|
+
Initializes a LinearTempDecay object.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
t_max: maximal time step.
|
|
36
|
+
rel_start_decay: Decay step size at the beginning of the process.
|
|
37
|
+
start_b: Starting value of the regularization term.
|
|
38
|
+
end_b: Target value of the regularization term.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
self.t_max = t_max
|
|
42
|
+
self.start_decay = rel_start_decay * t_max
|
|
43
|
+
self.start_b = start_b
|
|
44
|
+
self.end_b = end_b
|
|
45
|
+
|
|
46
|
+
def __call__(self, t: int) -> float:
|
|
47
|
+
"""
|
|
48
|
+
Cosine annealing scheduler for soft quantizer regularization temperature term.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
t: The current time step.
|
|
52
|
+
|
|
53
|
+
Returns: Scheduled temperature.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
is_before_start_decay = tf.cast(t < self.start_decay, tf.float32)
|
|
57
|
+
|
|
58
|
+
rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
|
|
59
|
+
|
|
60
|
+
return self.start_b * is_before_start_decay + \
|
|
61
|
+
(1 - is_before_start_decay) * \
|
|
62
|
+
(self.end_b + (self.start_b - self.end_b) * tf.math.maximum(0.0, (1 - rel_t)))
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class SoftQuantizerRegularization:
|
|
66
|
+
"""
|
|
67
|
+
A class to handle the computation of soft quantizer regularization for GPTQ training.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(self, total_gradient_steps: int):
|
|
71
|
+
"""
|
|
72
|
+
Initializes the regularization computation object with a LinearDecay object.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
total_gradient_steps: The number of gradient steps during optimization.
|
|
76
|
+
"""
|
|
77
|
+
# Initializing the temperature decay according to the number of expected gradient steps
|
|
78
|
+
self.linear_decay = LinearTempDecay(total_gradient_steps)
|
|
79
|
+
|
|
80
|
+
self.count_iter = tf.Variable(0.)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def __call__(self, model: Model, entropy_reg: float):
|
|
84
|
+
"""
|
|
85
|
+
Returns the soft quantizer regularization value for SoftRounding.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
model: A model to be quantized with SoftRounding.
|
|
89
|
+
entropy_reg: Entropy value to scale the quantizer regularization.
|
|
90
|
+
|
|
91
|
+
Returns: Regularization value.
|
|
92
|
+
"""
|
|
93
|
+
soft_reg_aux: List[tf.Tensor] = []
|
|
94
|
+
b = self.linear_decay(self.count_iter.value())
|
|
95
|
+
for layer in model.layers:
|
|
96
|
+
if isinstance(layer, KerasQuantizationWrapper):
|
|
97
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
|
98
|
+
fw_info=DEFAULT_KERAS_INFO)
|
|
99
|
+
|
|
100
|
+
st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
|
|
101
|
+
soft_reg_aux.append(tf.reduce_sum(1 - tf.pow(tf.math.abs(st - .5) * 2, b)))
|
|
102
|
+
|
|
103
|
+
reg = 0
|
|
104
|
+
|
|
105
|
+
for sq in soft_reg_aux:
|
|
106
|
+
reg += sq
|
|
107
|
+
|
|
108
|
+
self.count_iter.assign_add(1.0)
|
|
109
|
+
|
|
110
|
+
return entropy_reg * reg
|
|
@@ -16,22 +16,22 @@
|
|
|
16
16
|
import tensorflow as tf
|
|
17
17
|
import numpy as np
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit import RoundingType
|
|
19
|
+
from model_compression_toolkit.gptq import RoundingType
|
|
20
20
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
21
21
|
from model_compression_toolkit.core.common import max_power_of_two
|
|
22
|
-
from model_compression_toolkit.
|
|
23
|
-
from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ,
|
|
24
|
-
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
23
|
+
from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
|
|
24
|
+
SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
|
|
25
25
|
from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
|
|
26
|
-
from typing import Dict, Any
|
|
27
|
-
from model_compression_toolkit.
|
|
28
|
-
from model_compression_toolkit.core.common.logger import Logger
|
|
26
|
+
from typing import Dict, Any
|
|
27
|
+
from model_compression_toolkit.constants import THRESHOLD, MIN_THRESHOLD
|
|
29
28
|
from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
|
|
30
29
|
from model_compression_toolkit.gptq.keras.quantizer.quant_utils import power_of_two_max, clip, calculate_delta
|
|
31
30
|
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
|
|
32
31
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
|
|
33
32
|
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
|
|
34
33
|
get_threshold_reshape_shape
|
|
34
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
def soft_rounding_symmetric_quantizer(input_tensor: tf.Tensor,
|
|
@@ -66,46 +66,6 @@ def soft_rounding_symmetric_quantizer(input_tensor: tf.Tensor,
|
|
|
66
66
|
return delta * clip(tensor_q, max_val=max_int, min_val=min_int)
|
|
67
67
|
|
|
68
68
|
|
|
69
|
-
class LinearTempDecay:
|
|
70
|
-
"""
|
|
71
|
-
Annealing process for the soft quantizer regularization temperature term.
|
|
72
|
-
"""
|
|
73
|
-
|
|
74
|
-
def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 20, end_b: int = 2):
|
|
75
|
-
"""
|
|
76
|
-
Initializes a LinearTempDecay object.
|
|
77
|
-
|
|
78
|
-
Args:
|
|
79
|
-
t_max: maximal time step.
|
|
80
|
-
rel_start_decay: Decay step size at the beginning of the process.
|
|
81
|
-
start_b: Starting value of the regularization term.
|
|
82
|
-
end_b: Target value of the regularization term.
|
|
83
|
-
"""
|
|
84
|
-
|
|
85
|
-
self.t_max = t_max
|
|
86
|
-
self.start_decay = rel_start_decay * t_max
|
|
87
|
-
self.start_b = start_b
|
|
88
|
-
self.end_b = end_b
|
|
89
|
-
|
|
90
|
-
def __call__(self, t: int) -> float:
|
|
91
|
-
"""
|
|
92
|
-
Cosine annealing scheduler for soft quantizer regularization temperature term.
|
|
93
|
-
|
|
94
|
-
Args:
|
|
95
|
-
t: The current time step.
|
|
96
|
-
|
|
97
|
-
Returns: Scheduled temperature.
|
|
98
|
-
"""
|
|
99
|
-
|
|
100
|
-
is_before_start_decay = tf.cast(t < self.start_decay, tf.float32)
|
|
101
|
-
|
|
102
|
-
rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
|
|
103
|
-
|
|
104
|
-
return self.start_b * is_before_start_decay + \
|
|
105
|
-
(1 - is_before_start_decay) * \
|
|
106
|
-
(self.end_b + (self.start_b - self.end_b) * tf.math.maximum(0.0, (1 - rel_t)))
|
|
107
|
-
|
|
108
|
-
|
|
109
69
|
@mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
|
|
110
70
|
quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
|
|
111
71
|
quantizer_type=RoundingType.SoftQuantizer)
|
|
@@ -116,23 +76,15 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
|
|
|
116
76
|
|
|
117
77
|
def __init__(self,
|
|
118
78
|
quantization_config: TrainableQuantizerWeightsConfig,
|
|
119
|
-
|
|
120
|
-
quantization_parameter_learning: bool = False,
|
|
121
|
-
n_epochs: int = N_EPOCHS):
|
|
79
|
+
quantization_parameter_learning: bool = False):
|
|
122
80
|
"""
|
|
123
81
|
Initialize a SymmetricSoftRoundingGPTQ object with parameters to use
|
|
124
82
|
for the quantization.
|
|
125
83
|
|
|
126
84
|
Args:
|
|
127
85
|
quantization_config: Trainable weights quantizer config.
|
|
128
|
-
n_batches: The expected number of batches for each training epoch.
|
|
129
86
|
quantization_parameter_learning: Whether to train the quantization threshold.
|
|
130
|
-
n_epochs: Number of epochs to run training for.
|
|
131
87
|
"""
|
|
132
|
-
|
|
133
|
-
if n_batches is None:
|
|
134
|
-
Logger.error("SymmetricSoftRoundingGPTQ got an uninitialized n_batches argument.")
|
|
135
|
-
|
|
136
88
|
super().__init__(quantization_config)
|
|
137
89
|
self.num_bits = quantization_config.weights_n_bits
|
|
138
90
|
self.per_channel = quantization_config.weights_per_channel_threshold
|
|
@@ -148,32 +100,23 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
|
|
|
148
100
|
self.num_channels = len(self.threshold_values) if self.per_channel else 1
|
|
149
101
|
|
|
150
102
|
# gamma and zeta are stretch parameters for computing the rectified sigmoind function.
|
|
151
|
-
# beta is used to set the regularization term.
|
|
152
103
|
# See: https://arxiv.org/pdf/2004.10568.pdf
|
|
153
104
|
self.gamma = SOFT_ROUNDING_GAMMA
|
|
154
105
|
self.zeta = SOFT_ROUNDING_ZETA
|
|
155
|
-
self.beta = SOFT_ROUNDING_BETA
|
|
156
106
|
|
|
157
107
|
self.quantizer_parameters = {}
|
|
158
108
|
|
|
159
|
-
# Initializing the temperature decay according to the number of expected gradient steps
|
|
160
|
-
init_decay = MAX_ITERATIONS_DEFAULT if n_batches is None else n_epochs * n_batches
|
|
161
|
-
self.linear_decay = LinearTempDecay(init_decay)
|
|
162
|
-
|
|
163
109
|
def initialize_quantization(self,
|
|
164
110
|
tensor_shape: Any,
|
|
165
111
|
name: str,
|
|
166
|
-
layer: Any)
|
|
112
|
+
layer: Any):
|
|
167
113
|
"""
|
|
168
|
-
|
|
114
|
+
Add quantizer parameters to the quantizer parameters dictionary
|
|
169
115
|
|
|
170
116
|
Args:
|
|
171
117
|
tensor_shape: tensor shape of the quantized tensor.
|
|
172
118
|
name: Tensor name.
|
|
173
119
|
layer: Layer to quantize.
|
|
174
|
-
|
|
175
|
-
Returns:
|
|
176
|
-
Dictionary of parameters names to the variables.
|
|
177
120
|
"""
|
|
178
121
|
|
|
179
122
|
if self.per_channel:
|
|
@@ -183,12 +126,6 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
|
|
|
183
126
|
else:
|
|
184
127
|
reshape_shape = [self.num_channels]
|
|
185
128
|
|
|
186
|
-
ar_iter = layer.add_weight(
|
|
187
|
-
f"{name}_{GPTQ_ITER}",
|
|
188
|
-
shape=(),
|
|
189
|
-
initializer=tf.keras.initializers.Constant(0.0),
|
|
190
|
-
trainable=False)
|
|
191
|
-
|
|
192
129
|
ptq_threshold_tensor = layer.add_weight(
|
|
193
130
|
f"{name}_{PTQ_THRESHOLD}",
|
|
194
131
|
shape=reshape_shape,
|
|
@@ -212,44 +149,17 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
|
|
|
212
149
|
|
|
213
150
|
auxvar_tensor.assign(alpha)
|
|
214
151
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
152
|
+
# Add quantization variables
|
|
153
|
+
self.add_quantizer_variable(AUXVAR, auxvar_tensor, VariableGroup.WEIGHTS)
|
|
154
|
+
self.add_quantizer_variable(PTQ_THRESHOLD, ptq_threshold_tensor, VariableGroup.QPARAMS)
|
|
218
155
|
|
|
219
|
-
if self.quantization_parameter_learning:
|
|
156
|
+
if self.quantization_parameter_learning and not self.power_of_two:
|
|
220
157
|
scale = layer.add_weight(
|
|
221
158
|
f"{name}_{SCALE_PTQ}",
|
|
222
159
|
shape=self.num_channels,
|
|
223
160
|
initializer=tf.keras.initializers.Constant(1.0),
|
|
224
161
|
trainable=True)
|
|
225
|
-
self.
|
|
226
|
-
|
|
227
|
-
return self.quantizer_parameters
|
|
228
|
-
|
|
229
|
-
def get_quantization_variable(self) -> List[tf.Tensor]:
|
|
230
|
-
"""
|
|
231
|
-
This function return a list with the quantizer's quantization parameters variables.
|
|
232
|
-
|
|
233
|
-
Returns: A list with the quantization parameters if there are defined parameters.
|
|
234
|
-
|
|
235
|
-
"""
|
|
236
|
-
|
|
237
|
-
if self.quantization_parameter_learning and not self.power_of_two:
|
|
238
|
-
return [self.quantizer_parameters[SCALE_PTQ]]
|
|
239
|
-
else:
|
|
240
|
-
return []
|
|
241
|
-
|
|
242
|
-
def get_regularization(self) -> tf.Tensor:
|
|
243
|
-
"""
|
|
244
|
-
Computes the regularization term for the soft rounding loss.
|
|
245
|
-
|
|
246
|
-
Returns:
|
|
247
|
-
regularization term.
|
|
248
|
-
"""
|
|
249
|
-
|
|
250
|
-
st = self.get_soft_targets()
|
|
251
|
-
b = self.linear_decay(self.ar_iter.value())
|
|
252
|
-
return tf.reduce_sum(1 - tf.pow(tf.math.abs(st - .5) * 2, b))
|
|
162
|
+
self.add_quantizer_variable(SCALE_PTQ, scale, VariableGroup.QPARAMS)
|
|
253
163
|
|
|
254
164
|
def get_soft_targets(self) -> tf.Tensor:
|
|
255
165
|
"""
|
|
@@ -260,16 +170,7 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
|
|
|
260
170
|
|
|
261
171
|
"""
|
|
262
172
|
return qutils.clip(
|
|
263
|
-
tf.sigmoid(self.
|
|
264
|
-
|
|
265
|
-
def get_aux_variable(self) -> List[tf.Tensor]:
|
|
266
|
-
"""
|
|
267
|
-
This function return a list with the quantizer's quantization auxiliary variables.
|
|
268
|
-
|
|
269
|
-
Returns: A list with the quantization auxiliary variables.
|
|
270
|
-
|
|
271
|
-
"""
|
|
272
|
-
return [self.quantizer_parameters[AUXVAR]]
|
|
173
|
+
tf.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma, 1, 0)
|
|
273
174
|
|
|
274
175
|
def __call__(self,
|
|
275
176
|
inputs: tf.Tensor,
|
|
@@ -285,8 +186,14 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
|
|
|
285
186
|
The quantized tensor.
|
|
286
187
|
"""
|
|
287
188
|
|
|
288
|
-
|
|
289
|
-
|
|
189
|
+
ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
|
|
190
|
+
|
|
191
|
+
#####################################################
|
|
192
|
+
# Soft Rounding
|
|
193
|
+
#####################################################
|
|
194
|
+
aux_var = self.get_soft_targets()
|
|
195
|
+
if not training:
|
|
196
|
+
aux_var = tf.cast(tf.math.greater_equal(aux_var, 0.5), tf.float32)
|
|
290
197
|
|
|
291
198
|
if self.per_channel:
|
|
292
199
|
reshape_shape = get_threshold_reshape_shape(inputs.shape,
|
|
@@ -297,15 +204,6 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
|
|
|
297
204
|
# Calculate soft rounding targets and optimized threshold
|
|
298
205
|
##########################################################
|
|
299
206
|
ptq_threshold_tensor_hat = tf.reshape(ptq_threshold_tensor, reshape_shape)
|
|
300
|
-
aux_var = self.get_soft_targets()
|
|
301
|
-
|
|
302
|
-
#####################################################
|
|
303
|
-
# Soft Rounding
|
|
304
|
-
#####################################################
|
|
305
|
-
if training:
|
|
306
|
-
self.ar_iter.assign_add(1.0)
|
|
307
|
-
else:
|
|
308
|
-
aux_var = tf.cast(tf.math.greater_equal(aux_var, 0.5), tf.float32)
|
|
309
207
|
|
|
310
208
|
#####################################################
|
|
311
209
|
# Quantized Input
|
|
@@ -318,17 +216,22 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
|
|
|
318
216
|
power_of_two=self.power_of_two)
|
|
319
217
|
|
|
320
218
|
if self.quantization_parameter_learning and not self.power_of_two:
|
|
321
|
-
scale = tf.reshape(self.
|
|
219
|
+
scale = tf.reshape(self.get_quantizer_variable(SCALE_PTQ), reshape_shape)
|
|
322
220
|
q_tensor *= scale
|
|
323
221
|
|
|
324
|
-
return q_tensor
|
|
325
222
|
else:
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
223
|
+
q_tensor = soft_rounding_symmetric_quantizer(input_tensor=inputs,
|
|
224
|
+
auxvar_tensor=aux_var,
|
|
225
|
+
threshold_tensor=ptq_threshold_tensor.value(),
|
|
226
|
+
num_bits=self.num_bits,
|
|
227
|
+
signed=True,
|
|
228
|
+
power_of_two=self.power_of_two)
|
|
229
|
+
|
|
230
|
+
if self.quantization_parameter_learning and not self.power_of_two:
|
|
231
|
+
scale = self.get_quantizer_variable(SCALE_PTQ)
|
|
232
|
+
q_tensor *= scale
|
|
233
|
+
|
|
234
|
+
return q_tensor
|
|
332
235
|
|
|
333
236
|
def get_quant_config(self) -> Dict[str, np.ndarray]:
|
|
334
237
|
"""
|
|
@@ -340,13 +243,13 @@ class SymmetricSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
|
|
|
340
243
|
"""
|
|
341
244
|
|
|
342
245
|
if self.power_of_two:
|
|
343
|
-
old_threshold = self.
|
|
246
|
+
old_threshold = self.get_quantizer_variable(PTQ_THRESHOLD)
|
|
344
247
|
old_threshold = max_power_of_two(old_threshold, MIN_THRESHOLD)
|
|
345
248
|
|
|
346
249
|
else:
|
|
347
|
-
old_threshold = self.
|
|
250
|
+
old_threshold = self.get_quantizer_variable(PTQ_THRESHOLD)
|
|
348
251
|
if self.quantization_parameter_learning:
|
|
349
|
-
scale = tf.reshape(self.
|
|
252
|
+
scale = tf.reshape(self.get_quantizer_variable(SCALE_PTQ), self.threshold_shape)
|
|
350
253
|
old_threshold = old_threshold * scale
|
|
351
254
|
old_threshold = old_threshold.numpy()
|
|
352
255
|
old_threshold = old_threshold.reshape(self.threshold_shape)
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import tensorflow as tf
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from model_compression_toolkit.gptq import RoundingType
|
|
20
|
+
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
23
|
+
from model_compression_toolkit.gptq.common.gptq_constants import \
|
|
24
|
+
SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
|
|
25
|
+
from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
|
|
26
|
+
from typing import Dict, Any
|
|
27
|
+
from model_compression_toolkit.constants import RANGE_MIN, RANGE_MAX
|
|
28
|
+
from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
|
|
29
|
+
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
|
|
30
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
|
|
31
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
|
|
32
|
+
get_threshold_reshape_shape
|
|
33
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def soft_rounding_uniform_quantizer(input_tensor: tf.Tensor,
|
|
37
|
+
auxvar_tensor: tf.Variable,
|
|
38
|
+
min_tensor: tf.Tensor,
|
|
39
|
+
max_tensor: tf.Tensor,
|
|
40
|
+
num_bits: int) -> tf.Tensor:
|
|
41
|
+
"""
|
|
42
|
+
Quantize a tensor uniformly for GPTQ quantizers.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
|
|
46
|
+
auxvar_tensor: Tensor that manifests the bit shift of the quantized weights due to gptq training.
|
|
47
|
+
min_tensor: Tensor with values to compute the min threshold.
|
|
48
|
+
max_tensor: Tensor with values to compute the max threshold.
|
|
49
|
+
num_bits: Num of bits to use.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
A quantized tensor.
|
|
53
|
+
"""
|
|
54
|
+
# adjusts the quantization range so the quantization grid includes zero.
|
|
55
|
+
min_range, max_range = qutils.fix_range_to_include_zero(min_tensor, max_tensor, num_bits)
|
|
56
|
+
delta = qutils.calculate_delta_uniform(min_range, max_range, num_bits)
|
|
57
|
+
input_tensor_int = qutils.ste_floor((input_tensor - min_range) / delta)
|
|
58
|
+
tensor_q = input_tensor_int + auxvar_tensor
|
|
59
|
+
return delta * qutils.ste_clip(tensor_q,
|
|
60
|
+
min_val=0,
|
|
61
|
+
max_val=2 ** num_bits - 1) + min_range
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
|
|
65
|
+
quantization_method=[QuantizationMethod.UNIFORM],
|
|
66
|
+
quantizer_type=RoundingType.SoftQuantizer)
|
|
67
|
+
class UniformSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
|
|
68
|
+
"""
|
|
69
|
+
Trainable uniform quantizer to optimize the rounding of the quantized values using a soft quantization method.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(self,
|
|
73
|
+
quantization_config: TrainableQuantizerWeightsConfig,
|
|
74
|
+
quantization_parameter_learning: bool = False):
|
|
75
|
+
"""
|
|
76
|
+
Initialize a UniformSoftRoundingGPTQ object with parameters to use
|
|
77
|
+
for the quantization.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
quantization_config: Trainable weight quantizer config.
|
|
81
|
+
quantization_parameter_learning: Whether to train the quantization threshold.
|
|
82
|
+
"""
|
|
83
|
+
super().__init__(quantization_config)
|
|
84
|
+
self.num_bits = quantization_config.weights_n_bits
|
|
85
|
+
self.per_channel = quantization_config.weights_per_channel_threshold
|
|
86
|
+
|
|
87
|
+
self.min_values = quantization_config.weights_quantization_params[RANGE_MIN]
|
|
88
|
+
self.max_values = quantization_config.weights_quantization_params[RANGE_MAX]
|
|
89
|
+
|
|
90
|
+
self.quantization_axis = quantization_config.weights_channels_axis
|
|
91
|
+
assert quantization_parameter_learning is False, \
|
|
92
|
+
"Quantization parameters learning in UniformSoftRoundingGPTQ not implemented yet"
|
|
93
|
+
self.quantization_parameter_learning = quantization_parameter_learning
|
|
94
|
+
self.num_channels = self.min_values.shape[self.quantization_axis] if self.per_channel else 1
|
|
95
|
+
|
|
96
|
+
# gamma and zeta are stretch parameters for computing the rectified sigmoid function.
|
|
97
|
+
# See: https://arxiv.org/pdf/2004.10568.pdf
|
|
98
|
+
self.gamma = SOFT_ROUNDING_GAMMA
|
|
99
|
+
self.zeta = SOFT_ROUNDING_ZETA
|
|
100
|
+
|
|
101
|
+
def initialize_quantization(self,
|
|
102
|
+
tensor_shape: Any,
|
|
103
|
+
name: str,
|
|
104
|
+
layer: Any):
|
|
105
|
+
"""
|
|
106
|
+
Add quantizer parameters to the quantizer parameters dictionary
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
tensor_shape: tensor shape of the quantized tensor.
|
|
110
|
+
name: Tensor name.
|
|
111
|
+
layer: Layer to quantize.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
if self.per_channel:
|
|
115
|
+
reshape_shape = get_threshold_reshape_shape(tensor_shape,
|
|
116
|
+
quant_axis=self.quantization_axis,
|
|
117
|
+
quant_axis_dim=self.num_channels)
|
|
118
|
+
else:
|
|
119
|
+
reshape_shape = [self.num_channels]
|
|
120
|
+
|
|
121
|
+
min_tensor = layer.add_weight(
|
|
122
|
+
f"{name}_{FQ_MIN}",
|
|
123
|
+
shape=reshape_shape,
|
|
124
|
+
initializer=tf.keras.initializers.Constant(1.0),
|
|
125
|
+
trainable=False)
|
|
126
|
+
min_tensor.assign(self.min_values.reshape(reshape_shape))
|
|
127
|
+
|
|
128
|
+
max_tensor = layer.add_weight(
|
|
129
|
+
f"{name}_{FQ_MAX}",
|
|
130
|
+
shape=reshape_shape,
|
|
131
|
+
initializer=tf.keras.initializers.Constant(1.0),
|
|
132
|
+
trainable=False)
|
|
133
|
+
max_tensor.assign(self.max_values.reshape(reshape_shape))
|
|
134
|
+
|
|
135
|
+
w = getattr(layer.layer, name)
|
|
136
|
+
auxvar_tensor = layer.add_weight(
|
|
137
|
+
f"{name}_{AUXVAR}",
|
|
138
|
+
shape=list(w.shape),
|
|
139
|
+
initializer=tf.keras.initializers.Constant(0.0),
|
|
140
|
+
trainable=True)
|
|
141
|
+
|
|
142
|
+
w = layer.layer.depthwise_kernel if isinstance(layer.layer, (tf.keras.layers.DepthwiseConv2D,
|
|
143
|
+
tf.keras.layers.DepthwiseConv1D)) \
|
|
144
|
+
else layer.layer.kernel
|
|
145
|
+
delta = qutils.calculate_delta_uniform(min_tensor, max_tensor, self.num_bits)
|
|
146
|
+
w_clipped_normed = qutils.clip((w - min_tensor)/ delta, 0, 2 ** self.num_bits - 1)
|
|
147
|
+
rest = w_clipped_normed - tf.floor(w_clipped_normed) # rest of rounding [0, 1)
|
|
148
|
+
alpha = -qutils.safe_log((self.zeta - self.gamma) / (rest - self.gamma) - 1, 1e-16) # => sigmoid(alpha) = rest
|
|
149
|
+
auxvar_tensor.assign(alpha)
|
|
150
|
+
|
|
151
|
+
# Add quantization variables
|
|
152
|
+
self.add_quantizer_variable(AUXVAR, auxvar_tensor, VariableGroup.WEIGHTS)
|
|
153
|
+
self.add_quantizer_variable(RANGE_MIN, min_tensor, VariableGroup.QPARAMS)
|
|
154
|
+
self.add_quantizer_variable(RANGE_MAX, max_tensor, VariableGroup.QPARAMS)
|
|
155
|
+
|
|
156
|
+
def get_soft_targets(self) -> tf.Tensor:
|
|
157
|
+
"""
|
|
158
|
+
Computes the rectified sigmoid function for the quantization target parameters.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
A tensor with the soft rounding targets values.
|
|
162
|
+
|
|
163
|
+
"""
|
|
164
|
+
return qutils.clip(
|
|
165
|
+
tf.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma, 1, 0)
|
|
166
|
+
|
|
167
|
+
def __call__(self,
|
|
168
|
+
inputs: tf.Tensor,
|
|
169
|
+
training: bool):
|
|
170
|
+
"""
|
|
171
|
+
Quantize a tensor.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
inputs: Input tensor to quantize.
|
|
175
|
+
training: Whether the graph is in training mode.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
The quantized tensor.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
min_tensor = self.get_quantizer_variable(RANGE_MIN)
|
|
182
|
+
max_tensor = self.get_quantizer_variable(RANGE_MAX)
|
|
183
|
+
|
|
184
|
+
#####################################################
|
|
185
|
+
# Soft Rounding
|
|
186
|
+
#####################################################
|
|
187
|
+
aux_var = self.get_soft_targets()
|
|
188
|
+
if not training:
|
|
189
|
+
aux_var = tf.cast(tf.math.greater_equal(aux_var, 0.5), tf.float32)
|
|
190
|
+
|
|
191
|
+
if self.per_channel:
|
|
192
|
+
reshape_shape = get_threshold_reshape_shape(inputs.shape,
|
|
193
|
+
quant_axis=self.quantization_axis,
|
|
194
|
+
quant_axis_dim=-1)
|
|
195
|
+
|
|
196
|
+
#####################################################
|
|
197
|
+
# Quantized Input
|
|
198
|
+
#####################################################
|
|
199
|
+
q_tensor = soft_rounding_uniform_quantizer(input_tensor=inputs,
|
|
200
|
+
auxvar_tensor=aux_var,
|
|
201
|
+
min_tensor=tf.reshape(min_tensor, reshape_shape),
|
|
202
|
+
max_tensor=tf.reshape(max_tensor, reshape_shape),
|
|
203
|
+
num_bits=self.num_bits)
|
|
204
|
+
|
|
205
|
+
else:
|
|
206
|
+
q_tensor = soft_rounding_uniform_quantizer(input_tensor=inputs,
|
|
207
|
+
auxvar_tensor=aux_var,
|
|
208
|
+
min_tensor=min_tensor,
|
|
209
|
+
max_tensor=max_tensor,
|
|
210
|
+
num_bits=self.num_bits)
|
|
211
|
+
|
|
212
|
+
return q_tensor
|
|
213
|
+
|
|
214
|
+
def get_quant_config(self) -> Dict[str, np.ndarray]:
|
|
215
|
+
"""
|
|
216
|
+
Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
|
|
220
|
+
Keys must match NodeQuantizationConfig attributes
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
return {RANGE_MIN: self.get_quantizer_variable(RANGE_MIN).numpy(),
|
|
224
|
+
RANGE_MAX: self.get_quantizer_variable(RANGE_MAX).numpy()}
|