mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +12 -41
- model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
- model_compression_toolkit/core/__init__.py +14 -0
- model_compression_toolkit/core/analyzer.py +3 -2
- model_compression_toolkit/core/common/__init__.py +0 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +1 -8
- model_compression_toolkit/core/common/framework_info.py +1 -1
- model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
- model_compression_toolkit/core/common/graph/base_graph.py +2 -2
- model_compression_toolkit/core/common/graph/base_node.py +57 -1
- model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
- model_compression_toolkit/core/common/memory_computation.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/model_collector.py +2 -2
- model_compression_toolkit/core/common/model_validation.py +1 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -1
- model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
- model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
- model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
- model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
- model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
- model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/keras/constants.py +0 -7
- model_compression_toolkit/core/keras/default_framework_info.py +3 -3
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/keras/keras_implementation.py +2 -10
- model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
- model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
- model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +1 -1
- model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/constants.py +4 -6
- model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
- model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
- model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/runner.py +7 -7
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -104
- model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
- model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
- model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
- model_compression_toolkit/gptq/common/gptq_training.py +30 -39
- model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
- model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
- model_compression_toolkit/gptq/keras/graph_info.py +8 -33
- model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
- model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
- model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
- model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
- model_compression_toolkit/gptq/runner.py +3 -2
- model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
- model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
- model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
- model_compression_toolkit/ptq/__init__.py +3 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
- model_compression_toolkit/qat/__init__.py +4 -0
- model_compression_toolkit/qat/common/__init__.py +1 -2
- model_compression_toolkit/qat/common/qat_config.py +3 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
- model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
- model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
- model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
- model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
- model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
- model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
- /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
|
@@ -13,23 +13,24 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from typing import Dict, Any
|
|
16
|
+
from typing import Dict, Any
|
|
17
17
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
import tensorflow as tf
|
|
20
20
|
|
|
21
|
-
from model_compression_toolkit import RoundingType
|
|
21
|
+
from model_compression_toolkit.gptq import RoundingType
|
|
22
22
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
23
|
-
from model_compression_toolkit.
|
|
24
|
-
from model_compression_toolkit.gptq.common.gptq_constants import
|
|
23
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
24
|
+
from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD
|
|
25
25
|
from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
|
|
26
|
-
from model_compression_toolkit.
|
|
26
|
+
from model_compression_toolkit.constants import THRESHOLD
|
|
27
27
|
from model_compression_toolkit.core.common.defaultdict import DefaultDict
|
|
28
28
|
from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer
|
|
29
29
|
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
|
|
30
30
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
|
|
31
31
|
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
|
|
32
32
|
get_threshold_reshape_shape
|
|
33
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
33
34
|
|
|
34
35
|
|
|
35
36
|
def pertubation_symmetric_quantizer(input_tensor: tf.Tensor,
|
|
@@ -96,30 +97,20 @@ class STEWeightGPTQQuantizer(BaseKerasGPTQTrainableQuantizer):
|
|
|
96
97
|
self.quantization_axis = quantization_config.weights_channels_axis
|
|
97
98
|
self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
|
|
98
99
|
self.max_lsbs_change = max_lsbs_change_map.get(self.num_bits)
|
|
99
|
-
self.quantizer_parameters = {}
|
|
100
100
|
|
|
101
101
|
def initialize_quantization(self,
|
|
102
102
|
tensor_shape: Any,
|
|
103
103
|
name: str,
|
|
104
|
-
layer: Any)
|
|
104
|
+
layer: Any):
|
|
105
105
|
"""
|
|
106
|
-
|
|
106
|
+
Add quantizer parameters to the quantizer parameters dictionary
|
|
107
107
|
|
|
108
108
|
Args:
|
|
109
109
|
tensor_shape: tensor shape of the quantized tensor.
|
|
110
110
|
name: Tensor name.
|
|
111
111
|
layer: Layer to quantize.
|
|
112
|
-
|
|
113
|
-
Returns:
|
|
114
|
-
Dictionary of parameters names to the variables.
|
|
115
112
|
"""
|
|
116
113
|
|
|
117
|
-
ar_iter = layer.add_weight(
|
|
118
|
-
f"{name}_{GPTQ_ITER}",
|
|
119
|
-
shape=(),
|
|
120
|
-
initializer=tf.keras.initializers.Constant(0.0),
|
|
121
|
-
trainable=False)
|
|
122
|
-
|
|
123
114
|
ptq_threshold_tensor = layer.add_weight(
|
|
124
115
|
f"{name}_{PTQ_THRESHOLD}",
|
|
125
116
|
shape=len(self.threshold_values) if self.per_channel else (),
|
|
@@ -135,10 +126,8 @@ class STEWeightGPTQQuantizer(BaseKerasGPTQTrainableQuantizer):
|
|
|
135
126
|
trainable=True)
|
|
136
127
|
|
|
137
128
|
# save the quantizer added parameters for later calculations
|
|
138
|
-
self.
|
|
139
|
-
|
|
140
|
-
GPTQ_ITER: ar_iter}
|
|
141
|
-
return self.quantizer_parameters
|
|
129
|
+
self.add_quantizer_variable(PTQ_THRESHOLD, ptq_threshold_tensor, VariableGroup.QPARAMS)
|
|
130
|
+
self.add_quantizer_variable(AUXVAR, auxvar_tensor, VariableGroup.WEIGHTS)
|
|
142
131
|
|
|
143
132
|
def __call__(self,
|
|
144
133
|
inputs: tf.Tensor,
|
|
@@ -154,8 +143,8 @@ class STEWeightGPTQQuantizer(BaseKerasGPTQTrainableQuantizer):
|
|
|
154
143
|
The quantized tensor.
|
|
155
144
|
"""
|
|
156
145
|
|
|
157
|
-
auxvar = self.
|
|
158
|
-
ptq_threshold_tensor = self.
|
|
146
|
+
auxvar = self.get_quantizer_variable(AUXVAR)
|
|
147
|
+
ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
|
|
159
148
|
|
|
160
149
|
if self.per_channel:
|
|
161
150
|
reshape_shape = get_threshold_reshape_shape(inputs.shape,
|
|
@@ -178,25 +167,6 @@ class STEWeightGPTQQuantizer(BaseKerasGPTQTrainableQuantizer):
|
|
|
178
167
|
signed=True,
|
|
179
168
|
power_of_two=self.power_of_two)
|
|
180
169
|
|
|
181
|
-
def get_aux_variable(self) -> List[tf.Tensor]:
|
|
182
|
-
"""
|
|
183
|
-
This function return a list with the quantizer's quantization auxiliary variables.
|
|
184
|
-
|
|
185
|
-
Returns: A list with the quantization auxiliary variables.
|
|
186
|
-
|
|
187
|
-
"""
|
|
188
|
-
|
|
189
|
-
return [self.quantizer_parameters[AUXVAR]]
|
|
190
|
-
|
|
191
|
-
def get_quantization_variable(self) -> List[tf.Tensor]:
|
|
192
|
-
"""
|
|
193
|
-
This function return a list with the quantizer's quantization parameters variables.
|
|
194
|
-
|
|
195
|
-
Returns: A list with the quantization parameters.
|
|
196
|
-
|
|
197
|
-
"""
|
|
198
|
-
|
|
199
|
-
return [self.quantizer_parameters[PTQ_THRESHOLD]]
|
|
200
170
|
|
|
201
171
|
def get_quant_config(self) -> Dict[str, np.ndarray]:
|
|
202
172
|
"""
|
|
@@ -207,5 +177,5 @@ class STEWeightGPTQQuantizer(BaseKerasGPTQTrainableQuantizer):
|
|
|
207
177
|
Keys must match NodeQuantizationConfig attributes
|
|
208
178
|
|
|
209
179
|
"""
|
|
210
|
-
old_threshold = self.
|
|
180
|
+
old_threshold = self.get_quantizer_variable(PTQ_THRESHOLD)
|
|
211
181
|
return {THRESHOLD: old_threshold.numpy().reshape(self.threshold_shape)}
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Type
|
|
17
|
+
|
|
18
|
+
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
|
19
|
+
from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
|
|
20
|
+
from model_compression_toolkit.gptq.pytorch.gptq_training import PytorchGPTQTrainer
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class GPTQPytorchImplemantation(GPTQFrameworkImplemantation, PytorchImplementation):
|
|
24
|
+
|
|
25
|
+
def get_gptq_trainer_obj(self) -> Type[PytorchGPTQTrainer]:
|
|
26
|
+
"""
|
|
27
|
+
Returns: Pytorch object of GPTQTrainer
|
|
28
|
+
"""
|
|
29
|
+
return PytorchGPTQTrainer
|
|
@@ -19,21 +19,21 @@ from torch.nn import Module
|
|
|
19
19
|
from tqdm import tqdm
|
|
20
20
|
import copy
|
|
21
21
|
import torch
|
|
22
|
-
from model_compression_toolkit.
|
|
22
|
+
from model_compression_toolkit.logger import Logger
|
|
23
23
|
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
|
|
24
24
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
25
25
|
from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
|
|
26
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
|
|
26
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
|
|
27
27
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
28
28
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
29
29
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
30
30
|
from model_compression_toolkit.core.pytorch.constants import BIAS
|
|
31
31
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, set_model, torch_tensor_to_numpy
|
|
32
32
|
from model_compression_toolkit.gptq.pytorch.graph_info import get_gptq_trainable_parameters, \
|
|
33
|
-
get_weights_for_loss
|
|
33
|
+
get_weights_for_loss
|
|
34
34
|
from model_compression_toolkit.gptq.pytorch.quantizer.quantization_builder import quantization_builder
|
|
35
|
-
from model_compression_toolkit.gptq.common.gptq_constants import REGULARIZATION_VALUES
|
|
36
35
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
36
|
+
from model_compression_toolkit.gptq.pytorch.quantizer.regularization_factory import get_regularization
|
|
37
37
|
from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
|
|
38
38
|
|
|
39
39
|
|
|
@@ -63,7 +63,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
63
63
|
fw_info: Framework information
|
|
64
64
|
representative_data_gen: Dataset to use for inputs of the models.
|
|
65
65
|
"""
|
|
66
|
-
super().__init__(graph_float, graph_quant, gptq_config, fw_impl, fw_info
|
|
66
|
+
super().__init__(graph_float, graph_quant, gptq_config, fw_impl, fw_info)
|
|
67
67
|
self.loss_list = []
|
|
68
68
|
self.input_scale = 1
|
|
69
69
|
if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
|
|
@@ -71,7 +71,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
71
71
|
else:
|
|
72
72
|
self.input_scale = self.gptq_user_info.input_scale
|
|
73
73
|
|
|
74
|
-
trainable_weights, trainable_bias, trainable_threshold
|
|
74
|
+
trainable_weights, trainable_bias, trainable_threshold = get_gptq_trainable_parameters(
|
|
75
75
|
self.fxp_model,
|
|
76
76
|
add_bias=self.gptq_config.train_bias)
|
|
77
77
|
|
|
@@ -86,7 +86,9 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
86
86
|
trainable_bias,
|
|
87
87
|
trainable_threshold)
|
|
88
88
|
|
|
89
|
-
self.weights_for_average_loss = to_torch_tensor(self.
|
|
89
|
+
self.weights_for_average_loss = to_torch_tensor(self.compute_hessian_based_weights(representative_data_gen))
|
|
90
|
+
|
|
91
|
+
self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
|
|
90
92
|
|
|
91
93
|
def _is_gptq_applicable(self,
|
|
92
94
|
node: BaseNode) -> bool:
|
|
@@ -184,9 +186,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
184
186
|
self.compare_points_std,
|
|
185
187
|
self.weights_for_average_loss)
|
|
186
188
|
|
|
187
|
-
reg_value = self.gptq_config.
|
|
188
|
-
self.fxp_model,
|
|
189
|
-
**{REGULARIZATION_VALUES: self._get_quantizer_regularization_values(self.gptq_config.rounding_type)})
|
|
189
|
+
reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor)
|
|
190
190
|
|
|
191
191
|
loss_value += reg_value
|
|
192
192
|
|
|
@@ -272,18 +272,3 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
272
272
|
if hasattr(layer.layer, BIAS):
|
|
273
273
|
bias = getattr(layer.layer, BIAS)
|
|
274
274
|
bias.requires_grad = self.gptq_config.train_bias
|
|
275
|
-
|
|
276
|
-
def _get_quantizer_regularization_values(self, rounding_type: RoundingType) -> List[torch.Tensor]:
|
|
277
|
-
"""
|
|
278
|
-
Mapping between a rounding type to its matching regularization method.
|
|
279
|
-
|
|
280
|
-
Args:
|
|
281
|
-
rounding_type: GPTQ rounding type.
|
|
282
|
-
|
|
283
|
-
Returns: A regularization computation method.
|
|
284
|
-
|
|
285
|
-
"""
|
|
286
|
-
if rounding_type == RoundingType.SoftQuantizer:
|
|
287
|
-
return get_soft_rounding_reg(self.fxp_model)
|
|
288
|
-
else:
|
|
289
|
-
return []
|
|
@@ -15,11 +15,11 @@
|
|
|
15
15
|
import torch
|
|
16
16
|
import torch.nn as nn
|
|
17
17
|
from typing import List
|
|
18
|
-
|
|
19
18
|
from model_compression_toolkit.core.pytorch.constants import BIAS
|
|
20
19
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
21
20
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
22
21
|
from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
|
|
22
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
def get_gptq_trainable_parameters(fxp_model: nn.Module,
|
|
@@ -39,21 +39,23 @@ def get_gptq_trainable_parameters(fxp_model: nn.Module,
|
|
|
39
39
|
trainable_aux_weights = nn.ParameterList()
|
|
40
40
|
trainable_threshold = nn.ParameterList()
|
|
41
41
|
trainable_bias = nn.ParameterList()
|
|
42
|
-
trainable_temperature = nn.ParameterList()
|
|
43
42
|
|
|
44
43
|
for layer in fxp_model.modules():
|
|
45
44
|
if isinstance(layer, PytorchQuantizationWrapper):
|
|
46
45
|
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
|
47
46
|
fw_info=DEFAULT_PYTORCH_INFO)
|
|
48
47
|
|
|
49
|
-
|
|
50
|
-
|
|
48
|
+
# collect trainable weights per quantizer
|
|
49
|
+
quantizer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.WEIGHTS)
|
|
50
|
+
quantizer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.QPARAMS)
|
|
51
|
+
trainable_aux_weights.extend(quantizer_trainable_weights)
|
|
52
|
+
trainable_threshold.extend(quantizer_trainable_threshold)
|
|
51
53
|
|
|
52
54
|
if add_bias and hasattr(layer.layer, BIAS):
|
|
53
55
|
bias = getattr(layer.layer, BIAS)
|
|
54
56
|
trainable_bias.append(bias)
|
|
55
57
|
|
|
56
|
-
return trainable_aux_weights, trainable_bias, trainable_threshold
|
|
58
|
+
return trainable_aux_weights, trainable_bias, trainable_threshold
|
|
57
59
|
|
|
58
60
|
|
|
59
61
|
def get_weights_for_loss(fxp_model: nn.Module) -> [List[nn.Parameter], List[torch.Tensor]]:
|
|
@@ -77,25 +79,3 @@ def get_weights_for_loss(fxp_model: nn.Module) -> [List[nn.Parameter], List[torc
|
|
|
77
79
|
fxp_weights_list.append(quantizer(training=False, inputs=quantizer_vars))
|
|
78
80
|
|
|
79
81
|
return flp_weights_list, fxp_weights_list
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
# TODO: this function need to move to location that is relevant only for soft quantizer -
|
|
83
|
-
# once deciding how to handle GPTQ quantizers regularization.
|
|
84
|
-
def get_soft_rounding_reg(fxp_model: nn.Module) -> List[torch.Tensor]:
|
|
85
|
-
"""
|
|
86
|
-
This function returns the soft quantizer regularization values for SoftRounding.
|
|
87
|
-
|
|
88
|
-
Args:
|
|
89
|
-
fxp_model: A model to be quantized with SoftRounding.
|
|
90
|
-
|
|
91
|
-
Returns: A list of tensors.
|
|
92
|
-
"""
|
|
93
|
-
|
|
94
|
-
soft_reg_aux: List[torch.Tensor] = []
|
|
95
|
-
for layer in fxp_model.modules():
|
|
96
|
-
if isinstance(layer, PytorchQuantizationWrapper):
|
|
97
|
-
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
|
98
|
-
fw_info=DEFAULT_PYTORCH_INFO)
|
|
99
|
-
|
|
100
|
-
soft_reg_aux.append(layer.weights_quantizers[kernel_attribute].get_regularization())
|
|
101
|
-
return soft_reg_aux
|
|
@@ -14,17 +14,18 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Callable
|
|
16
16
|
from model_compression_toolkit.core import common
|
|
17
|
-
from model_compression_toolkit.
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
20
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
|
|
21
|
-
from model_compression_toolkit.
|
|
17
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.constants import PYTORCH
|
|
20
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
|
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
|
22
22
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
23
23
|
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
|
|
24
|
+
from model_compression_toolkit.gptq.keras.quantization_facade import GPTQ_MOMENTUM
|
|
24
25
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
|
25
26
|
from model_compression_toolkit.core.exporter import export_model
|
|
26
27
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
27
|
-
from model_compression_toolkit import CoreConfig
|
|
28
|
+
from model_compression_toolkit.core import CoreConfig
|
|
28
29
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
29
30
|
MixedPrecisionQuantizationConfigV2
|
|
30
31
|
|
|
@@ -35,8 +36,8 @@ LR_QUANTIZATION_PARAM_DEFAULT = 1e-4
|
|
|
35
36
|
|
|
36
37
|
if FOUND_TORCH:
|
|
37
38
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
38
|
-
from model_compression_toolkit.
|
|
39
|
-
from model_compression_toolkit.
|
|
39
|
+
from model_compression_toolkit.gptq.pytorch.gptq_pytorch_implementation import GPTQPytorchImplemantation
|
|
40
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
40
41
|
from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss
|
|
41
42
|
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
|
|
42
43
|
import torch
|
|
@@ -71,33 +72,19 @@ if FOUND_TORCH:
|
|
|
71
72
|
Import MCT and Create a GradientPTQConfigV2 to run for 5 epochs:
|
|
72
73
|
|
|
73
74
|
>>> import model_compression_toolkit as mct
|
|
74
|
-
>>> gptq_conf = mct.get_pytorch_gptq_config(n_epochs=5)
|
|
75
|
+
>>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=5)
|
|
75
76
|
|
|
76
77
|
Other PyTorch optimizers can be passed with dummy params:
|
|
77
78
|
|
|
78
79
|
>>> import torch
|
|
79
|
-
>>> gptq_conf = mct.get_pytorch_gptq_config(n_epochs=3, optimizer=torch.optim.Adam([torch.Tensor(1)]))
|
|
80
|
+
>>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=3, optimizer=torch.optim.Adam([torch.Tensor(1)]))
|
|
80
81
|
|
|
81
82
|
The configuration can be passed to :func:`~model_compression_toolkit.pytorch_post_training_quantization` in order to quantize a pytorch model using gptq.
|
|
82
83
|
|
|
83
84
|
"""
|
|
84
|
-
bias_optimizer =
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
# - change default quantization_parameters_learning to True.
|
|
88
|
-
# - remove explicit rounding_type and quantizer_config (and let it use the default GradientPTQConfig).
|
|
89
|
-
return GradientPTQConfigV2(n_epochs,
|
|
90
|
-
optimizer,
|
|
91
|
-
optimizer_rest=optimizer_rest,
|
|
92
|
-
loss=loss,
|
|
93
|
-
log_function=log_function,
|
|
94
|
-
train_bias=True,
|
|
95
|
-
optimizer_quantization_parameter=optimizer_quantization_parameter,
|
|
96
|
-
optimizer_bias=bias_optimizer,
|
|
97
|
-
rounding_type=RoundingType.STE,
|
|
98
|
-
quantizer_config=GPTQQuantizerConfig(),
|
|
99
|
-
quantization_parameters_learning=False,
|
|
100
|
-
)
|
|
85
|
+
bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
|
|
86
|
+
return GradientPTQConfigV2(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
|
|
87
|
+
log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer)
|
|
101
88
|
|
|
102
89
|
|
|
103
90
|
def pytorch_gradient_post_training_quantization_experimental(model: Module,
|
|
@@ -131,7 +118,7 @@ if FOUND_TORCH:
|
|
|
131
118
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
132
119
|
gptq_config (GradientPTQConfigV2): Configuration for using gptq (e.g. optimizer).
|
|
133
120
|
gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
|
|
134
|
-
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
121
|
+
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
135
122
|
new_experimental_exporter (bool): Whether exporting the quantized model using new exporter or not (in progress. Avoiding it for now is recommended).
|
|
136
123
|
|
|
137
124
|
Returns:
|
|
@@ -155,26 +142,26 @@ if FOUND_TORCH:
|
|
|
155
142
|
|
|
156
143
|
Create MCT core configurations with number of calibration iterations set to 1:
|
|
157
144
|
|
|
158
|
-
>>> config = mct.CoreConfig()
|
|
145
|
+
>>> config = mct.core.CoreConfig()
|
|
159
146
|
|
|
160
147
|
Pass the module, the representative dataset generator and the configuration (optional) to get a quantized module
|
|
161
148
|
|
|
162
|
-
>>> quantized_module, quantization_info = mct.pytorch_gradient_post_training_quantization_experimental(module, repr_datagen, core_config=config, gptq_config=gptq_conf)
|
|
149
|
+
>>> quantized_module, quantization_info = mct.gptq.pytorch_gradient_post_training_quantization_experimental(module, repr_datagen, core_config=config, gptq_config=gptq_conf)
|
|
163
150
|
|
|
164
151
|
"""
|
|
165
152
|
|
|
166
153
|
if core_config.mixed_precision_enable:
|
|
167
154
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
|
|
168
|
-
|
|
155
|
+
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
169
156
|
"MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
|
|
170
157
|
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
171
158
|
|
|
172
|
-
|
|
159
|
+
Logger.info("Using experimental mixed-precision quantization. "
|
|
173
160
|
"If you encounter an issue please file a bug.")
|
|
174
161
|
|
|
175
162
|
tb_w = _init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
|
|
176
163
|
|
|
177
|
-
fw_impl =
|
|
164
|
+
fw_impl = GPTQPytorchImplemantation()
|
|
178
165
|
|
|
179
166
|
# ---------------------- #
|
|
180
167
|
# Core Runner
|
|
@@ -205,7 +192,7 @@ if FOUND_TORCH:
|
|
|
205
192
|
Logger.warning('Using new experimental exported models. '
|
|
206
193
|
'Please do not use unless you are familiar with what you are doing')
|
|
207
194
|
|
|
208
|
-
return
|
|
195
|
+
return get_exportable_pytorch_model(graph_gptq)
|
|
209
196
|
|
|
210
197
|
return export_model(graph_gptq,
|
|
211
198
|
DEFAULT_PYTORCH_INFO,
|
|
@@ -15,3 +15,4 @@
|
|
|
15
15
|
|
|
16
16
|
import model_compression_toolkit.gptq.pytorch.quantizer.ste_rounding.symmetric_ste
|
|
17
17
|
import model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.symmetric_soft_quantizer
|
|
18
|
+
import model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.uniform_soft_quantizer
|
|
@@ -13,10 +13,10 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from abc import abstractmethod
|
|
16
|
-
from typing import Union, Dict
|
|
16
|
+
from typing import Union, Dict
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
|
20
20
|
from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
|
|
21
21
|
|
|
22
22
|
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
@@ -71,26 +71,6 @@ if FOUND_TORCH:
|
|
|
71
71
|
|
|
72
72
|
return weights, quant_config, {}
|
|
73
73
|
|
|
74
|
-
def get_aux_variable(self) -> List[Tensor]:
|
|
75
|
-
"""
|
|
76
|
-
This function return a list with the quantizer's quantization auxiliary variables.
|
|
77
|
-
|
|
78
|
-
Returns: A list with the quantization auxiliary variables.
|
|
79
|
-
|
|
80
|
-
"""
|
|
81
|
-
|
|
82
|
-
return [] # pragma: no cover
|
|
83
|
-
|
|
84
|
-
def get_quantization_variable(self) -> List[Tensor]:
|
|
85
|
-
"""
|
|
86
|
-
This function return a list with the quantizer's quantization parameters variables.
|
|
87
|
-
|
|
88
|
-
Returns: A list with the quantization parameters.
|
|
89
|
-
|
|
90
|
-
"""
|
|
91
|
-
|
|
92
|
-
return [] # pragma: no cover
|
|
93
|
-
|
|
94
74
|
@abstractmethod
|
|
95
75
|
def get_quant_config(self):
|
|
96
76
|
"""
|
|
@@ -14,9 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Union, Tuple
|
|
16
16
|
import torch
|
|
17
|
-
from
|
|
18
|
-
from model_compression_toolkit.core.common.constants import MIN_THRESHOLD
|
|
19
|
-
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
17
|
+
from model_compression_toolkit.constants import MIN_THRESHOLD
|
|
20
18
|
|
|
21
19
|
|
|
22
20
|
def power_of_two_max(max_tensor: torch.Tensor) -> torch.Tensor:
|
|
@@ -30,11 +28,20 @@ def calculate_delta(max_tensor: torch.Tensor,
|
|
|
30
28
|
num_bits: int,
|
|
31
29
|
signed: bool) -> torch.Tensor:
|
|
32
30
|
"""
|
|
33
|
-
Compute the step size for the quantization.
|
|
31
|
+
Compute the step size for the symmetric quantization.
|
|
34
32
|
"""
|
|
35
33
|
return max_tensor / (2 ** (num_bits - int(signed)))
|
|
36
34
|
|
|
37
35
|
|
|
36
|
+
def calculate_delta_uniform(min_tensor: torch.Tensor,
|
|
37
|
+
max_tensor: torch.Tensor,
|
|
38
|
+
num_bits: int) -> torch.Tensor:
|
|
39
|
+
"""
|
|
40
|
+
Compute the step size for the uniform quantization.
|
|
41
|
+
"""
|
|
42
|
+
return (max_tensor-min_tensor) / (2 ** num_bits - 1)
|
|
43
|
+
|
|
44
|
+
|
|
38
45
|
def ste_ceil(x: torch.Tensor) -> torch.Tensor:
|
|
39
46
|
"""
|
|
40
47
|
Return the ceil values of a tensor.
|
|
@@ -42,6 +49,13 @@ def ste_ceil(x: torch.Tensor) -> torch.Tensor:
|
|
|
42
49
|
return (torch.ceil(x) - x).detach() + x
|
|
43
50
|
|
|
44
51
|
|
|
52
|
+
def ste_floor(x: torch.Tensor) -> torch.Tensor:
|
|
53
|
+
"""
|
|
54
|
+
Return the floor values of a tensor.
|
|
55
|
+
"""
|
|
56
|
+
return (torch.floor(x) - x).detach() + x
|
|
57
|
+
|
|
58
|
+
|
|
45
59
|
def ste_round(x: torch.Tensor) -> torch.Tensor:
|
|
46
60
|
"""
|
|
47
61
|
Calculate the rounded values of a tensor
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import List, Dict, Tuple
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit import GradientPTQConfigV2
|
|
17
|
+
from model_compression_toolkit.gptq import GradientPTQConfigV2
|
|
18
18
|
from model_compression_toolkit.core import common
|
|
19
19
|
from model_compression_toolkit.core.pytorch.constants import KERNEL
|
|
20
20
|
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizer import \
|
|
@@ -59,7 +59,7 @@ def quantization_builder(n: common.BaseNode,
|
|
|
59
59
|
quant_method=quant_method,
|
|
60
60
|
quantizer_base_class=BasePytorchGPTQTrainableQuantizer)
|
|
61
61
|
weights_quantizers.update({KERNEL: quantizer_class(get_trainable_quantizer_weights_config(n),
|
|
62
|
-
**gptq_config.
|
|
62
|
+
**gptq_config.gptq_quantizer_params_override)})
|
|
63
63
|
activation_quantizers = []
|
|
64
64
|
if n.is_activation_quantization_enabled():
|
|
65
65
|
quant_method = n.final_activation_quantization_cfg.activation_quantization_method
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import Callable
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit.gptq import RoundingType, GradientPTQConfigV2, GradientPTQConfig
|
|
18
|
+
from model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.soft_quantizer_reg import \
|
|
19
|
+
SoftQuantizerRegularization
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen: Callable) -> Callable:
|
|
23
|
+
"""
|
|
24
|
+
Returns a function that computes the regularization term for GPTQ training based on the given
|
|
25
|
+
rounding type in the GPTQ configuration.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
gptq_config: A GPTQ configuration.
|
|
29
|
+
representative_data_gen: Dataset used for the GPTQ training.
|
|
30
|
+
|
|
31
|
+
Returns: A function for computing the regularization. If there is no regularization function defined for the given
|
|
32
|
+
rounding type, then it returns a function that just returns 0.
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
if gptq_config.rounding_type == RoundingType.SoftQuantizer:
|
|
36
|
+
# dry run on the representative dataset to count number of batches
|
|
37
|
+
num_batches = 0
|
|
38
|
+
for _ in representative_data_gen():
|
|
39
|
+
num_batches += 1
|
|
40
|
+
|
|
41
|
+
n_epochs = GradientPTQConfigV2.from_v1(n_ptq_iter=num_batches, config_v1=gptq_config).n_epochs if \
|
|
42
|
+
not type(gptq_config) == GradientPTQConfigV2 else gptq_config.n_epochs
|
|
43
|
+
return SoftQuantizerRegularization(total_gradient_steps=num_batches * n_epochs)
|
|
44
|
+
else:
|
|
45
|
+
return lambda m, e_reg: 0
|