mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +12 -41
- model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
- model_compression_toolkit/core/__init__.py +14 -0
- model_compression_toolkit/core/analyzer.py +3 -2
- model_compression_toolkit/core/common/__init__.py +0 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +1 -8
- model_compression_toolkit/core/common/framework_info.py +1 -1
- model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
- model_compression_toolkit/core/common/graph/base_graph.py +2 -2
- model_compression_toolkit/core/common/graph/base_node.py +57 -1
- model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
- model_compression_toolkit/core/common/memory_computation.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/model_collector.py +2 -2
- model_compression_toolkit/core/common/model_validation.py +1 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -1
- model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
- model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
- model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
- model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
- model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
- model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/keras/constants.py +0 -7
- model_compression_toolkit/core/keras/default_framework_info.py +3 -3
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/keras/keras_implementation.py +2 -10
- model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
- model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
- model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +1 -1
- model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/constants.py +4 -6
- model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
- model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
- model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/runner.py +7 -7
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -104
- model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
- model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
- model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
- model_compression_toolkit/gptq/common/gptq_training.py +30 -39
- model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
- model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
- model_compression_toolkit/gptq/keras/graph_info.py +8 -33
- model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
- model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
- model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
- model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
- model_compression_toolkit/gptq/runner.py +3 -2
- model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
- model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
- model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
- model_compression_toolkit/ptq/__init__.py +3 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
- model_compression_toolkit/qat/__init__.py +4 -0
- model_compression_toolkit/qat/common/__init__.py +1 -2
- model_compression_toolkit/qat/common/qat_config.py +3 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
- model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
- model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
- model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
- model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
- model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
- model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
- /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Tuple, List
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit import FrameworkInfo
|
|
18
|
-
from model_compression_toolkit.
|
|
17
|
+
from model_compression_toolkit.core import FrameworkInfo
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
19
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
20
20
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
21
21
|
|
|
@@ -16,12 +16,14 @@ import copy
|
|
|
16
16
|
from abc import ABC, abstractmethod
|
|
17
17
|
import numpy as np
|
|
18
18
|
from typing import Callable, List, Any
|
|
19
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
20
|
-
from model_compression_toolkit.core.common import Graph,
|
|
19
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
20
|
+
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
21
21
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
22
22
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
23
|
+
from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR
|
|
23
24
|
from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points
|
|
24
25
|
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
|
|
26
|
+
from model_compression_toolkit.logger import Logger
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
class GPTQTrainer(ABC):
|
|
@@ -34,8 +36,7 @@ class GPTQTrainer(ABC):
|
|
|
34
36
|
graph_quant: Graph,
|
|
35
37
|
gptq_config: GradientPTQConfig,
|
|
36
38
|
fw_impl: FrameworkImplementation,
|
|
37
|
-
fw_info: FrameworkInfo
|
|
38
|
-
representative_data_gen: Callable):
|
|
39
|
+
fw_info: FrameworkInfo):
|
|
39
40
|
"""
|
|
40
41
|
Build two models from a graph: A teacher network (float model) and a student network (quantized model).
|
|
41
42
|
Use the dataset generator to pass images through the teacher and student networks to get intermediate
|
|
@@ -48,7 +49,6 @@ class GPTQTrainer(ABC):
|
|
|
48
49
|
gptq_config: GradientPTQConfig with parameters about the tuning process.
|
|
49
50
|
fw_impl: Framework implementation
|
|
50
51
|
fw_info: Framework information
|
|
51
|
-
representative_data_gen: Dataset to use for inputs of the models.
|
|
52
52
|
"""
|
|
53
53
|
self.graph_float = copy.deepcopy(graph_float)
|
|
54
54
|
self.graph_quant = copy.deepcopy(graph_quant)
|
|
@@ -66,10 +66,6 @@ class GPTQTrainer(ABC):
|
|
|
66
66
|
append2output=self.compare_points,
|
|
67
67
|
fw_info=self.fw_info)
|
|
68
68
|
|
|
69
|
-
if self.gptq_config.rounding_type == RoundingType.SoftQuantizer:
|
|
70
|
-
# dry run on the representative dataset to count number of batches
|
|
71
|
-
self.count_num_batches_for_training(representative_data_gen)
|
|
72
|
-
|
|
73
69
|
self.fxp_model, self.gptq_user_info = self.build_gptq_model()
|
|
74
70
|
|
|
75
71
|
def get_optimizer_with_param(self,
|
|
@@ -88,8 +84,10 @@ class GPTQTrainer(ABC):
|
|
|
88
84
|
|
|
89
85
|
w2train = [*flattened_trainable_weights]
|
|
90
86
|
|
|
87
|
+
quant_params_learning = self.gptq_config.gptq_quantizer_params_override.get(QUANT_PARAM_LEARNING_STR, False)
|
|
88
|
+
|
|
91
89
|
optimizer_with_param = [(self.gptq_config.optimizer, w2train)]
|
|
92
|
-
if self.gptq_config.train_bias or
|
|
90
|
+
if self.gptq_config.train_bias or quant_params_learning:
|
|
93
91
|
w2train_res = []
|
|
94
92
|
if self.gptq_config.train_bias:
|
|
95
93
|
if self.gptq_config.optimizer_bias is not None:
|
|
@@ -99,7 +97,7 @@ class GPTQTrainer(ABC):
|
|
|
99
97
|
if self.gptq_config.optimizer_rest is None:
|
|
100
98
|
Logger.error( # pragma: no cover
|
|
101
99
|
"To enable bias micro training an additional optimizer is required, please define the optimizer_rest")
|
|
102
|
-
if
|
|
100
|
+
if quant_params_learning:
|
|
103
101
|
if self.gptq_config.optimizer_quantization_parameter is not None: # Ability to override optimizer
|
|
104
102
|
optimizer_with_param.append((self.gptq_config.optimizer_quantization_parameter,
|
|
105
103
|
trainable_quantization_parameters))
|
|
@@ -107,25 +105,32 @@ class GPTQTrainer(ABC):
|
|
|
107
105
|
w2train_res.extend(trainable_quantization_parameters)
|
|
108
106
|
if self.gptq_config.optimizer_rest is None:
|
|
109
107
|
Logger.error( # pragma: no cover
|
|
110
|
-
"To enable
|
|
111
|
-
|
|
108
|
+
"To enable quantization parameters micro training an additional optimizer is required, please define the optimizer_rest")
|
|
109
|
+
if len(w2train_res) > 0:
|
|
110
|
+
# Either bias or quantization parameters are trainable but did not provide a specific optimizer,
|
|
111
|
+
# so we should use optimizer_rest to train them
|
|
112
|
+
if self.gptq_config.optimizer_rest is None:
|
|
113
|
+
Logger.error( # pragma: no cover
|
|
114
|
+
"To enable bias or quantization parameters micro training an additional optimizer is required, please define the optimizer_rest")
|
|
115
|
+
optimizer_with_param.append((self.gptq_config.optimizer_rest, w2train_res))
|
|
112
116
|
|
|
113
117
|
return optimizer_with_param
|
|
114
118
|
|
|
115
119
|
|
|
116
|
-
def
|
|
117
|
-
|
|
120
|
+
def compute_hessian_based_weights(self,
|
|
121
|
+
representative_data_gen: Callable) -> np.ndarray:
|
|
118
122
|
"""
|
|
119
|
-
Computes the
|
|
123
|
+
Computes the Hessian-based weights using the framework's model_grad method per batch of images.
|
|
120
124
|
|
|
121
125
|
Args:
|
|
122
|
-
representative_data_gen: Dataset used for inference to compute the
|
|
126
|
+
representative_data_gen: Dataset used for inference to compute the Hessian-based weights.
|
|
123
127
|
|
|
124
128
|
Returns: A vector of weights, one for each compare point,
|
|
125
129
|
to be used for the loss metric weighted average computation when running GPTQ training.
|
|
126
130
|
"""
|
|
127
|
-
if self.gptq_config.
|
|
128
|
-
images = self._generate_images_batch(representative_data_gen,
|
|
131
|
+
if self.gptq_config.use_hessian_based_weights:
|
|
132
|
+
images = self._generate_images_batch(representative_data_gen,
|
|
133
|
+
self.gptq_config.hessian_weights_config.hessians_num_samples)
|
|
129
134
|
|
|
130
135
|
model_output_replacement = self._get_model_output_replacement()
|
|
131
136
|
|
|
@@ -143,17 +148,18 @@ class GPTQTrainer(ABC):
|
|
|
143
148
|
output_list=model_output_replacement,
|
|
144
149
|
all_outputs_indices=[],
|
|
145
150
|
alpha=0,
|
|
146
|
-
norm_weights=self.gptq_config.norm_weights,
|
|
147
|
-
n_iter=self.gptq_config.
|
|
151
|
+
norm_weights=self.gptq_config.hessian_weights_config.norm_weights,
|
|
152
|
+
n_iter=self.gptq_config.hessian_weights_config.hessians_n_iter)
|
|
148
153
|
points_apprx_jacobians_weights.append(image_ip_gradients)
|
|
149
|
-
if self.gptq_config.log_norm:
|
|
154
|
+
if self.gptq_config.hessian_weights_config.log_norm:
|
|
150
155
|
mean_jacobian_weights = np.mean(points_apprx_jacobians_weights, axis=0)
|
|
151
156
|
mean_jacobian_weights = np.where(mean_jacobian_weights != 0, mean_jacobian_weights,
|
|
152
157
|
np.partition(mean_jacobian_weights, 1)[1])
|
|
153
158
|
log_weights = np.log10(mean_jacobian_weights)
|
|
154
159
|
|
|
155
|
-
|
|
156
|
-
|
|
160
|
+
if self.gptq_config.hessian_weights_config.scale_log_norm:
|
|
161
|
+
return (log_weights - np.min(log_weights)) / (np.max(log_weights) - np.min(log_weights))
|
|
162
|
+
|
|
157
163
|
return log_weights - np.min(log_weights)
|
|
158
164
|
else:
|
|
159
165
|
return np.mean(points_apprx_jacobians_weights, axis=0)
|
|
@@ -249,21 +255,6 @@ class GPTQTrainer(ABC):
|
|
|
249
255
|
replacement_outputs.append(prev_node)
|
|
250
256
|
return replacement_outputs
|
|
251
257
|
|
|
252
|
-
def count_num_batches_for_training(self, representative_data_gen):
|
|
253
|
-
"""
|
|
254
|
-
Runs a "dry-run" of the representative dataset to count the number of batches for each training epoch.
|
|
255
|
-
|
|
256
|
-
Args:
|
|
257
|
-
representative_data_gen: A callable method to retrieve images from Dataset.
|
|
258
|
-
|
|
259
|
-
Returns: The number of batches for each training epoch.
|
|
260
|
-
|
|
261
|
-
"""
|
|
262
|
-
num_batches = 0
|
|
263
|
-
for _ in representative_data_gen():
|
|
264
|
-
num_batches += 1
|
|
265
|
-
self.gptq_config.quantizer_config.set_num_batches(num_batches)
|
|
266
|
-
|
|
267
258
|
|
|
268
259
|
def gptq_training(graph_float: Graph,
|
|
269
260
|
graph_quant: Graph,
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Type
|
|
17
|
+
|
|
18
|
+
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
|
19
|
+
from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
|
|
20
|
+
from model_compression_toolkit.gptq.keras.gptq_training import KerasGPTQTrainer
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class GPTQKerasImplemantation(GPTQFrameworkImplemantation, KerasImplementation):
|
|
24
|
+
|
|
25
|
+
def get_gptq_trainer_obj(self) -> Type[KerasGPTQTrainer]:
|
|
26
|
+
"""
|
|
27
|
+
Returns: Keras object of GPTQTrainer
|
|
28
|
+
"""
|
|
29
|
+
return KerasGPTQTrainer
|
|
@@ -12,7 +12,6 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from functools import partial
|
|
16
15
|
from typing import Callable, List, Tuple, Union
|
|
17
16
|
|
|
18
17
|
import tensorflow as tf
|
|
@@ -23,11 +22,11 @@ from tqdm import tqdm
|
|
|
23
22
|
# As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
|
|
24
23
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
25
24
|
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
|
26
|
-
from model_compression_toolkit.gptq.common.gptq_constants import REGULARIZATION_VALUES
|
|
27
25
|
from packaging import version
|
|
28
26
|
|
|
29
27
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
30
28
|
from model_compression_toolkit.gptq.keras.quantizer.quantization_builder import quantization_builder
|
|
29
|
+
from model_compression_toolkit.logger import Logger
|
|
31
30
|
from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
|
|
32
31
|
|
|
33
32
|
if version.parse(tf.__version__) < version.parse("2.6"):
|
|
@@ -37,15 +36,15 @@ else:
|
|
|
37
36
|
|
|
38
37
|
from model_compression_toolkit.core import common
|
|
39
38
|
from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
|
|
40
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
|
|
39
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
|
|
41
40
|
from model_compression_toolkit.core.common import Graph
|
|
42
|
-
from model_compression_toolkit.gptq.keras.graph_info import get_weights_for_loss,
|
|
43
|
-
|
|
41
|
+
from model_compression_toolkit.gptq.keras.graph_info import get_weights_for_loss, get_gptq_trainable_parameters
|
|
42
|
+
from model_compression_toolkit.gptq.keras.quantizer.regularization_factory import get_regularization
|
|
44
43
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
45
44
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
46
45
|
import numpy as np
|
|
47
46
|
import copy
|
|
48
|
-
from model_compression_toolkit.core.keras.constants import BIAS, USE_BIAS
|
|
47
|
+
from model_compression_toolkit.core.keras.constants import BIAS, USE_BIAS
|
|
49
48
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
50
49
|
|
|
51
50
|
|
|
@@ -79,13 +78,12 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
79
78
|
graph_quant,
|
|
80
79
|
gptq_config,
|
|
81
80
|
fw_impl,
|
|
82
|
-
fw_info
|
|
83
|
-
representative_data_gen)
|
|
81
|
+
fw_info)
|
|
84
82
|
|
|
85
83
|
self.loss_list = []
|
|
86
84
|
self.input_scale = 1
|
|
87
85
|
|
|
88
|
-
trainable_weights, bias_weights, trainable_threshold
|
|
86
|
+
trainable_weights, bias_weights, trainable_threshold = get_gptq_trainable_parameters(
|
|
89
87
|
self.fxp_model,
|
|
90
88
|
fw_info,
|
|
91
89
|
add_bias=gptq_config.train_bias)
|
|
@@ -108,11 +106,13 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
108
106
|
[len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0
|
|
109
107
|
|
|
110
108
|
if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
|
|
111
|
-
|
|
109
|
+
Logger.error("Input scale mismatch between float and GPTQ networks") # pragma: no cover
|
|
112
110
|
else:
|
|
113
111
|
self.input_scale = self.gptq_user_info.input_scale
|
|
114
112
|
|
|
115
|
-
self.weights_for_average_loss = self.
|
|
113
|
+
self.weights_for_average_loss = self.compute_hessian_based_weights(representative_data_gen)
|
|
114
|
+
|
|
115
|
+
self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
|
|
116
116
|
|
|
117
117
|
def _is_gptq_applicable(self,
|
|
118
118
|
node: common.BaseNode) -> bool:
|
|
@@ -127,7 +127,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
127
127
|
"""
|
|
128
128
|
|
|
129
129
|
if node.is_weights_quantization_enabled() and not self.fw_info.is_kernel_op(node.type):
|
|
130
|
-
|
|
130
|
+
Logger.error(f"GPTQ Error: Quantizing node {node.name} of type {node.type} "
|
|
131
131
|
f"without a kernel isn't supported")
|
|
132
132
|
return node.is_weights_quantization_enabled()
|
|
133
133
|
|
|
@@ -195,9 +195,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
195
195
|
self.compare_points_std,
|
|
196
196
|
self.weights_for_average_loss)
|
|
197
197
|
|
|
198
|
-
reg_value = self.gptq_config.
|
|
199
|
-
self.fxp_model,
|
|
200
|
-
**{REGULARIZATION_VALUES: self._get_quantizer_regularization_values(self.gptq_config.rounding_type)})
|
|
198
|
+
reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor)
|
|
201
199
|
|
|
202
200
|
loss_value += reg_value
|
|
203
201
|
|
|
@@ -283,7 +281,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
283
281
|
self.gptq_config.log_function(loss_value_step, grads[0], in_optimizer_with_param[0][-1],
|
|
284
282
|
self.compare_points)
|
|
285
283
|
self.loss_list.append(loss_value_step.numpy())
|
|
286
|
-
|
|
284
|
+
Logger.debug(f'last loss value: {self.loss_list[-1]}')
|
|
287
285
|
|
|
288
286
|
def update_graph(self):
|
|
289
287
|
"""
|
|
@@ -300,7 +298,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
300
298
|
if len(node) == 0 and isinstance(layer.layer, TensorFlowOpLayer):
|
|
301
299
|
node = graph.find_node_by_name('_'.join(layer.layer.name.split('_')[3:]))
|
|
302
300
|
if len(node) != 1:
|
|
303
|
-
|
|
301
|
+
Logger.error(f"Can't update GPTQ graph due to missing layer named: {layer.layer.name}")
|
|
304
302
|
node = node[0]
|
|
305
303
|
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
|
|
306
304
|
fw_info=self.fw_info)
|
|
@@ -319,18 +317,3 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
319
317
|
node.set_weights_by_keys(BIAS, new_bias)
|
|
320
318
|
|
|
321
319
|
return graph
|
|
322
|
-
|
|
323
|
-
def _get_quantizer_regularization_values(self, rounding_type: RoundingType) -> List[tf.Tensor]:
|
|
324
|
-
"""
|
|
325
|
-
Mapping between a rounding type to its matching regularization method.
|
|
326
|
-
|
|
327
|
-
Args:
|
|
328
|
-
rounding_type: GPTQ rounding type.
|
|
329
|
-
|
|
330
|
-
Returns: A regularization computation method.
|
|
331
|
-
|
|
332
|
-
"""
|
|
333
|
-
if rounding_type == RoundingType.SoftQuantizer:
|
|
334
|
-
return get_soft_rounding_reg(self.fxp_model)
|
|
335
|
-
else:
|
|
336
|
-
return []
|
|
@@ -13,23 +13,21 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
|
|
17
16
|
import tensorflow as tf
|
|
18
17
|
from typing import Tuple, List
|
|
19
|
-
|
|
20
18
|
from model_compression_toolkit.core.keras.constants import USE_BIAS
|
|
21
19
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
22
20
|
from tensorflow.keras.models import Model
|
|
23
|
-
|
|
24
21
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
25
22
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
26
23
|
from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
|
|
24
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
27
25
|
|
|
28
26
|
|
|
29
27
|
def get_gptq_trainable_parameters(fxp_model: Model,
|
|
30
28
|
fw_info: FrameworkInfo,
|
|
31
29
|
add_bias: bool = False) -> (
|
|
32
|
-
List[tf.Variable], List[tf.Variable], List[tf.Variable]
|
|
30
|
+
List[tf.Variable], List[tf.Variable], List[tf.Variable]):
|
|
33
31
|
"""
|
|
34
32
|
Get trainable parameters from all layers in a model
|
|
35
33
|
|
|
@@ -45,16 +43,17 @@ def get_gptq_trainable_parameters(fxp_model: Model,
|
|
|
45
43
|
trainable_weights: List[tf.Tensor] = []
|
|
46
44
|
trainable_threshold: List[tf.Tensor] = []
|
|
47
45
|
bias_weights: List[List[tf.Tensor]] = []
|
|
48
|
-
temperature_weights: List[tf.Tensor] = []
|
|
49
46
|
|
|
50
47
|
for layer in fxp_model.layers:
|
|
51
48
|
if isinstance(layer, KerasQuantizationWrapper):
|
|
52
49
|
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
|
53
50
|
fw_info=DEFAULT_KERAS_INFO)
|
|
54
51
|
|
|
55
|
-
# collect trainable weights per
|
|
56
|
-
|
|
57
|
-
|
|
52
|
+
# collect trainable weights per quantizer
|
|
53
|
+
quantizer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.WEIGHTS)
|
|
54
|
+
quantizer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.QPARAMS)
|
|
55
|
+
trainable_weights.append(quantizer_trainable_weights)
|
|
56
|
+
trainable_threshold.extend(quantizer_trainable_threshold)
|
|
58
57
|
|
|
59
58
|
if add_bias:
|
|
60
59
|
kernel_ops_attrs = fw_info.kernel_ops_attributes_mapping.get(type(layer.layer))
|
|
@@ -62,10 +61,8 @@ def get_gptq_trainable_parameters(fxp_model: Model,
|
|
|
62
61
|
and layer.layer.get_config().get(USE_BIAS)
|
|
63
62
|
if use_bias is not None and use_bias:
|
|
64
63
|
bias_weights.append([layer.layer.bias])
|
|
65
|
-
trainable_weights.append(layer_trainable_weights)
|
|
66
|
-
trainable_threshold.extend(layer_trainable_threshold)
|
|
67
64
|
|
|
68
|
-
return trainable_weights, bias_weights, trainable_threshold
|
|
65
|
+
return trainable_weights, bias_weights, trainable_threshold
|
|
69
66
|
|
|
70
67
|
|
|
71
68
|
def get_weights_for_loss(fxp_model: Model) -> Tuple[List[list], List[list]]:
|
|
@@ -95,25 +92,3 @@ def get_weights_for_loss(fxp_model: Model) -> Tuple[List[list], List[list]]:
|
|
|
95
92
|
fxp_weights_list.append(_layer_fxp_weights)
|
|
96
93
|
|
|
97
94
|
return flp_weights_list, fxp_weights_list
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
# TODO: this function need to move to location that is relevant only for soft quantizer -
|
|
101
|
-
# once deciding how to handle GPTQ quantizers regularization.
|
|
102
|
-
def get_soft_rounding_reg(fxp_model: Model) -> List[tf.Tensor]:
|
|
103
|
-
"""
|
|
104
|
-
This function returns the soft quantizer regularization values for SoftRounding.
|
|
105
|
-
|
|
106
|
-
Args:
|
|
107
|
-
fxp_model: A model to be quantized with SoftRounding.
|
|
108
|
-
|
|
109
|
-
Returns: A list of tensors.
|
|
110
|
-
"""
|
|
111
|
-
|
|
112
|
-
soft_reg_aux: List[tf.Tensor] = []
|
|
113
|
-
for layer in fxp_model.layers:
|
|
114
|
-
if isinstance(layer, KerasQuantizationWrapper):
|
|
115
|
-
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
|
116
|
-
fw_info=DEFAULT_KERAS_INFO)
|
|
117
|
-
|
|
118
|
-
soft_reg_aux.append(layer.weights_quantizers[kernel_attribute].get_regularization())
|
|
119
|
-
return soft_reg_aux
|
|
@@ -16,21 +16,19 @@
|
|
|
16
16
|
from typing import Callable, Tuple
|
|
17
17
|
from packaging import version
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit.
|
|
20
|
-
from model_compression_toolkit.
|
|
21
|
-
from model_compression_toolkit.core.common.constants import TENSORFLOW
|
|
19
|
+
from model_compression_toolkit.logger import Logger
|
|
20
|
+
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
|
|
22
21
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
23
22
|
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
|
|
24
23
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
25
24
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
26
|
-
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import
|
|
27
|
-
|
|
28
|
-
from model_compression_toolkit import CoreConfig
|
|
25
|
+
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfigV2
|
|
26
|
+
from model_compression_toolkit.core import CoreConfig
|
|
29
27
|
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
|
|
30
28
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
|
31
29
|
from model_compression_toolkit.core.exporter import export_model
|
|
32
30
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
33
|
-
from model_compression_toolkit.
|
|
31
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
34
32
|
|
|
35
33
|
LR_DEFAULT = 0.15
|
|
36
34
|
LR_REST_DEFAULT = 1e-4
|
|
@@ -38,14 +36,14 @@ LR_BIAS_DEFAULT = 1e-4
|
|
|
38
36
|
LR_QUANTIZATION_PARAM_DEFAULT = 1e-3
|
|
39
37
|
GPTQ_MOMENTUM = 0.9
|
|
40
38
|
|
|
41
|
-
if
|
|
39
|
+
if FOUND_TF:
|
|
42
40
|
import tensorflow as tf
|
|
43
41
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
44
|
-
from model_compression_toolkit.
|
|
42
|
+
from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation
|
|
45
43
|
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
|
46
44
|
from tensorflow.keras.models import Model
|
|
47
45
|
from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss
|
|
48
|
-
from model_compression_toolkit.
|
|
46
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
49
47
|
from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
|
|
50
48
|
from model_compression_toolkit import get_target_platform_capabilities
|
|
51
49
|
|
|
@@ -62,7 +60,8 @@ if common.constants.FOUND_TF:
|
|
|
62
60
|
optimizer: OptimizerV2 = tf.keras.optimizers.Adam(learning_rate=LR_DEFAULT),
|
|
63
61
|
optimizer_rest: OptimizerV2 = tf.keras.optimizers.Adam(learning_rate=LR_REST_DEFAULT),
|
|
64
62
|
loss: Callable = GPTQMultipleTensorsLoss(),
|
|
65
|
-
log_function: Callable = None
|
|
63
|
+
log_function: Callable = None,
|
|
64
|
+
use_hessian_based_weights: bool = True) -> GradientPTQConfigV2:
|
|
66
65
|
"""
|
|
67
66
|
Create a GradientPTQConfigV2 instance for Keras models.
|
|
68
67
|
|
|
@@ -72,6 +71,7 @@ if common.constants.FOUND_TF:
|
|
|
72
71
|
optimizer_rest (OptimizerV2): Keras optimizer to use for fine-tuning of the bias variable.
|
|
73
72
|
loss (Callable): loss to use during fine-tuning. should accept 4 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors, the 3rd is a list of quantized weights and the 4th is a list of float weights.
|
|
74
73
|
log_function (Callable): Function to log information about the gptq process.
|
|
74
|
+
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
|
|
75
75
|
|
|
76
76
|
returns:
|
|
77
77
|
a GradientPTQConfigV2 object to use when fine-tuning the quantized model using gptq.
|
|
@@ -85,24 +85,25 @@ if common.constants.FOUND_TF:
|
|
|
85
85
|
|
|
86
86
|
Create a GradientPTQConfigV2 to run for 5 epochs:
|
|
87
87
|
|
|
88
|
-
>>> gptq_conf = mct.get_keras_gptq_config(n_epochs=5)
|
|
88
|
+
>>> gptq_conf = mct.gptq.get_keras_gptq_config(n_epochs=5)
|
|
89
89
|
|
|
90
90
|
Other Tensorflow optimizers can be passed:
|
|
91
91
|
|
|
92
|
-
>>> gptq_conf = mct.get_keras_gptq_config(n_epochs=3, optimizer=tf.keras.optimizers.Nadam())
|
|
92
|
+
>>> gptq_conf = mct.gptq.get_keras_gptq_config(n_epochs=3, optimizer=tf.keras.optimizers.Nadam())
|
|
93
93
|
|
|
94
94
|
The configuration can be passed to :func:`~model_compression_toolkit.keras_post_training_quantization` in order to quantize a keras model using gptq.
|
|
95
95
|
|
|
96
96
|
"""
|
|
97
|
-
bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT,
|
|
97
|
+
bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT,
|
|
98
|
+
momentum=GPTQ_MOMENTUM)
|
|
98
99
|
return GradientPTQConfigV2(n_epochs,
|
|
99
100
|
optimizer,
|
|
100
101
|
optimizer_rest=optimizer_rest,
|
|
101
102
|
loss=loss,
|
|
102
103
|
log_function=log_function,
|
|
103
104
|
train_bias=True,
|
|
104
|
-
|
|
105
|
-
|
|
105
|
+
optimizer_bias=bias_optimizer,
|
|
106
|
+
use_hessian_based_weights=use_hessian_based_weights)
|
|
106
107
|
|
|
107
108
|
|
|
108
109
|
def keras_gradient_post_training_quantization_experimental(in_model: Model,
|
|
@@ -164,28 +165,28 @@ if common.constants.FOUND_TF:
|
|
|
164
165
|
|
|
165
166
|
Create an MCT core config, containing the quantization configuration:
|
|
166
167
|
|
|
167
|
-
>>> config = mct.CoreConfig()
|
|
168
|
+
>>> config = mct.core.CoreConfig()
|
|
168
169
|
|
|
169
170
|
If mixed precision is desired, create an MCT core config with a mixed-precision configuration, to quantize a model
|
|
170
171
|
with different bitwidths for different layers.
|
|
171
172
|
The candidates bitwidth for quantization should be defined in the target platform model:
|
|
172
173
|
|
|
173
|
-
>>> config = mct.CoreConfig(mixed_precision_config=mct.MixedPrecisionQuantizationConfigV2(num_of_images=1))
|
|
174
|
+
>>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfigV2(num_of_images=1))
|
|
174
175
|
|
|
175
176
|
For mixed-precision set a target KPI object:
|
|
176
177
|
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
|
|
177
178
|
that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value,
|
|
178
179
|
while the bias will not):
|
|
179
180
|
|
|
180
|
-
>>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
181
|
+
>>> kpi = mct.core.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
181
182
|
|
|
182
183
|
Create GPTQ config:
|
|
183
184
|
|
|
184
|
-
>>> gptq_config = mct.get_keras_gptq_config(n_epochs=1)
|
|
185
|
+
>>> gptq_config = mct.gptq.get_keras_gptq_config(n_epochs=1)
|
|
185
186
|
|
|
186
187
|
Pass the model with the representative dataset generator to get a quantized model:
|
|
187
188
|
|
|
188
|
-
>>> quantized_model, quantization_info = mct.keras_gradient_post_training_quantization_experimental(model, repr_datagen, gptq_config, target_kpi=kpi, core_config=config)
|
|
189
|
+
>>> quantized_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization_experimental(model, repr_datagen, gptq_config, target_kpi=kpi, core_config=config)
|
|
189
190
|
|
|
190
191
|
"""
|
|
191
192
|
KerasModelValidation(model=in_model,
|
|
@@ -193,15 +194,15 @@ if common.constants.FOUND_TF:
|
|
|
193
194
|
|
|
194
195
|
if core_config.mixed_precision_enable:
|
|
195
196
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
|
|
196
|
-
|
|
197
|
+
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
197
198
|
"MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
|
|
198
199
|
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
199
200
|
|
|
200
|
-
|
|
201
|
+
Logger.info("Using experimental mixed-precision quantization. "
|
|
201
202
|
"If you encounter an issue please file a bug.")
|
|
202
203
|
tb_w = _init_tensorboard_writer(fw_info)
|
|
203
204
|
|
|
204
|
-
fw_impl =
|
|
205
|
+
fw_impl = GPTQKerasImplemantation()
|
|
205
206
|
|
|
206
207
|
tg, bit_widths_config = core_runner(in_model=in_model,
|
|
207
208
|
representative_data_gen=representative_data_gen,
|
|
@@ -15,3 +15,4 @@
|
|
|
15
15
|
|
|
16
16
|
import model_compression_toolkit.gptq.keras.quantizer.ste_rounding.symmetric_ste
|
|
17
17
|
import model_compression_toolkit.gptq.keras.quantizer.soft_rounding.symmetric_soft_quantizer
|
|
18
|
+
import model_compression_toolkit.gptq.keras.quantizer.soft_rounding.uniform_soft_quantizer
|
|
@@ -15,8 +15,8 @@
|
|
|
15
15
|
from abc import abstractmethod
|
|
16
16
|
from typing import Union, Dict, List
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.constants import FOUND_TF
|
|
20
20
|
from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
|
|
21
21
|
|
|
22
22
|
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
@@ -45,7 +45,6 @@ if FOUND_TF:
|
|
|
45
45
|
|
|
46
46
|
super().__init__(quantization_config)
|
|
47
47
|
|
|
48
|
-
self.quantizer_parameters = None
|
|
49
48
|
|
|
50
49
|
def update_layer_quantization_params(self, layer: KerasQuantizationWrapper
|
|
51
50
|
) -> (Dict[str, tf.Tensor], Dict[str, Dict], Dict):
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
import tensorflow as tf
|
|
17
|
-
from model_compression_toolkit.
|
|
17
|
+
from model_compression_toolkit.constants import MIN_THRESHOLD
|
|
18
18
|
from typing import Tuple
|
|
19
19
|
|
|
20
20
|
|
|
@@ -26,6 +26,14 @@ def ste_ceil(x: tf.Tensor) -> tf.Tensor:
|
|
|
26
26
|
return error + x
|
|
27
27
|
|
|
28
28
|
|
|
29
|
+
def ste_floor(x: tf.Tensor) -> tf.Tensor:
|
|
30
|
+
"""
|
|
31
|
+
Return the floor values of a tensor.
|
|
32
|
+
"""
|
|
33
|
+
error = tf.stop_gradient(tf.math.floor(x) - x)
|
|
34
|
+
return error + x
|
|
35
|
+
|
|
36
|
+
|
|
29
37
|
def safe_log(x: tf.Tensor, eps: float) -> tf.Tensor:
|
|
30
38
|
"""
|
|
31
39
|
Computes log function of x unless x is smaller than some small value, so the log function would not fail.
|
|
@@ -72,6 +80,15 @@ def calculate_delta(max_tensor: tf.Tensor,
|
|
|
72
80
|
return max_tensor / (2 ** (num_bits - int(signed)))
|
|
73
81
|
|
|
74
82
|
|
|
83
|
+
def calculate_delta_uniform(min_tensor: tf.Tensor,
|
|
84
|
+
max_tensor: tf.Tensor,
|
|
85
|
+
num_bits: int) -> tf.Tensor:
|
|
86
|
+
"""
|
|
87
|
+
Compute the step size for the uniform quantization.
|
|
88
|
+
"""
|
|
89
|
+
return (max_tensor-min_tensor) / (2 ** num_bits - 1)
|
|
90
|
+
|
|
91
|
+
|
|
75
92
|
def ste_clip(x: [tf.Tensor, tf.Variable], max_val=1, min_val=None) -> tf.Tensor:
|
|
76
93
|
"""
|
|
77
94
|
clip a variable between fixed values such that min_val<=output<=max_val
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Dict, List, Tuple
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit import GradientPTQConfigV2
|
|
17
|
+
from model_compression_toolkit.gptq import GradientPTQConfigV2
|
|
18
18
|
from model_compression_toolkit.core import common
|
|
19
19
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
20
20
|
from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \
|
|
@@ -61,7 +61,7 @@ def quantization_builder(n: common.BaseNode,
|
|
|
61
61
|
fw_info=DEFAULT_KERAS_INFO)
|
|
62
62
|
|
|
63
63
|
weights_quantizers.update({kernel_attribute: quantizer_class(get_trainable_quantizer_weights_config(n),
|
|
64
|
-
**gptq_config.
|
|
64
|
+
**gptq_config.gptq_quantizer_params_override)})
|
|
65
65
|
|
|
66
66
|
activation_quantizers = []
|
|
67
67
|
if n.is_activation_quantization_enabled():
|