mct-nightly 1.8.0.8042023.post345__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +4 -3
- {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +285 -277
- model_compression_toolkit/__init__.py +9 -32
- model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
- model_compression_toolkit/core/__init__.py +14 -0
- model_compression_toolkit/core/analyzer.py +3 -2
- model_compression_toolkit/core/common/__init__.py +0 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +1 -8
- model_compression_toolkit/core/common/framework_info.py +1 -1
- model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
- model_compression_toolkit/core/common/graph/base_graph.py +2 -2
- model_compression_toolkit/core/common/graph/base_node.py +57 -1
- model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
- model_compression_toolkit/core/common/memory_computation.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/model_collector.py +2 -2
- model_compression_toolkit/core/common/model_validation.py +1 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -1
- model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
- model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
- model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
- model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
- model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
- model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +2 -2
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/keras/constants.py +0 -7
- model_compression_toolkit/core/keras/default_framework_info.py +3 -3
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/keras/keras_implementation.py +2 -10
- model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
- model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
- model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +1 -1
- model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +2 -2
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/constants.py +0 -6
- model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
- model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
- model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/runner.py +7 -7
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +3 -2
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +5 -3
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +3 -2
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +4 -3
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +2 -2
- model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
- model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
- model_compression_toolkit/gptq/common/gptq_training.py +2 -1
- model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
- model_compression_toolkit/gptq/keras/gptq_training.py +5 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +27 -20
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -5
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +2 -2
- model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
- model_compression_toolkit/gptq/pytorch/gptq_training.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +13 -13
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -3
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +8 -3
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -2
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +2 -2
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +9 -11
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +2 -2
- model_compression_toolkit/gptq/runner.py +3 -2
- model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +11 -12
- model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
- model_compression_toolkit/ptq/__init__.py +3 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
- model_compression_toolkit/qat/__init__.py +4 -0
- model_compression_toolkit/qat/common/__init__.py +1 -2
- model_compression_toolkit/qat/common/qat_config.py +3 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +13 -11
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +9 -9
- model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +4 -3
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +7 -5
- model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +9 -9
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -6
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +2 -2
- model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
- model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
- model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
- model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
- {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +0 -0
- {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
- /model_compression_toolkit/{core/common/logger.py → logger.py} +0 -0
- /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
|
@@ -16,21 +16,19 @@
|
|
|
16
16
|
from typing import Callable, Tuple
|
|
17
17
|
from packaging import version
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit.
|
|
20
|
-
from model_compression_toolkit.
|
|
21
|
-
from model_compression_toolkit.core.common.constants import TENSORFLOW
|
|
19
|
+
from model_compression_toolkit.logger import Logger
|
|
20
|
+
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
|
|
22
21
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
23
22
|
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
|
|
24
23
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
25
24
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
26
|
-
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import
|
|
27
|
-
|
|
28
|
-
from model_compression_toolkit import CoreConfig
|
|
25
|
+
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfigV2
|
|
26
|
+
from model_compression_toolkit.core import CoreConfig
|
|
29
27
|
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
|
|
30
28
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
|
31
29
|
from model_compression_toolkit.core.exporter import export_model
|
|
32
30
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
33
|
-
from model_compression_toolkit.
|
|
31
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
34
32
|
|
|
35
33
|
LR_DEFAULT = 0.15
|
|
36
34
|
LR_REST_DEFAULT = 1e-4
|
|
@@ -38,14 +36,14 @@ LR_BIAS_DEFAULT = 1e-4
|
|
|
38
36
|
LR_QUANTIZATION_PARAM_DEFAULT = 1e-3
|
|
39
37
|
GPTQ_MOMENTUM = 0.9
|
|
40
38
|
|
|
41
|
-
if
|
|
39
|
+
if FOUND_TF:
|
|
42
40
|
import tensorflow as tf
|
|
43
41
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
44
|
-
from model_compression_toolkit.
|
|
42
|
+
from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation
|
|
45
43
|
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
|
46
44
|
from tensorflow.keras.models import Model
|
|
47
45
|
from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss
|
|
48
|
-
from model_compression_toolkit.
|
|
46
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
49
47
|
from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
|
|
50
48
|
from model_compression_toolkit import get_target_platform_capabilities
|
|
51
49
|
|
|
@@ -62,7 +60,8 @@ if common.constants.FOUND_TF:
|
|
|
62
60
|
optimizer: OptimizerV2 = tf.keras.optimizers.Adam(learning_rate=LR_DEFAULT),
|
|
63
61
|
optimizer_rest: OptimizerV2 = tf.keras.optimizers.Adam(learning_rate=LR_REST_DEFAULT),
|
|
64
62
|
loss: Callable = GPTQMultipleTensorsLoss(),
|
|
65
|
-
log_function: Callable = None
|
|
63
|
+
log_function: Callable = None,
|
|
64
|
+
use_hessian_based_weights: bool = True) -> GradientPTQConfigV2:
|
|
66
65
|
"""
|
|
67
66
|
Create a GradientPTQConfigV2 instance for Keras models.
|
|
68
67
|
|
|
@@ -72,6 +71,7 @@ if common.constants.FOUND_TF:
|
|
|
72
71
|
optimizer_rest (OptimizerV2): Keras optimizer to use for fine-tuning of the bias variable.
|
|
73
72
|
loss (Callable): loss to use during fine-tuning. should accept 4 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors, the 3rd is a list of quantized weights and the 4th is a list of float weights.
|
|
74
73
|
log_function (Callable): Function to log information about the gptq process.
|
|
74
|
+
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
|
|
75
75
|
|
|
76
76
|
returns:
|
|
77
77
|
a GradientPTQConfigV2 object to use when fine-tuning the quantized model using gptq.
|
|
@@ -94,9 +94,16 @@ if common.constants.FOUND_TF:
|
|
|
94
94
|
The configuration can be passed to :func:`~model_compression_toolkit.keras_post_training_quantization` in order to quantize a keras model using gptq.
|
|
95
95
|
|
|
96
96
|
"""
|
|
97
|
-
bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT,
|
|
98
|
-
|
|
99
|
-
|
|
97
|
+
bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT,
|
|
98
|
+
momentum=GPTQ_MOMENTUM)
|
|
99
|
+
return GradientPTQConfigV2(n_epochs,
|
|
100
|
+
optimizer,
|
|
101
|
+
optimizer_rest=optimizer_rest,
|
|
102
|
+
loss=loss,
|
|
103
|
+
log_function=log_function,
|
|
104
|
+
train_bias=True,
|
|
105
|
+
optimizer_bias=bias_optimizer,
|
|
106
|
+
use_hessian_based_weights=use_hessian_based_weights)
|
|
100
107
|
|
|
101
108
|
|
|
102
109
|
def keras_gradient_post_training_quantization_experimental(in_model: Model,
|
|
@@ -158,20 +165,20 @@ if common.constants.FOUND_TF:
|
|
|
158
165
|
|
|
159
166
|
Create an MCT core config, containing the quantization configuration:
|
|
160
167
|
|
|
161
|
-
>>> config = mct.CoreConfig()
|
|
168
|
+
>>> config = mct.core.CoreConfig()
|
|
162
169
|
|
|
163
170
|
If mixed precision is desired, create an MCT core config with a mixed-precision configuration, to quantize a model
|
|
164
171
|
with different bitwidths for different layers.
|
|
165
172
|
The candidates bitwidth for quantization should be defined in the target platform model:
|
|
166
173
|
|
|
167
|
-
>>> config = mct.CoreConfig(mixed_precision_config=mct.MixedPrecisionQuantizationConfigV2(num_of_images=1))
|
|
174
|
+
>>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfigV2(num_of_images=1))
|
|
168
175
|
|
|
169
176
|
For mixed-precision set a target KPI object:
|
|
170
177
|
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
|
|
171
178
|
that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value,
|
|
172
179
|
while the bias will not):
|
|
173
180
|
|
|
174
|
-
>>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
181
|
+
>>> kpi = mct.core.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
175
182
|
|
|
176
183
|
Create GPTQ config:
|
|
177
184
|
|
|
@@ -187,15 +194,15 @@ if common.constants.FOUND_TF:
|
|
|
187
194
|
|
|
188
195
|
if core_config.mixed_precision_enable:
|
|
189
196
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
|
|
190
|
-
|
|
197
|
+
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
191
198
|
"MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
|
|
192
199
|
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
193
200
|
|
|
194
|
-
|
|
201
|
+
Logger.info("Using experimental mixed-precision quantization. "
|
|
195
202
|
"If you encounter an issue please file a bug.")
|
|
196
203
|
tb_w = _init_tensorboard_writer(fw_info)
|
|
197
204
|
|
|
198
|
-
fw_impl =
|
|
205
|
+
fw_impl = GPTQKerasImplemantation()
|
|
199
206
|
|
|
200
207
|
tg, bit_widths_config = core_runner(in_model=in_model,
|
|
201
208
|
representative_data_gen=representative_data_gen,
|
|
@@ -15,3 +15,4 @@
|
|
|
15
15
|
|
|
16
16
|
import model_compression_toolkit.gptq.keras.quantizer.ste_rounding.symmetric_ste
|
|
17
17
|
import model_compression_toolkit.gptq.keras.quantizer.soft_rounding.symmetric_soft_quantizer
|
|
18
|
+
import model_compression_toolkit.gptq.keras.quantizer.soft_rounding.uniform_soft_quantizer
|
|
@@ -15,8 +15,8 @@
|
|
|
15
15
|
from abc import abstractmethod
|
|
16
16
|
from typing import Union, Dict, List
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.constants import FOUND_TF
|
|
20
20
|
from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
|
|
21
21
|
|
|
22
22
|
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
import tensorflow as tf
|
|
17
|
-
from model_compression_toolkit.
|
|
17
|
+
from model_compression_toolkit.constants import MIN_THRESHOLD
|
|
18
18
|
from typing import Tuple
|
|
19
19
|
|
|
20
20
|
|
|
@@ -26,6 +26,14 @@ def ste_ceil(x: tf.Tensor) -> tf.Tensor:
|
|
|
26
26
|
return error + x
|
|
27
27
|
|
|
28
28
|
|
|
29
|
+
def ste_floor(x: tf.Tensor) -> tf.Tensor:
|
|
30
|
+
"""
|
|
31
|
+
Return the floor values of a tensor.
|
|
32
|
+
"""
|
|
33
|
+
error = tf.stop_gradient(tf.math.floor(x) - x)
|
|
34
|
+
return error + x
|
|
35
|
+
|
|
36
|
+
|
|
29
37
|
def safe_log(x: tf.Tensor, eps: float) -> tf.Tensor:
|
|
30
38
|
"""
|
|
31
39
|
Computes log function of x unless x is smaller than some small value, so the log function would not fail.
|
|
@@ -72,6 +80,15 @@ def calculate_delta(max_tensor: tf.Tensor,
|
|
|
72
80
|
return max_tensor / (2 ** (num_bits - int(signed)))
|
|
73
81
|
|
|
74
82
|
|
|
83
|
+
def calculate_delta_uniform(min_tensor: tf.Tensor,
|
|
84
|
+
max_tensor: tf.Tensor,
|
|
85
|
+
num_bits: int) -> tf.Tensor:
|
|
86
|
+
"""
|
|
87
|
+
Compute the step size for the uniform quantization.
|
|
88
|
+
"""
|
|
89
|
+
return (max_tensor-min_tensor) / (2 ** num_bits - 1)
|
|
90
|
+
|
|
91
|
+
|
|
75
92
|
def ste_clip(x: [tf.Tensor, tf.Variable], max_val=1, min_val=None) -> tf.Tensor:
|
|
76
93
|
"""
|
|
77
94
|
clip a variable between fixed values such that min_val<=output<=max_val
|
|
@@ -77,7 +77,7 @@ class SoftQuantizerRegularization:
|
|
|
77
77
|
# Initializing the temperature decay according to the number of expected gradient steps
|
|
78
78
|
self.linear_decay = LinearTempDecay(total_gradient_steps)
|
|
79
79
|
|
|
80
|
-
self.count_iter = 0
|
|
80
|
+
self.count_iter = tf.Variable(0.)
|
|
81
81
|
|
|
82
82
|
|
|
83
83
|
def __call__(self, model: Model, entropy_reg: float):
|
|
@@ -90,16 +90,14 @@ class SoftQuantizerRegularization:
|
|
|
90
90
|
|
|
91
91
|
Returns: Regularization value.
|
|
92
92
|
"""
|
|
93
|
-
|
|
94
93
|
soft_reg_aux: List[tf.Tensor] = []
|
|
94
|
+
b = self.linear_decay(self.count_iter.value())
|
|
95
95
|
for layer in model.layers:
|
|
96
96
|
if isinstance(layer, KerasQuantizationWrapper):
|
|
97
97
|
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
|
98
98
|
fw_info=DEFAULT_KERAS_INFO)
|
|
99
99
|
|
|
100
100
|
st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
|
|
101
|
-
b = self.linear_decay(self.count_iter)
|
|
102
|
-
|
|
103
101
|
soft_reg_aux.append(tf.reduce_sum(1 - tf.pow(tf.math.abs(st - .5) * 2, b)))
|
|
104
102
|
|
|
105
103
|
reg = 0
|
|
@@ -107,6 +105,6 @@ class SoftQuantizerRegularization:
|
|
|
107
105
|
for sq in soft_reg_aux:
|
|
108
106
|
reg += sq
|
|
109
107
|
|
|
110
|
-
self.count_iter
|
|
108
|
+
self.count_iter.assign_add(1.0)
|
|
111
109
|
|
|
112
110
|
return entropy_reg * reg
|
|
@@ -19,12 +19,12 @@ import numpy as np
|
|
|
19
19
|
from model_compression_toolkit.gptq import RoundingType
|
|
20
20
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
21
21
|
from model_compression_toolkit.core.common import max_power_of_two
|
|
22
|
-
from model_compression_toolkit.
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
23
23
|
from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
|
|
24
24
|
SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
|
|
25
25
|
from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
|
|
26
26
|
from typing import Dict, Any
|
|
27
|
-
from model_compression_toolkit.
|
|
27
|
+
from model_compression_toolkit.constants import THRESHOLD, MIN_THRESHOLD
|
|
28
28
|
from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
|
|
29
29
|
from model_compression_toolkit.gptq.keras.quantizer.quant_utils import power_of_two_max, clip, calculate_delta
|
|
30
30
|
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import tensorflow as tf
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from model_compression_toolkit.gptq import RoundingType
|
|
20
|
+
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
23
|
+
from model_compression_toolkit.gptq.common.gptq_constants import \
|
|
24
|
+
SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
|
|
25
|
+
from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
|
|
26
|
+
from typing import Dict, Any
|
|
27
|
+
from model_compression_toolkit.constants import RANGE_MIN, RANGE_MAX
|
|
28
|
+
from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
|
|
29
|
+
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
|
|
30
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
|
|
31
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
|
|
32
|
+
get_threshold_reshape_shape
|
|
33
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def soft_rounding_uniform_quantizer(input_tensor: tf.Tensor,
|
|
37
|
+
auxvar_tensor: tf.Variable,
|
|
38
|
+
min_tensor: tf.Tensor,
|
|
39
|
+
max_tensor: tf.Tensor,
|
|
40
|
+
num_bits: int) -> tf.Tensor:
|
|
41
|
+
"""
|
|
42
|
+
Quantize a tensor uniformly for GPTQ quantizers.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
|
|
46
|
+
auxvar_tensor: Tensor that manifests the bit shift of the quantized weights due to gptq training.
|
|
47
|
+
min_tensor: Tensor with values to compute the min threshold.
|
|
48
|
+
max_tensor: Tensor with values to compute the max threshold.
|
|
49
|
+
num_bits: Num of bits to use.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
A quantized tensor.
|
|
53
|
+
"""
|
|
54
|
+
# adjusts the quantization range so the quantization grid includes zero.
|
|
55
|
+
min_range, max_range = qutils.fix_range_to_include_zero(min_tensor, max_tensor, num_bits)
|
|
56
|
+
delta = qutils.calculate_delta_uniform(min_range, max_range, num_bits)
|
|
57
|
+
input_tensor_int = qutils.ste_floor((input_tensor - min_range) / delta)
|
|
58
|
+
tensor_q = input_tensor_int + auxvar_tensor
|
|
59
|
+
return delta * qutils.ste_clip(tensor_q,
|
|
60
|
+
min_val=0,
|
|
61
|
+
max_val=2 ** num_bits - 1) + min_range
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
|
|
65
|
+
quantization_method=[QuantizationMethod.UNIFORM],
|
|
66
|
+
quantizer_type=RoundingType.SoftQuantizer)
|
|
67
|
+
class UniformSoftRoundingGPTQ(BaseKerasGPTQTrainableQuantizer):
|
|
68
|
+
"""
|
|
69
|
+
Trainable uniform quantizer to optimize the rounding of the quantized values using a soft quantization method.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(self,
|
|
73
|
+
quantization_config: TrainableQuantizerWeightsConfig,
|
|
74
|
+
quantization_parameter_learning: bool = False):
|
|
75
|
+
"""
|
|
76
|
+
Initialize a UniformSoftRoundingGPTQ object with parameters to use
|
|
77
|
+
for the quantization.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
quantization_config: Trainable weight quantizer config.
|
|
81
|
+
quantization_parameter_learning: Whether to train the quantization threshold.
|
|
82
|
+
"""
|
|
83
|
+
super().__init__(quantization_config)
|
|
84
|
+
self.num_bits = quantization_config.weights_n_bits
|
|
85
|
+
self.per_channel = quantization_config.weights_per_channel_threshold
|
|
86
|
+
|
|
87
|
+
self.min_values = quantization_config.weights_quantization_params[RANGE_MIN]
|
|
88
|
+
self.max_values = quantization_config.weights_quantization_params[RANGE_MAX]
|
|
89
|
+
|
|
90
|
+
self.quantization_axis = quantization_config.weights_channels_axis
|
|
91
|
+
assert quantization_parameter_learning is False, \
|
|
92
|
+
"Quantization parameters learning in UniformSoftRoundingGPTQ not implemented yet"
|
|
93
|
+
self.quantization_parameter_learning = quantization_parameter_learning
|
|
94
|
+
self.num_channels = self.min_values.shape[self.quantization_axis] if self.per_channel else 1
|
|
95
|
+
|
|
96
|
+
# gamma and zeta are stretch parameters for computing the rectified sigmoid function.
|
|
97
|
+
# See: https://arxiv.org/pdf/2004.10568.pdf
|
|
98
|
+
self.gamma = SOFT_ROUNDING_GAMMA
|
|
99
|
+
self.zeta = SOFT_ROUNDING_ZETA
|
|
100
|
+
|
|
101
|
+
def initialize_quantization(self,
|
|
102
|
+
tensor_shape: Any,
|
|
103
|
+
name: str,
|
|
104
|
+
layer: Any):
|
|
105
|
+
"""
|
|
106
|
+
Add quantizer parameters to the quantizer parameters dictionary
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
tensor_shape: tensor shape of the quantized tensor.
|
|
110
|
+
name: Tensor name.
|
|
111
|
+
layer: Layer to quantize.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
if self.per_channel:
|
|
115
|
+
reshape_shape = get_threshold_reshape_shape(tensor_shape,
|
|
116
|
+
quant_axis=self.quantization_axis,
|
|
117
|
+
quant_axis_dim=self.num_channels)
|
|
118
|
+
else:
|
|
119
|
+
reshape_shape = [self.num_channels]
|
|
120
|
+
|
|
121
|
+
min_tensor = layer.add_weight(
|
|
122
|
+
f"{name}_{FQ_MIN}",
|
|
123
|
+
shape=reshape_shape,
|
|
124
|
+
initializer=tf.keras.initializers.Constant(1.0),
|
|
125
|
+
trainable=False)
|
|
126
|
+
min_tensor.assign(self.min_values.reshape(reshape_shape))
|
|
127
|
+
|
|
128
|
+
max_tensor = layer.add_weight(
|
|
129
|
+
f"{name}_{FQ_MAX}",
|
|
130
|
+
shape=reshape_shape,
|
|
131
|
+
initializer=tf.keras.initializers.Constant(1.0),
|
|
132
|
+
trainable=False)
|
|
133
|
+
max_tensor.assign(self.max_values.reshape(reshape_shape))
|
|
134
|
+
|
|
135
|
+
w = getattr(layer.layer, name)
|
|
136
|
+
auxvar_tensor = layer.add_weight(
|
|
137
|
+
f"{name}_{AUXVAR}",
|
|
138
|
+
shape=list(w.shape),
|
|
139
|
+
initializer=tf.keras.initializers.Constant(0.0),
|
|
140
|
+
trainable=True)
|
|
141
|
+
|
|
142
|
+
w = layer.layer.depthwise_kernel if isinstance(layer.layer, (tf.keras.layers.DepthwiseConv2D,
|
|
143
|
+
tf.keras.layers.DepthwiseConv1D)) \
|
|
144
|
+
else layer.layer.kernel
|
|
145
|
+
delta = qutils.calculate_delta_uniform(min_tensor, max_tensor, self.num_bits)
|
|
146
|
+
w_clipped_normed = qutils.clip((w - min_tensor)/ delta, 0, 2 ** self.num_bits - 1)
|
|
147
|
+
rest = w_clipped_normed - tf.floor(w_clipped_normed) # rest of rounding [0, 1)
|
|
148
|
+
alpha = -qutils.safe_log((self.zeta - self.gamma) / (rest - self.gamma) - 1, 1e-16) # => sigmoid(alpha) = rest
|
|
149
|
+
auxvar_tensor.assign(alpha)
|
|
150
|
+
|
|
151
|
+
# Add quantization variables
|
|
152
|
+
self.add_quantizer_variable(AUXVAR, auxvar_tensor, VariableGroup.WEIGHTS)
|
|
153
|
+
self.add_quantizer_variable(RANGE_MIN, min_tensor, VariableGroup.QPARAMS)
|
|
154
|
+
self.add_quantizer_variable(RANGE_MAX, max_tensor, VariableGroup.QPARAMS)
|
|
155
|
+
|
|
156
|
+
def get_soft_targets(self) -> tf.Tensor:
|
|
157
|
+
"""
|
|
158
|
+
Computes the rectified sigmoid function for the quantization target parameters.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
A tensor with the soft rounding targets values.
|
|
162
|
+
|
|
163
|
+
"""
|
|
164
|
+
return qutils.clip(
|
|
165
|
+
tf.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma, 1, 0)
|
|
166
|
+
|
|
167
|
+
def __call__(self,
|
|
168
|
+
inputs: tf.Tensor,
|
|
169
|
+
training: bool):
|
|
170
|
+
"""
|
|
171
|
+
Quantize a tensor.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
inputs: Input tensor to quantize.
|
|
175
|
+
training: Whether the graph is in training mode.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
The quantized tensor.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
min_tensor = self.get_quantizer_variable(RANGE_MIN)
|
|
182
|
+
max_tensor = self.get_quantizer_variable(RANGE_MAX)
|
|
183
|
+
|
|
184
|
+
#####################################################
|
|
185
|
+
# Soft Rounding
|
|
186
|
+
#####################################################
|
|
187
|
+
aux_var = self.get_soft_targets()
|
|
188
|
+
if not training:
|
|
189
|
+
aux_var = tf.cast(tf.math.greater_equal(aux_var, 0.5), tf.float32)
|
|
190
|
+
|
|
191
|
+
if self.per_channel:
|
|
192
|
+
reshape_shape = get_threshold_reshape_shape(inputs.shape,
|
|
193
|
+
quant_axis=self.quantization_axis,
|
|
194
|
+
quant_axis_dim=-1)
|
|
195
|
+
|
|
196
|
+
#####################################################
|
|
197
|
+
# Quantized Input
|
|
198
|
+
#####################################################
|
|
199
|
+
q_tensor = soft_rounding_uniform_quantizer(input_tensor=inputs,
|
|
200
|
+
auxvar_tensor=aux_var,
|
|
201
|
+
min_tensor=tf.reshape(min_tensor, reshape_shape),
|
|
202
|
+
max_tensor=tf.reshape(max_tensor, reshape_shape),
|
|
203
|
+
num_bits=self.num_bits)
|
|
204
|
+
|
|
205
|
+
else:
|
|
206
|
+
q_tensor = soft_rounding_uniform_quantizer(input_tensor=inputs,
|
|
207
|
+
auxvar_tensor=aux_var,
|
|
208
|
+
min_tensor=min_tensor,
|
|
209
|
+
max_tensor=max_tensor,
|
|
210
|
+
num_bits=self.num_bits)
|
|
211
|
+
|
|
212
|
+
return q_tensor
|
|
213
|
+
|
|
214
|
+
def get_quant_config(self) -> Dict[str, np.ndarray]:
|
|
215
|
+
"""
|
|
216
|
+
Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
|
|
220
|
+
Keys must match NodeQuantizationConfig attributes
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
return {RANGE_MIN: self.get_quantizer_variable(RANGE_MIN).numpy(),
|
|
224
|
+
RANGE_MAX: self.get_quantizer_variable(RANGE_MAX).numpy()}
|
|
@@ -20,10 +20,10 @@ import tensorflow as tf
|
|
|
20
20
|
|
|
21
21
|
from model_compression_toolkit.gptq import RoundingType
|
|
22
22
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
23
|
-
from model_compression_toolkit.
|
|
23
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
24
24
|
from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD
|
|
25
25
|
from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
|
|
26
|
-
from model_compression_toolkit.
|
|
26
|
+
from model_compression_toolkit.constants import THRESHOLD
|
|
27
27
|
from model_compression_toolkit.core.common.defaultdict import DefaultDict
|
|
28
28
|
from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
|
|
29
29
|
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Type
|
|
17
|
+
|
|
18
|
+
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
|
19
|
+
from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
|
|
20
|
+
from model_compression_toolkit.gptq.pytorch.gptq_training import PytorchGPTQTrainer
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class GPTQPytorchImplemantation(GPTQFrameworkImplemantation, PytorchImplementation):
|
|
24
|
+
|
|
25
|
+
def get_gptq_trainer_obj(self) -> Type[PytorchGPTQTrainer]:
|
|
26
|
+
"""
|
|
27
|
+
Returns: Pytorch object of GPTQTrainer
|
|
28
|
+
"""
|
|
29
|
+
return PytorchGPTQTrainer
|
|
@@ -19,7 +19,7 @@ from torch.nn import Module
|
|
|
19
19
|
from tqdm import tqdm
|
|
20
20
|
import copy
|
|
21
21
|
import torch
|
|
22
|
-
from model_compression_toolkit.
|
|
22
|
+
from model_compression_toolkit.logger import Logger
|
|
23
23
|
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
|
|
24
24
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
25
25
|
from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
|
|
@@ -14,18 +14,18 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Callable
|
|
16
16
|
from model_compression_toolkit.core import common
|
|
17
|
-
from model_compression_toolkit.
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
17
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.constants import PYTORCH
|
|
20
20
|
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
|
|
21
|
-
from model_compression_toolkit.
|
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
|
22
22
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
23
23
|
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
|
|
24
24
|
from model_compression_toolkit.gptq.keras.quantization_facade import GPTQ_MOMENTUM
|
|
25
25
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
|
26
26
|
from model_compression_toolkit.core.exporter import export_model
|
|
27
27
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
28
|
-
from model_compression_toolkit import CoreConfig
|
|
28
|
+
from model_compression_toolkit.core import CoreConfig
|
|
29
29
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
30
30
|
MixedPrecisionQuantizationConfigV2
|
|
31
31
|
|
|
@@ -36,8 +36,8 @@ LR_QUANTIZATION_PARAM_DEFAULT = 1e-4
|
|
|
36
36
|
|
|
37
37
|
if FOUND_TORCH:
|
|
38
38
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
39
|
-
from model_compression_toolkit.
|
|
40
|
-
from model_compression_toolkit.
|
|
39
|
+
from model_compression_toolkit.gptq.pytorch.gptq_pytorch_implementation import GPTQPytorchImplemantation
|
|
40
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
41
41
|
from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss
|
|
42
42
|
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
|
|
43
43
|
import torch
|
|
@@ -118,7 +118,7 @@ if FOUND_TORCH:
|
|
|
118
118
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
119
119
|
gptq_config (GradientPTQConfigV2): Configuration for using gptq (e.g. optimizer).
|
|
120
120
|
gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
|
|
121
|
-
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
121
|
+
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
122
122
|
new_experimental_exporter (bool): Whether exporting the quantized model using new exporter or not (in progress. Avoiding it for now is recommended).
|
|
123
123
|
|
|
124
124
|
Returns:
|
|
@@ -142,7 +142,7 @@ if FOUND_TORCH:
|
|
|
142
142
|
|
|
143
143
|
Create MCT core configurations with number of calibration iterations set to 1:
|
|
144
144
|
|
|
145
|
-
>>> config = mct.CoreConfig()
|
|
145
|
+
>>> config = mct.core.CoreConfig()
|
|
146
146
|
|
|
147
147
|
Pass the module, the representative dataset generator and the configuration (optional) to get a quantized module
|
|
148
148
|
|
|
@@ -152,16 +152,16 @@ if FOUND_TORCH:
|
|
|
152
152
|
|
|
153
153
|
if core_config.mixed_precision_enable:
|
|
154
154
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
|
|
155
|
-
|
|
155
|
+
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
156
156
|
"MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
|
|
157
157
|
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
158
158
|
|
|
159
|
-
|
|
159
|
+
Logger.info("Using experimental mixed-precision quantization. "
|
|
160
160
|
"If you encounter an issue please file a bug.")
|
|
161
161
|
|
|
162
162
|
tb_w = _init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
|
|
163
163
|
|
|
164
|
-
fw_impl =
|
|
164
|
+
fw_impl = GPTQPytorchImplemantation()
|
|
165
165
|
|
|
166
166
|
# ---------------------- #
|
|
167
167
|
# Core Runner
|
|
@@ -192,7 +192,7 @@ if FOUND_TORCH:
|
|
|
192
192
|
Logger.warning('Using new experimental exported models. '
|
|
193
193
|
'Please do not use unless you are familiar with what you are doing')
|
|
194
194
|
|
|
195
|
-
return
|
|
195
|
+
return get_exportable_pytorch_model(graph_gptq)
|
|
196
196
|
|
|
197
197
|
return export_model(graph_gptq,
|
|
198
198
|
DEFAULT_PYTORCH_INFO,
|
|
@@ -13,10 +13,10 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from abc import abstractmethod
|
|
16
|
-
from typing import Union, Dict
|
|
16
|
+
from typing import Union, Dict
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
|
20
20
|
from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
|
|
21
21
|
|
|
22
22
|
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
@@ -14,9 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Union, Tuple
|
|
16
16
|
import torch
|
|
17
|
-
from
|
|
18
|
-
from model_compression_toolkit.core.common.constants import MIN_THRESHOLD
|
|
19
|
-
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
17
|
+
from model_compression_toolkit.constants import MIN_THRESHOLD
|
|
20
18
|
|
|
21
19
|
|
|
22
20
|
def power_of_two_max(max_tensor: torch.Tensor) -> torch.Tensor:
|
|
@@ -51,6 +49,13 @@ def ste_ceil(x: torch.Tensor) -> torch.Tensor:
|
|
|
51
49
|
return (torch.ceil(x) - x).detach() + x
|
|
52
50
|
|
|
53
51
|
|
|
52
|
+
def ste_floor(x: torch.Tensor) -> torch.Tensor:
|
|
53
|
+
"""
|
|
54
|
+
Return the floor values of a tensor.
|
|
55
|
+
"""
|
|
56
|
+
return (torch.floor(x) - x).detach() + x
|
|
57
|
+
|
|
58
|
+
|
|
54
59
|
def ste_round(x: torch.Tensor) -> torch.Tensor:
|
|
55
60
|
"""
|
|
56
61
|
Calculate the rounded values of a tensor
|
|
@@ -95,14 +95,13 @@ class SoftQuantizerRegularization:
|
|
|
95
95
|
"""
|
|
96
96
|
|
|
97
97
|
soft_reg_aux: List[torch.Tensor] = []
|
|
98
|
+
b = self.linear_decay(self.count_iter)
|
|
98
99
|
for layer in model.modules():
|
|
99
100
|
if isinstance(layer, PytorchQuantizationWrapper):
|
|
100
101
|
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
|
101
102
|
fw_info=DEFAULT_PYTORCH_INFO)
|
|
102
103
|
|
|
103
104
|
st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
|
|
104
|
-
b = self.linear_decay(self.count_iter)
|
|
105
|
-
|
|
106
105
|
soft_reg_aux.append((1 - torch.pow(torch.abs(st - .5) * 2, b)).sum())
|
|
107
106
|
|
|
108
107
|
reg = 0
|