mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +12 -41
- model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
- model_compression_toolkit/core/__init__.py +14 -0
- model_compression_toolkit/core/analyzer.py +3 -2
- model_compression_toolkit/core/common/__init__.py +0 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +1 -8
- model_compression_toolkit/core/common/framework_info.py +1 -1
- model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
- model_compression_toolkit/core/common/graph/base_graph.py +2 -2
- model_compression_toolkit/core/common/graph/base_node.py +57 -1
- model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
- model_compression_toolkit/core/common/memory_computation.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/model_collector.py +2 -2
- model_compression_toolkit/core/common/model_validation.py +1 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -1
- model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
- model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
- model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
- model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
- model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
- model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/keras/constants.py +0 -7
- model_compression_toolkit/core/keras/default_framework_info.py +3 -3
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/keras/keras_implementation.py +2 -10
- model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
- model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
- model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +1 -1
- model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/constants.py +4 -6
- model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
- model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
- model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/runner.py +7 -7
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -104
- model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
- model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
- model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
- model_compression_toolkit/gptq/common/gptq_training.py +30 -39
- model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
- model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
- model_compression_toolkit/gptq/keras/graph_info.py +8 -33
- model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
- model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
- model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
- model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
- model_compression_toolkit/gptq/runner.py +3 -2
- model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
- model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
- model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
- model_compression_toolkit/ptq/__init__.py +3 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
- model_compression_toolkit/qat/__init__.py +4 -0
- model_compression_toolkit/qat/common/__init__.py +1 -2
- model_compression_toolkit/qat/common/qat_config.py +3 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
- model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
- model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
- model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
- model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
- model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
- model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
- /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import List
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import numpy as np
|
|
19
|
+
from torch import nn
|
|
20
|
+
|
|
21
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
22
|
+
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
23
|
+
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
24
|
+
from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class LinearTempDecay:
|
|
28
|
+
"""
|
|
29
|
+
Annealing process for the soft quantizer regularization temperature term.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 20, end_b: int = 2):
|
|
33
|
+
"""
|
|
34
|
+
Initializes a LinearTempDecay object.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
t_max: maximal time step.
|
|
38
|
+
rel_start_decay: Decay step size at the beginning of the process.
|
|
39
|
+
start_b: Starting value of the regularization term.
|
|
40
|
+
end_b: Target value of the regularization term.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
self.t_max = t_max
|
|
44
|
+
self.start_decay = rel_start_decay * t_max
|
|
45
|
+
self.start_b = start_b
|
|
46
|
+
self.end_b = end_b
|
|
47
|
+
|
|
48
|
+
def __call__(self, t: float) -> float:
|
|
49
|
+
"""
|
|
50
|
+
Cosine annealing scheduler for soft quantizer regularization temperature term.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
t: The current time step.
|
|
54
|
+
|
|
55
|
+
Returns: Scheduled temperature.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
is_before_start_decay = (t < self.start_decay)
|
|
59
|
+
|
|
60
|
+
rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
|
|
61
|
+
|
|
62
|
+
return self.start_b * is_before_start_decay + \
|
|
63
|
+
(1 - is_before_start_decay) * \
|
|
64
|
+
(self.end_b + (self.start_b - self.end_b) * torch.maximum(to_torch_tensor(np.array([0.0])),
|
|
65
|
+
to_torch_tensor(np.array((1 - rel_t)))))
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class SoftQuantizerRegularization:
|
|
69
|
+
"""
|
|
70
|
+
A class to handle the computation of soft quantizer regularization for GPTQ training.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(self, total_gradient_steps: int):
|
|
74
|
+
"""
|
|
75
|
+
Initializes the regularization computation object with a LinearDecay object.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
total_gradient_steps: The number of gradient steps during optimization.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
# Initializing the temperature decay according to the number of expected gradient steps
|
|
82
|
+
self.linear_decay = LinearTempDecay(total_gradient_steps)
|
|
83
|
+
|
|
84
|
+
self.count_iter = 0
|
|
85
|
+
|
|
86
|
+
def __call__(self, model: nn.Module, entropy_reg: float):
|
|
87
|
+
"""
|
|
88
|
+
Returns the soft quantizer regularization value for SoftRounding.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
model: A model to be quantized with SoftRounding.
|
|
92
|
+
entropy_reg: Entropy value to scale the quantizer regularization.
|
|
93
|
+
|
|
94
|
+
Returns: Regularization value.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
soft_reg_aux: List[torch.Tensor] = []
|
|
98
|
+
b = self.linear_decay(self.count_iter)
|
|
99
|
+
for layer in model.modules():
|
|
100
|
+
if isinstance(layer, PytorchQuantizationWrapper):
|
|
101
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
|
102
|
+
fw_info=DEFAULT_PYTORCH_INFO)
|
|
103
|
+
|
|
104
|
+
st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
|
|
105
|
+
soft_reg_aux.append((1 - torch.pow(torch.abs(st - .5) * 2, b)).sum())
|
|
106
|
+
|
|
107
|
+
reg = 0
|
|
108
|
+
|
|
109
|
+
for sq in soft_reg_aux:
|
|
110
|
+
reg += sq
|
|
111
|
+
|
|
112
|
+
self.count_iter += 1
|
|
113
|
+
|
|
114
|
+
return entropy_reg * reg
|
|
@@ -14,24 +14,25 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
import torch
|
|
16
16
|
import torch.nn as nn
|
|
17
|
-
from typing import
|
|
17
|
+
from typing import Dict
|
|
18
18
|
import numpy as np
|
|
19
19
|
|
|
20
|
-
from model_compression_toolkit.core.common import
|
|
20
|
+
from model_compression_toolkit.core.common import max_power_of_two
|
|
21
21
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
22
|
-
from model_compression_toolkit.
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
23
23
|
from model_compression_toolkit.gptq.common.gptq_config import RoundingType
|
|
24
24
|
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
|
|
25
25
|
BasePytorchGPTQTrainableQuantizer
|
|
26
26
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
|
|
27
27
|
from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
|
|
28
|
-
from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ,
|
|
29
|
-
|
|
30
|
-
from model_compression_toolkit.
|
|
28
|
+
from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
|
|
29
|
+
SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
|
|
30
|
+
from model_compression_toolkit.constants import THRESHOLD, MIN_THRESHOLD
|
|
31
31
|
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
|
|
32
32
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
|
|
33
33
|
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
|
|
34
34
|
get_threshold_reshape_shape
|
|
35
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
35
36
|
|
|
36
37
|
|
|
37
38
|
def soft_rounding_symmetric_quantizer(input_tensor: torch.Tensor,
|
|
@@ -67,46 +68,6 @@ def soft_rounding_symmetric_quantizer(input_tensor: torch.Tensor,
|
|
|
67
68
|
max_val=int_threshold - 1)
|
|
68
69
|
|
|
69
70
|
|
|
70
|
-
class LinearTempDecay:
|
|
71
|
-
"""
|
|
72
|
-
Annealing process for the soft quantizer regularization temperature term.
|
|
73
|
-
"""
|
|
74
|
-
|
|
75
|
-
def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 20, end_b: int = 2):
|
|
76
|
-
"""
|
|
77
|
-
Initializes a LinearTempDecay object.
|
|
78
|
-
|
|
79
|
-
Args:
|
|
80
|
-
t_max: maximal time step.
|
|
81
|
-
rel_start_decay: Decay step size at the beginning of the process.
|
|
82
|
-
start_b: Starting value of the regularization term.
|
|
83
|
-
end_b: Target value of the regularization term.
|
|
84
|
-
"""
|
|
85
|
-
|
|
86
|
-
self.t_max = t_max
|
|
87
|
-
self.start_decay = rel_start_decay * t_max
|
|
88
|
-
self.start_b = start_b
|
|
89
|
-
self.end_b = end_b
|
|
90
|
-
|
|
91
|
-
def __call__(self, t: nn.Parameter) -> float:
|
|
92
|
-
"""
|
|
93
|
-
Cosine annealing scheduler for soft quantizer regularization temperature term.
|
|
94
|
-
|
|
95
|
-
Args:
|
|
96
|
-
t: The current time step.
|
|
97
|
-
|
|
98
|
-
Returns: Scheduled temperature.
|
|
99
|
-
"""
|
|
100
|
-
|
|
101
|
-
is_before_start_decay = (t < self.start_decay).to(torch.float32)
|
|
102
|
-
|
|
103
|
-
rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
|
|
104
|
-
|
|
105
|
-
return self.start_b * is_before_start_decay + \
|
|
106
|
-
(1 - is_before_start_decay) * \
|
|
107
|
-
(self.end_b + (self.start_b - self.end_b) * torch.maximum(to_torch_tensor(np.array([0.0])), (1 - rel_t)))
|
|
108
|
-
|
|
109
|
-
|
|
110
71
|
@mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
|
|
111
72
|
quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
|
|
112
73
|
quantizer_type=RoundingType.SoftQuantizer)
|
|
@@ -117,22 +78,15 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
|
117
78
|
|
|
118
79
|
def __init__(self,
|
|
119
80
|
quantization_config: TrainableQuantizerWeightsConfig,
|
|
120
|
-
|
|
121
|
-
quantization_parameter_learning: bool = False,
|
|
122
|
-
n_epochs: int = N_EPOCHS):
|
|
81
|
+
quantization_parameter_learning: bool = False):
|
|
123
82
|
"""
|
|
124
83
|
Construct a Pytorch model that utilize a fake weight quantizer of soft-quantizer for symmetric quantizer.
|
|
125
84
|
|
|
126
85
|
Args:
|
|
127
86
|
quantization_config: Trainable weights quantizer config.
|
|
128
|
-
n_batches (int): number of batches in representative dataset
|
|
129
87
|
quantization_parameter_learning (Bool): Whether to learn the threshold or not
|
|
130
|
-
n_epochs (int): number of epochs the representative dataset is run during fine-tuning
|
|
131
88
|
"""
|
|
132
89
|
|
|
133
|
-
if n_batches is None:
|
|
134
|
-
Logger.error("SymmetricSoftRoundingGPTQ got an uninitialized n_batches argument.")
|
|
135
|
-
|
|
136
90
|
super().__init__(quantization_config)
|
|
137
91
|
self.num_bits = quantization_config.weights_n_bits
|
|
138
92
|
self.per_channel = quantization_config.weights_per_channel_threshold
|
|
@@ -147,35 +101,24 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
|
147
101
|
self.quantization_parameter_learning = quantization_parameter_learning
|
|
148
102
|
|
|
149
103
|
# gamma and zeta are stretch parameters for computing the rectified sigmoind function.
|
|
150
|
-
# beta is used to set the regularization term.
|
|
151
104
|
# See: https://arxiv.org/pdf/2004.10568.pdf
|
|
152
105
|
self.gamma = SOFT_ROUNDING_GAMMA
|
|
153
106
|
self.zeta = SOFT_ROUNDING_ZETA
|
|
154
|
-
self.beta = SOFT_ROUNDING_BETA
|
|
155
107
|
|
|
156
108
|
self.quantizer_parameters = {}
|
|
157
109
|
|
|
158
|
-
# Initializing the temperature decay according to the number of expected gradient steps
|
|
159
|
-
num_iterations = MAX_ITERATIONS_DEFAULT if n_batches is None else n_epochs * n_batches
|
|
160
|
-
self.linear_decay = LinearTempDecay(num_iterations)
|
|
161
|
-
|
|
162
110
|
def initialize_quantization(self,
|
|
163
111
|
tensor_shape: torch.Size,
|
|
164
112
|
name: str,
|
|
165
|
-
layer: qi.PytorchQuantizationWrapper)
|
|
113
|
+
layer: qi.PytorchQuantizationWrapper):
|
|
166
114
|
"""
|
|
167
|
-
|
|
115
|
+
Add quantizer parameters to the quantizer parameters dictionary
|
|
168
116
|
|
|
169
117
|
Args:
|
|
170
118
|
tensor_shape: tensor shape of the quantized tensor.
|
|
171
119
|
name: Tensor name.
|
|
172
120
|
layer: Layer to quantize.
|
|
173
|
-
|
|
174
|
-
Returns:
|
|
175
|
-
Dictionary of parameters names to the variables.
|
|
176
121
|
"""
|
|
177
|
-
layer.register_parameter(f"{name}_{GPTQ_ITER}",
|
|
178
|
-
nn.Parameter(to_torch_tensor(np.array([0])), requires_grad=False))
|
|
179
122
|
|
|
180
123
|
if self.per_channel:
|
|
181
124
|
threshold_tensor = to_torch_tensor(self.threshold_values)
|
|
@@ -195,31 +138,18 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
|
195
138
|
layer.register_parameter(f"{name}_{AUXVAR}", nn.Parameter(alpha, requires_grad=True))
|
|
196
139
|
|
|
197
140
|
# save the quantizer added parameters for later calculations
|
|
198
|
-
self.
|
|
199
|
-
|
|
200
|
-
GPTQ_ITER: layer.get_parameter(f"{name}_{GPTQ_ITER}")}
|
|
141
|
+
self.add_quantizer_variable(PTQ_THRESHOLD, layer.get_parameter(f"{name}_{PTQ_THRESHOLD}"), VariableGroup.QPARAMS)
|
|
142
|
+
self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
|
|
201
143
|
|
|
202
144
|
if self.quantization_parameter_learning:
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
def get_regularization(self) -> torch.Tensor:
|
|
212
|
-
"""
|
|
213
|
-
Computes the regularization term for the soft rounding loss.
|
|
214
|
-
|
|
215
|
-
Returns:
|
|
216
|
-
regularization term.
|
|
217
|
-
"""
|
|
218
|
-
|
|
219
|
-
st = self.get_soft_targets()
|
|
220
|
-
ar_iter = self.quantizer_parameters[GPTQ_ITER]
|
|
221
|
-
b = self.linear_decay(ar_iter)
|
|
222
|
-
return (1 - torch.pow(torch.abs(st - .5) * 2, b)).sum()
|
|
145
|
+
if self.per_channel:
|
|
146
|
+
layer.register_parameter(f"{name}_{SCALE_PTQ}",
|
|
147
|
+
nn.Parameter(to_torch_tensor(torch.ones_like(torch.Tensor(self.threshold_values))),
|
|
148
|
+
requires_grad=True))
|
|
149
|
+
else:
|
|
150
|
+
layer.register_parameter(f"{name}_{SCALE_PTQ}",
|
|
151
|
+
nn.Parameter(to_torch_tensor((torch.tensor([1.0], requires_grad=True)))))
|
|
152
|
+
self.add_quantizer_variable(SCALE_PTQ, layer.get_parameter(f"{name}_{SCALE_PTQ}"), VariableGroup.QPARAMS)
|
|
223
153
|
|
|
224
154
|
def get_soft_targets(self) -> torch.Tensor:
|
|
225
155
|
"""
|
|
@@ -229,28 +159,9 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
|
229
159
|
A tensor with the soft rounding targets values.
|
|
230
160
|
|
|
231
161
|
"""
|
|
232
|
-
scaled_sigmoid = torch.sigmoid(self.
|
|
162
|
+
scaled_sigmoid = torch.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma
|
|
233
163
|
return torch.clip(scaled_sigmoid, min=0, max=1)
|
|
234
164
|
|
|
235
|
-
def get_aux_variable(self) -> List[torch.Tensor]:
|
|
236
|
-
"""
|
|
237
|
-
This function return a list with the quantizer's quantization auxiliary variables.
|
|
238
|
-
|
|
239
|
-
Returns: A list with the quantization auxiliary variables.
|
|
240
|
-
"""
|
|
241
|
-
return [self.quantizer_parameters.get(AUXVAR)]
|
|
242
|
-
|
|
243
|
-
def get_quantization_variable(self) -> List[torch.Tensor]:
|
|
244
|
-
"""
|
|
245
|
-
This function return a list with the quantizer's quantization parameters variables.
|
|
246
|
-
|
|
247
|
-
Returns: A list with the quantization parameters.
|
|
248
|
-
"""
|
|
249
|
-
if self.quantization_parameter_learning and not self.power_of_two:
|
|
250
|
-
return [self.quantizer_parameters[SCALE_PTQ]]
|
|
251
|
-
else:
|
|
252
|
-
return []
|
|
253
|
-
|
|
254
165
|
def get_quant_config(self) -> Dict[str, np.ndarray]:
|
|
255
166
|
"""
|
|
256
167
|
Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
|
|
@@ -260,12 +171,13 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
|
260
171
|
Keys must match NodeQuantizationConfig attributes
|
|
261
172
|
|
|
262
173
|
"""
|
|
263
|
-
old_threshold = torch_tensor_to_numpy(self.
|
|
174
|
+
old_threshold = torch_tensor_to_numpy(self.get_quantizer_variable(PTQ_THRESHOLD))
|
|
175
|
+
old_threshold = np.resize(old_threshold, self.threshold_shape)
|
|
264
176
|
if self.power_of_two:
|
|
265
177
|
old_threshold = max_power_of_two(old_threshold, MIN_THRESHOLD)
|
|
266
178
|
else:
|
|
267
179
|
if self.quantization_parameter_learning:
|
|
268
|
-
scale = torch.reshape(self.
|
|
180
|
+
scale = torch.reshape(self.get_quantizer_variable(SCALE_PTQ), self.threshold_shape)
|
|
269
181
|
old_threshold = old_threshold * torch_tensor_to_numpy(scale)
|
|
270
182
|
old_threshold = old_threshold.reshape(self.threshold_shape)
|
|
271
183
|
return {THRESHOLD: old_threshold}
|
|
@@ -283,17 +195,14 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
|
283
195
|
Returns:
|
|
284
196
|
quantized tensor
|
|
285
197
|
"""
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
ptq_threshold_tensor = self.quantizer_parameters[PTQ_THRESHOLD]
|
|
198
|
+
auxvar = self.get_quantizer_variable(AUXVAR)
|
|
199
|
+
ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
|
|
289
200
|
|
|
290
201
|
#####################################################
|
|
291
202
|
# Soft Rounding
|
|
292
203
|
#####################################################
|
|
293
204
|
aux_var = self.get_soft_targets()
|
|
294
|
-
if training:
|
|
295
|
-
ar_iter.set_(ar_iter + 1)
|
|
296
|
-
else:
|
|
205
|
+
if not training:
|
|
297
206
|
aux_var = (aux_var >= 0.5).to(auxvar.dtype)
|
|
298
207
|
|
|
299
208
|
if self.per_channel:
|
|
@@ -317,7 +226,7 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
|
317
226
|
power_of_two=self.power_of_two)
|
|
318
227
|
|
|
319
228
|
if self.quantization_parameter_learning and not self.power_of_two:
|
|
320
|
-
scale = torch.reshape(self.
|
|
229
|
+
scale = torch.reshape(self.get_quantizer_variable(SCALE_PTQ), reshape_shape)
|
|
321
230
|
q_tensor *= scale
|
|
322
231
|
|
|
323
232
|
else:
|
|
@@ -328,4 +237,8 @@ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
|
328
237
|
signed=True,
|
|
329
238
|
power_of_two=self.power_of_two)
|
|
330
239
|
|
|
240
|
+
if self.quantization_parameter_learning and not self.power_of_two:
|
|
241
|
+
scale = self.get_quantizer_variable(SCALE_PTQ)
|
|
242
|
+
q_tensor *= scale
|
|
243
|
+
|
|
331
244
|
return q_tensor
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
from typing import Dict
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
23
|
+
from model_compression_toolkit.gptq.common.gptq_config import RoundingType
|
|
24
|
+
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
|
|
25
|
+
BasePytorchGPTQTrainableQuantizer
|
|
26
|
+
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
|
|
27
|
+
from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
|
|
28
|
+
from model_compression_toolkit.gptq.common.gptq_constants import SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
|
|
29
|
+
from model_compression_toolkit.gptq.pytorch.quantizer.quant_utils import fix_range_to_include_zero
|
|
30
|
+
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
|
|
31
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import \
|
|
32
|
+
mark_quantizer
|
|
33
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import \
|
|
34
|
+
VariableGroup
|
|
35
|
+
from model_compression_toolkit.constants import RANGE_MAX, RANGE_MIN
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def soft_rounding_unifrom_quantizer(input_tensor: torch.Tensor,
|
|
39
|
+
auxvar_tensor: torch.Tensor,
|
|
40
|
+
min_range: torch.Tensor,
|
|
41
|
+
max_range: torch.Tensor,
|
|
42
|
+
num_bits: int) -> torch.Tensor:
|
|
43
|
+
"""
|
|
44
|
+
Quantize a tensor uniformly for GPTQ quantizers.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
|
|
48
|
+
auxvar_tensor: Tensor that manifests the bit shift of the quantized weights due to gptq training.
|
|
49
|
+
min_range: Tensor with min values to compute the delta grid.
|
|
50
|
+
max_range: Tensor with max values to compute the delta grid.
|
|
51
|
+
num_bits: Num of bits to use.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
A quantized tensor.
|
|
55
|
+
"""
|
|
56
|
+
# adjusts the quantization range so the quantization grid includes zero.
|
|
57
|
+
min_range, max_range = fix_range_to_include_zero(min_range, max_range, num_bits)
|
|
58
|
+
delta = qutils.calculate_delta_uniform(min_range, max_range, num_bits)
|
|
59
|
+
input_tensor_int = qutils.ste_floor((input_tensor - min_range) / delta)
|
|
60
|
+
tensor_q = input_tensor_int + auxvar_tensor
|
|
61
|
+
return delta * qutils.ste_clip(tensor_q,
|
|
62
|
+
min_val=0,
|
|
63
|
+
max_val=2 ** num_bits - 1) + min_range
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
|
|
67
|
+
quantization_method=[QuantizationMethod.UNIFORM],
|
|
68
|
+
quantizer_type=RoundingType.SoftQuantizer)
|
|
69
|
+
class UniformSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
70
|
+
"""
|
|
71
|
+
Trainable uniform quantizer to optimize the rounding of the quantized values using a soft quantization method.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self,
|
|
75
|
+
quantization_config: TrainableQuantizerWeightsConfig,
|
|
76
|
+
quantization_parameter_learning: bool = False):
|
|
77
|
+
"""
|
|
78
|
+
Construct a Pytorch model that utilize a fake weight quantizer of soft-quantizer for symmetric quantizer.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
quantization_config: Trainable weights quantizer config.
|
|
82
|
+
quantization_parameter_learning (Bool): Whether to learn the min/max ranges or not
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
super().__init__(quantization_config)
|
|
86
|
+
self.num_bits = quantization_config.weights_n_bits
|
|
87
|
+
self.per_channel = quantization_config.weights_per_channel_threshold
|
|
88
|
+
|
|
89
|
+
self.min_values = quantization_config.weights_quantization_params[RANGE_MIN]
|
|
90
|
+
self.max_values = quantization_config.weights_quantization_params[RANGE_MAX]
|
|
91
|
+
|
|
92
|
+
self.quantization_axis = quantization_config.weights_channels_axis
|
|
93
|
+
self.quantization_parameter_learning = quantization_parameter_learning
|
|
94
|
+
|
|
95
|
+
# gamma and zeta are stretch parameters for computing the rectified sigmoid function.
|
|
96
|
+
# See: https://arxiv.org/pdf/2004.10568.pdf
|
|
97
|
+
self.gamma = SOFT_ROUNDING_GAMMA
|
|
98
|
+
self.zeta = SOFT_ROUNDING_ZETA
|
|
99
|
+
|
|
100
|
+
def initialize_quantization(self,
|
|
101
|
+
tensor_shape: torch.Size,
|
|
102
|
+
name: str,
|
|
103
|
+
layer: qi.PytorchQuantizationWrapper):
|
|
104
|
+
"""
|
|
105
|
+
Add quantizer parameters to the quantizer parameters dictionary
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
tensor_shape: tensor shape of the quantized tensor.
|
|
109
|
+
name: Tensor name.
|
|
110
|
+
layer: Layer to quantize.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
# Add min and max variables to layer.
|
|
114
|
+
if self.per_channel:
|
|
115
|
+
min_values = to_torch_tensor(self.min_values)
|
|
116
|
+
max_values = to_torch_tensor(self.max_values)
|
|
117
|
+
else:
|
|
118
|
+
min_values = torch.tensor(self.min_values)
|
|
119
|
+
max_values = torch.tensor(self.max_values)
|
|
120
|
+
|
|
121
|
+
layer.register_parameter(name+"_"+FQ_MIN, nn.Parameter(min_values, requires_grad=self.quantization_parameter_learning))
|
|
122
|
+
layer.register_parameter(name+"_"+FQ_MAX, nn.Parameter(max_values, requires_grad=self.quantization_parameter_learning))
|
|
123
|
+
|
|
124
|
+
w = layer.layer.weight
|
|
125
|
+
delta = qutils.calculate_delta_uniform(min_values, max_values, self.num_bits)
|
|
126
|
+
w_clipped_normed = torch.clip((w - min_values) / delta, 0, 2 ** self.num_bits - 1)
|
|
127
|
+
rest = w_clipped_normed - torch.floor(w_clipped_normed) # rest of rounding [0, 1)
|
|
128
|
+
alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # => sigmoid(alpha) = rest
|
|
129
|
+
layer.register_parameter(f"{name}_{AUXVAR}", nn.Parameter(alpha, requires_grad=True))
|
|
130
|
+
|
|
131
|
+
# Save the quantizer parameters
|
|
132
|
+
self.add_quantizer_variable(FQ_MIN, layer.get_parameter(name+"_"+FQ_MIN), VariableGroup.QPARAMS)
|
|
133
|
+
self.add_quantizer_variable(FQ_MAX, layer.get_parameter(name+"_"+FQ_MAX), VariableGroup.QPARAMS)
|
|
134
|
+
self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
|
|
135
|
+
|
|
136
|
+
def get_soft_targets(self) -> torch.Tensor:
|
|
137
|
+
"""
|
|
138
|
+
Computes the rectified sigmoid function for the quantization target parameters.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
A tensor with the soft rounding targets values.
|
|
142
|
+
|
|
143
|
+
"""
|
|
144
|
+
scaled_sigmoid = torch.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma
|
|
145
|
+
return torch.clip(scaled_sigmoid, min=0, max=1)
|
|
146
|
+
|
|
147
|
+
def get_quant_config(self) -> Dict[str, np.ndarray]:
|
|
148
|
+
"""
|
|
149
|
+
Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
|
|
153
|
+
Keys must match NodeQuantizationConfig attributes
|
|
154
|
+
|
|
155
|
+
"""
|
|
156
|
+
min_values = torch_tensor_to_numpy(self.get_quantizer_variable(FQ_MIN))
|
|
157
|
+
max_values = torch_tensor_to_numpy(self.get_quantizer_variable(FQ_MAX))
|
|
158
|
+
return {RANGE_MIN: min_values,
|
|
159
|
+
RANGE_MAX: max_values}
|
|
160
|
+
|
|
161
|
+
def __call__(self,
|
|
162
|
+
inputs: nn.Parameter,
|
|
163
|
+
training: bool) -> torch.Tensor:
|
|
164
|
+
"""
|
|
165
|
+
Quantize a tensor.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
inputs: Input tensor to quantize.
|
|
169
|
+
training: whether in training mode or not
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
quantized tensor
|
|
173
|
+
"""
|
|
174
|
+
auxvar = self.get_quantizer_variable(AUXVAR)
|
|
175
|
+
min_range = self.get_quantizer_variable(FQ_MIN)
|
|
176
|
+
max_range = self.get_quantizer_variable(FQ_MAX)
|
|
177
|
+
|
|
178
|
+
#####################################################
|
|
179
|
+
# Soft Rounding
|
|
180
|
+
#####################################################
|
|
181
|
+
aux_var = self.get_soft_targets()
|
|
182
|
+
if not training:
|
|
183
|
+
aux_var = (aux_var >= 0.5).to(auxvar.dtype)
|
|
184
|
+
|
|
185
|
+
#####################################################
|
|
186
|
+
# Quantized Input
|
|
187
|
+
#####################################################
|
|
188
|
+
q_tensor = soft_rounding_unifrom_quantizer(input_tensor=inputs,
|
|
189
|
+
auxvar_tensor=aux_var,
|
|
190
|
+
min_range=min_range,
|
|
191
|
+
max_range=max_range,
|
|
192
|
+
num_bits=self.num_bits)
|
|
193
|
+
|
|
194
|
+
return q_tensor
|
|
@@ -14,23 +14,23 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
import torch
|
|
16
16
|
import torch.nn as nn
|
|
17
|
-
from typing import
|
|
17
|
+
from typing import Dict
|
|
18
18
|
import numpy as np
|
|
19
19
|
from model_compression_toolkit.core.common.defaultdict import DefaultDict
|
|
20
20
|
|
|
21
21
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
22
|
-
from model_compression_toolkit.
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
23
23
|
from model_compression_toolkit.gptq.common.gptq_config import RoundingType
|
|
24
24
|
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
|
|
25
25
|
BasePytorchGPTQTrainableQuantizer
|
|
26
26
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
|
|
27
27
|
from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
|
|
28
28
|
from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD, MAX_LSB_CHANGE
|
|
29
|
-
from model_compression_toolkit.
|
|
29
|
+
from model_compression_toolkit.constants import THRESHOLD
|
|
30
30
|
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
|
|
31
31
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import \
|
|
32
32
|
mark_quantizer
|
|
33
|
-
|
|
33
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
34
34
|
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
|
|
35
35
|
get_threshold_reshape_shape
|
|
36
36
|
|
|
@@ -104,23 +104,19 @@ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
|
|
|
104
104
|
self.quantization_axis = quantization_config.weights_channels_axis
|
|
105
105
|
self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
|
|
106
106
|
self.max_lsbs_change = max_lsbs_change_map.get(self.num_bits)
|
|
107
|
-
self.quantizer_parameters = {}
|
|
108
107
|
|
|
109
108
|
|
|
110
109
|
def initialize_quantization(self,
|
|
111
110
|
tensor_shape: torch.Size,
|
|
112
111
|
name: str,
|
|
113
|
-
layer: qi.PytorchQuantizationWrapper)
|
|
112
|
+
layer: qi.PytorchQuantizationWrapper):
|
|
114
113
|
"""
|
|
115
|
-
|
|
114
|
+
Add quantizer parameters to the quantizer parameters dictionary
|
|
116
115
|
|
|
117
116
|
Args:
|
|
118
117
|
tensor_shape: tensor shape of the quantized tensor.
|
|
119
118
|
name: Tensor name.
|
|
120
119
|
layer: Layer to quantize.
|
|
121
|
-
|
|
122
|
-
Returns:
|
|
123
|
-
Dictionary of parameters names to the variables.
|
|
124
120
|
"""
|
|
125
121
|
|
|
126
122
|
layer.register_parameter(f"{name}_{PTQ_THRESHOLD}",
|
|
@@ -131,27 +127,9 @@ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
|
|
|
131
127
|
requires_grad=True))
|
|
132
128
|
|
|
133
129
|
# save the quantizer added parameters for later calculations
|
|
134
|
-
self.
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
return self.quantizer_parameters
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
def get_aux_variable(self) -> List[torch.Tensor]:
|
|
141
|
-
"""
|
|
142
|
-
This function return a list with the quantizer's quantization auxiliary variables.
|
|
143
|
-
|
|
144
|
-
Returns: A list with the quantization auxiliary variables.
|
|
145
|
-
"""
|
|
146
|
-
return [self.quantizer_parameters.get(AUXVAR)]
|
|
130
|
+
self.add_quantizer_variable(PTQ_THRESHOLD, layer.get_parameter(f"{name}_{PTQ_THRESHOLD}"), VariableGroup.QPARAMS)
|
|
131
|
+
self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
|
|
147
132
|
|
|
148
|
-
def get_quantization_variable(self) -> List[torch.Tensor]:
|
|
149
|
-
"""
|
|
150
|
-
This function return a list with the quantizer's quantization parameters variables.
|
|
151
|
-
|
|
152
|
-
Returns: A list with the quantization parameters.
|
|
153
|
-
"""
|
|
154
|
-
return [self.quantizer_parameters.get(PTQ_THRESHOLD)]
|
|
155
133
|
|
|
156
134
|
def get_quant_config(self) -> Dict[str, np.ndarray]:
|
|
157
135
|
"""
|
|
@@ -162,7 +140,7 @@ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
|
|
|
162
140
|
Keys must match NodeQuantizationConfig attributes
|
|
163
141
|
|
|
164
142
|
"""
|
|
165
|
-
old_threshold = self.
|
|
143
|
+
old_threshold = self.get_quantizer_variable(PTQ_THRESHOLD)
|
|
166
144
|
return {THRESHOLD: torch_tensor_to_numpy(old_threshold).reshape(self.threshold_shape)}
|
|
167
145
|
|
|
168
146
|
def __call__(self,
|
|
@@ -178,8 +156,8 @@ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):
|
|
|
178
156
|
Returns:
|
|
179
157
|
quantized tensor
|
|
180
158
|
"""
|
|
181
|
-
auxvar = self.
|
|
182
|
-
ptq_threshold_tensor = self.
|
|
159
|
+
auxvar = self.get_quantizer_variable(AUXVAR)
|
|
160
|
+
ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
|
|
183
161
|
|
|
184
162
|
if self.per_channel:
|
|
185
163
|
reshape_shape = get_threshold_reshape_shape(inputs.shape,
|