mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +12 -41
- model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
- model_compression_toolkit/core/__init__.py +14 -0
- model_compression_toolkit/core/analyzer.py +3 -2
- model_compression_toolkit/core/common/__init__.py +0 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +1 -8
- model_compression_toolkit/core/common/framework_info.py +1 -1
- model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
- model_compression_toolkit/core/common/graph/base_graph.py +2 -2
- model_compression_toolkit/core/common/graph/base_node.py +57 -1
- model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
- model_compression_toolkit/core/common/memory_computation.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/model_collector.py +2 -2
- model_compression_toolkit/core/common/model_validation.py +1 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -1
- model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
- model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
- model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
- model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
- model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
- model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/keras/constants.py +0 -7
- model_compression_toolkit/core/keras/default_framework_info.py +3 -3
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/keras/keras_implementation.py +2 -10
- model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
- model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
- model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +1 -1
- model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/constants.py +4 -6
- model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
- model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
- model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/runner.py +7 -7
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -104
- model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
- model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
- model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
- model_compression_toolkit/gptq/common/gptq_training.py +30 -39
- model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
- model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
- model_compression_toolkit/gptq/keras/graph_info.py +8 -33
- model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
- model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
- model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
- model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
- model_compression_toolkit/gptq/runner.py +3 -2
- model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
- model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
- model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
- model_compression_toolkit/ptq/__init__.py +3 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
- model_compression_toolkit/qat/__init__.py +4 -0
- model_compression_toolkit/qat/common/__init__.py +1 -2
- model_compression_toolkit/qat/common/qat_config.py +3 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
- model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
- model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
- model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
- model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
- model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
- model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
- /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
|
@@ -12,12 +12,13 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
from
|
|
15
|
+
from abc import abstractmethod
|
|
16
|
+
from enum import Enum
|
|
17
|
+
from typing import Union, List, Any
|
|
17
18
|
from inspect import signature
|
|
18
19
|
|
|
19
20
|
from model_compression_toolkit.core import common
|
|
20
|
-
from model_compression_toolkit.
|
|
21
|
+
from model_compression_toolkit.logger import Logger
|
|
21
22
|
|
|
22
23
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer, \
|
|
23
24
|
QuantizationTarget
|
|
@@ -27,6 +28,19 @@ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructur
|
|
|
27
28
|
QUANTIZATION_TARGET
|
|
28
29
|
|
|
29
30
|
|
|
31
|
+
VAR = 'var'
|
|
32
|
+
GROUP = 'group'
|
|
33
|
+
|
|
34
|
+
class VariableGroup(Enum):
|
|
35
|
+
"""
|
|
36
|
+
An enum for choosing trainable variable group
|
|
37
|
+
0. WEIGHTS
|
|
38
|
+
1. QPARAMS
|
|
39
|
+
"""
|
|
40
|
+
WEIGHTS = 0
|
|
41
|
+
QPARAMS = 1
|
|
42
|
+
|
|
43
|
+
|
|
30
44
|
class BaseTrainableQuantizer(BaseInferableQuantizer):
|
|
31
45
|
def __init__(self,
|
|
32
46
|
quantization_config: Union[TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig]):
|
|
@@ -41,9 +55,9 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
|
|
|
41
55
|
for i, (k, v) in enumerate(self.get_sig().parameters.items()):
|
|
42
56
|
if i == 0:
|
|
43
57
|
if v.annotation not in [TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]:
|
|
44
|
-
|
|
58
|
+
Logger.error(f"First parameter must be either TrainableQuantizerWeightsConfig or TrainableQuantizerActivationConfig") # pragma: no cover
|
|
45
59
|
elif v.default is v.empty:
|
|
46
|
-
|
|
60
|
+
Logger.error(f"Parameter {k} doesn't have a default value") # pragma: no cover
|
|
47
61
|
|
|
48
62
|
super(BaseTrainableQuantizer, self).__init__()
|
|
49
63
|
self.quantization_config = quantization_config
|
|
@@ -59,17 +73,19 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
|
|
|
59
73
|
if static_quantization_target == QuantizationTarget.Weights:
|
|
60
74
|
self.validate_weights()
|
|
61
75
|
if self.quantization_config.weights_quantization_method not in static_quantization_method:
|
|
62
|
-
|
|
76
|
+
Logger.error(
|
|
63
77
|
f'Quantization method mismatch expected: {static_quantization_method} and got {self.quantization_config.weights_quantization_method}')
|
|
64
78
|
elif static_quantization_target == QuantizationTarget.Activation:
|
|
65
79
|
self.validate_activation()
|
|
66
80
|
if self.quantization_config.activation_quantization_method not in static_quantization_method:
|
|
67
|
-
|
|
81
|
+
Logger.error(
|
|
68
82
|
f'Quantization method mismatch expected: {static_quantization_method} and got {self.quantization_config.activation_quantization_method}')
|
|
69
83
|
else:
|
|
70
|
-
|
|
84
|
+
Logger.error(
|
|
71
85
|
f'Unknown Quantization Part:{static_quantization_target}') # pragma: no cover
|
|
72
86
|
|
|
87
|
+
self.quantizer_parameters = {}
|
|
88
|
+
|
|
73
89
|
@classmethod
|
|
74
90
|
def get_sig(cls):
|
|
75
91
|
return signature(cls)
|
|
@@ -129,7 +145,7 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
|
|
|
129
145
|
|
|
130
146
|
"""
|
|
131
147
|
if self.activation_quantization() or not self.weights_quantization():
|
|
132
|
-
|
|
148
|
+
Logger.error(f'Expect weight quantization got activation')
|
|
133
149
|
|
|
134
150
|
def validate_activation(self) -> None:
|
|
135
151
|
"""
|
|
@@ -137,7 +153,7 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
|
|
|
137
153
|
|
|
138
154
|
"""
|
|
139
155
|
if not self.activation_quantization() or self.weights_quantization():
|
|
140
|
-
|
|
156
|
+
Logger.error(f'Expect activation quantization got weight')
|
|
141
157
|
|
|
142
158
|
def convert2inferable(self) -> BaseInferableQuantizer:
|
|
143
159
|
"""
|
|
@@ -147,3 +163,38 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
|
|
|
147
163
|
BaseInferableQuantizer object.
|
|
148
164
|
"""
|
|
149
165
|
raise NotImplemented # pragma: no cover
|
|
166
|
+
|
|
167
|
+
def add_quantizer_variable(self, name: str, variable: Any, group: VariableGroup = VariableGroup.WEIGHTS):
|
|
168
|
+
"""
|
|
169
|
+
Add a quantizer variable to quantizer_parameters dictionary
|
|
170
|
+
"""
|
|
171
|
+
self.quantizer_parameters.update({name: {VAR: variable, GROUP: group}})
|
|
172
|
+
|
|
173
|
+
def get_quantizer_variable(self, name: str) -> Any:
|
|
174
|
+
"""
|
|
175
|
+
Get a quantizer variable by name
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
name: variable name
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
trainable variable
|
|
182
|
+
"""
|
|
183
|
+
if name in self.quantizer_parameters:
|
|
184
|
+
return self.quantizer_parameters[name][VAR]
|
|
185
|
+
else:
|
|
186
|
+
Logger.error(f'Variable {name} is not exist in quantizers parameters!') # pragma: no cover
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@abstractmethod
|
|
190
|
+
def get_trainable_variables(self, group: VariableGroup) -> List[Any]:
|
|
191
|
+
"""
|
|
192
|
+
Get trainable parameters with specific group from quantizer
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
group: Enum of variable group
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
List of trainable variables
|
|
199
|
+
"""
|
|
200
|
+
raise NotImplemented # pragma: no cover
|
|
@@ -13,7 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import List
|
|
16
|
-
from model_compression_toolkit.core.common import BaseNode
|
|
16
|
+
from model_compression_toolkit.core.common import BaseNode
|
|
17
|
+
from model_compression_toolkit.logger import Logger
|
|
17
18
|
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
|
|
18
19
|
TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig, TrainableQuantizerCandidateConfig
|
|
19
20
|
|
|
@@ -12,11 +12,10 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from typing import Union
|
|
15
|
+
from typing import Union, Any
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit import
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
17
|
+
from model_compression_toolkit.logger import Logger
|
|
18
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
20
19
|
from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
|
|
21
20
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants \
|
|
22
21
|
import QUANTIZATION_TARGET, QUANTIZATION_METHOD, QUANTIZER_TYPE
|
|
@@ -25,7 +24,7 @@ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructur
|
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
def get_trainable_quantizer_class(quant_target: QuantizationTarget,
|
|
28
|
-
quantizer_type: Union[
|
|
27
|
+
quantizer_type: Union[Any, Any],
|
|
29
28
|
quant_method: QuantizationMethod,
|
|
30
29
|
quantizer_base_class: type) -> type:
|
|
31
30
|
"""
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from abc import ABC
|
|
16
16
|
from typing import Dict, List
|
|
17
|
-
from model_compression_toolkit.
|
|
17
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class TrainableQuantizerCandidateConfig:
|
|
@@ -12,12 +12,12 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from typing import Dict, Any, Union
|
|
15
|
+
from typing import Dict, Any, Union, List
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
|
|
20
|
-
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer
|
|
17
|
+
from model_compression_toolkit.logger import Logger
|
|
18
|
+
from model_compression_toolkit.constants import FOUND_TF
|
|
19
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
20
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer, VAR, GROUP
|
|
21
21
|
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
22
22
|
TrainableQuantizerActivationConfig
|
|
23
23
|
|
|
@@ -25,7 +25,7 @@ if FOUND_TF:
|
|
|
25
25
|
QUANTIZATION_CONFIG = 'quantization_config'
|
|
26
26
|
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.keras.config_serialization import config_serialization, \
|
|
27
27
|
config_deserialization
|
|
28
|
-
|
|
28
|
+
import tensorflow as tf
|
|
29
29
|
|
|
30
30
|
class BaseKerasTrainableQuantizer(BaseTrainableQuantizer):
|
|
31
31
|
def __init__(self,
|
|
@@ -61,6 +61,24 @@ if FOUND_TF:
|
|
|
61
61
|
# Note that a quantizer only receive quantization config and the rest of define hardcoded inside the speficie quantizer.
|
|
62
62
|
return cls(quantization_config=quantization_config)
|
|
63
63
|
|
|
64
|
+
def get_trainable_variables(self, group: VariableGroup) -> List[tf.Tensor]:
|
|
65
|
+
"""
|
|
66
|
+
Get trainable parameters with specific group from quantizer
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
group: Enum of variable group
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
List of trainable variables
|
|
73
|
+
"""
|
|
74
|
+
quantizer_trainable = []
|
|
75
|
+
for name, parameter_dict in self.quantizer_parameters.items():
|
|
76
|
+
quantizer_parameter, parameter_group = parameter_dict[VAR], parameter_dict[GROUP]
|
|
77
|
+
if quantizer_parameter.trainable and parameter_group == group:
|
|
78
|
+
quantizer_trainable.append(quantizer_parameter)
|
|
79
|
+
return quantizer_trainable
|
|
80
|
+
|
|
81
|
+
|
|
64
82
|
else:
|
|
65
83
|
class BaseKerasTrainableQuantizer(BaseTrainableQuantizer):
|
|
66
84
|
def __init__(self,
|
|
@@ -17,7 +17,7 @@ import copy
|
|
|
17
17
|
from typing import Any, Union
|
|
18
18
|
from enum import Enum
|
|
19
19
|
|
|
20
|
-
from model_compression_toolkit.
|
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
21
21
|
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.trainable_quantizer_config import \
|
|
22
22
|
TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig
|
|
23
23
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common import constants as C
|
|
@@ -12,17 +12,20 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from typing import Union
|
|
15
|
+
from typing import Union, List
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
|
|
20
|
-
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer
|
|
17
|
+
from model_compression_toolkit.logger import Logger
|
|
18
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
|
19
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
20
|
+
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer, VAR, GROUP
|
|
21
21
|
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
22
22
|
TrainableQuantizerActivationConfig
|
|
23
23
|
|
|
24
|
+
|
|
24
25
|
if FOUND_TORCH:
|
|
25
26
|
|
|
27
|
+
import torch
|
|
28
|
+
|
|
26
29
|
class BasePytorchTrainableQuantizer(BaseTrainableQuantizer):
|
|
27
30
|
def __init__(self,
|
|
28
31
|
quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
|
|
@@ -35,6 +38,24 @@ if FOUND_TORCH:
|
|
|
35
38
|
"""
|
|
36
39
|
super().__init__(quantization_config)
|
|
37
40
|
|
|
41
|
+
|
|
42
|
+
def get_trainable_variables(self, group: VariableGroup) -> List[torch.Tensor]:
|
|
43
|
+
"""
|
|
44
|
+
Get trainable parameters with specific group from quantizer
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
group: Enum of variable group
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
List of trainable variables
|
|
51
|
+
"""
|
|
52
|
+
quantizer_trainable = []
|
|
53
|
+
for name, parameter_dict in self.quantizer_parameters.items():
|
|
54
|
+
quantizer_parameter, parameter_group = parameter_dict[VAR], parameter_dict[GROUP]
|
|
55
|
+
if quantizer_parameter.requires_grad and parameter_group == group:
|
|
56
|
+
quantizer_trainable.append(quantizer_parameter)
|
|
57
|
+
return quantizer_trainable
|
|
58
|
+
|
|
38
59
|
else:
|
|
39
60
|
class BasePytorchTrainableQuantizer(BaseTrainableQuantizer):
|
|
40
61
|
def __init__(self,
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
# TP Model constants
|
|
17
|
+
OPS_SET_LIST = 'ops_set_list'
|
|
18
|
+
|
|
19
|
+
# Version
|
|
20
|
+
LATEST = 'latest'
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# Supported TP models names:
|
|
24
|
+
DEFAULT_TP_MODEL = 'default'
|
|
25
|
+
IMX500_TP_MODEL = 'imx500'
|
|
26
|
+
TFLITE_TP_MODEL = 'tflite'
|
|
27
|
+
QNNPACK_TP_MODEL = 'qnnpack'
|
model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py
RENAMED
|
@@ -13,16 +13,16 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from model_compression_toolkit.
|
|
17
|
-
from model_compression_toolkit.
|
|
16
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.fusing import Fusing
|
|
17
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
|
|
18
18
|
TargetPlatformCapabilities, OperationsSetToLayers, Smaller, SmallerEq, NotEq, Eq, GreaterEq, Greater, LayerFilterParams, OperationsToLayers, get_current_tpc
|
|
19
19
|
|
|
20
|
-
from model_compression_toolkit.
|
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model import \
|
|
21
21
|
get_default_quantization_config_options, TargetPlatformModel
|
|
22
22
|
|
|
23
|
-
from model_compression_toolkit.
|
|
23
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import OpQuantizationConfig, \
|
|
24
24
|
QuantizationConfigOptions, QuantizationMethod
|
|
25
|
-
from model_compression_toolkit.
|
|
25
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSet, OperatorSetConcat
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from model_compression_toolkit.
|
|
16
|
+
from model_compression_toolkit.logger import Logger
|
|
17
17
|
|
|
18
18
|
def get_current_tp_model():
|
|
19
19
|
"""
|
model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py
RENAMED
|
@@ -16,8 +16,8 @@
|
|
|
16
16
|
|
|
17
17
|
from typing import Any
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit.
|
|
20
|
-
from model_compression_toolkit.
|
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorSetConcat
|
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import TargetPlatformModelComponent
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class Fusing(TargetPlatformModelComponent):
|
model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py
RENAMED
|
@@ -14,10 +14,10 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Dict, Any
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
20
|
-
from model_compression_toolkit.
|
|
17
|
+
from model_compression_toolkit.target_platform_capabilities.constants import OPS_SET_LIST
|
|
18
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import TargetPlatformModelComponent
|
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import _current_tp_model
|
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import QuantizationConfigOptions
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class OperatorsSetBase(TargetPlatformModelComponent):
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from enum import Enum
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class QuantizationFormat(Enum):
|
|
19
|
+
FAKELY_QUANT = 0
|
|
20
|
+
INT8 = 1
|
|
@@ -16,16 +16,16 @@
|
|
|
16
16
|
import pprint
|
|
17
17
|
from typing import Any, Dict
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit.
|
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import _current_tp_model, \
|
|
20
20
|
get_current_tp_model
|
|
21
|
-
from model_compression_toolkit.
|
|
22
|
-
from model_compression_toolkit.
|
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.fusing import Fusing
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import \
|
|
23
23
|
TargetPlatformModelComponent
|
|
24
|
-
from model_compression_toolkit.
|
|
24
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import OpQuantizationConfig, \
|
|
25
25
|
QuantizationConfigOptions
|
|
26
|
-
from model_compression_toolkit.
|
|
27
|
-
from model_compression_toolkit.
|
|
28
|
-
from model_compression_toolkit.
|
|
26
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSetBase
|
|
27
|
+
from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass
|
|
28
|
+
from model_compression_toolkit.logger import Logger
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
def get_default_quantization_config_options() -> QuantizationConfigOptions:
|
|
@@ -223,3 +223,12 @@ class TargetPlatformModel(ImmutableClass):
|
|
|
223
223
|
|
|
224
224
|
"""
|
|
225
225
|
pprint.pprint(self.get_info(), sort_dicts=False)
|
|
226
|
+
|
|
227
|
+
def set_quantization_format(self,
|
|
228
|
+
quantization_format: Any):
|
|
229
|
+
"""
|
|
230
|
+
Set quantization format.
|
|
231
|
+
Args:
|
|
232
|
+
quantization_format: A quantization format (fake-quant, int8 etc.) from enum QuantizationFormat.
|
|
233
|
+
"""
|
|
234
|
+
self.quantization_format = quantization_format
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Any, Dict
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.
|
|
17
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import _current_tp_model
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class TargetPlatformModelComponent:
|
|
@@ -13,13 +13,13 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from model_compression_toolkit.
|
|
17
|
-
from model_compression_toolkit.
|
|
18
|
-
from model_compression_toolkit.
|
|
16
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import get_current_tpc
|
|
17
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities import TargetPlatformCapabilities
|
|
18
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import \
|
|
19
19
|
Eq, GreaterEq, NotEq, SmallerEq, Greater, Smaller
|
|
20
|
-
from model_compression_toolkit.
|
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.layer_filter_params import \
|
|
21
21
|
LayerFilterParams
|
|
22
|
-
from model_compression_toolkit.
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.operations_to_layers import \
|
|
23
23
|
OperationsToLayers, OperationsSetToLayers
|
|
24
24
|
|
|
25
25
|
|
|
@@ -13,10 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from typing import Any
|
|
17
|
-
|
|
18
|
-
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
19
|
-
from model_compression_toolkit.core.common.target_platform.targetplatform2framework.attribute_filter import AttributeFilter
|
|
16
|
+
from typing import Any
|
|
17
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import AttributeFilter
|
|
20
18
|
|
|
21
19
|
|
|
22
20
|
class LayerFilterParams:
|
|
@@ -87,34 +85,34 @@ class LayerFilterParams:
|
|
|
87
85
|
params.extend([str(c) for c in self.conditions])
|
|
88
86
|
params_str = ', '.join(params)
|
|
89
87
|
return f'{self.layer.__name__}({params_str})'
|
|
90
|
-
|
|
91
|
-
def match(self,
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
88
|
+
#
|
|
89
|
+
# def match(self,
|
|
90
|
+
# node: BaseNode) -> bool:
|
|
91
|
+
# """
|
|
92
|
+
# Check if a node matches the layer, conditions and keyword-arguments of
|
|
93
|
+
# the LayerFilterParams.
|
|
94
|
+
#
|
|
95
|
+
# Args:
|
|
96
|
+
# node: Node to check if matches to the LayerFilterParams properties.
|
|
97
|
+
#
|
|
98
|
+
# Returns:
|
|
99
|
+
# Whether the node matches to the LayerFilterParams properties.
|
|
100
|
+
# """
|
|
101
|
+
# # Check the node has the same type as the layer in LayerFilterParams
|
|
102
|
+
# if self.layer != node.type:
|
|
103
|
+
# return False
|
|
104
|
+
#
|
|
105
|
+
# # Get attributes from node to filter
|
|
106
|
+
# layer_config = node.framework_attr
|
|
107
|
+
# if hasattr(node, "op_call_kwargs"):
|
|
108
|
+
# layer_config.update(node.op_call_kwargs)
|
|
109
|
+
#
|
|
110
|
+
# for attr, value in self.kwargs.items():
|
|
111
|
+
# if layer_config.get(attr) != value:
|
|
112
|
+
# return False
|
|
113
|
+
#
|
|
114
|
+
# for c in self.conditions:
|
|
115
|
+
# if not c.match(layer_config):
|
|
116
|
+
# return False
|
|
117
|
+
#
|
|
118
|
+
# return True
|
|
@@ -15,10 +15,10 @@
|
|
|
15
15
|
|
|
16
16
|
from typing import List, Any
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
20
|
-
from model_compression_toolkit.
|
|
21
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
|
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
|
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorSetConcat, \
|
|
22
22
|
OperatorsSetBase
|
|
23
23
|
|
|
24
24
|
|
|
@@ -18,18 +18,17 @@ import itertools
|
|
|
18
18
|
import pprint
|
|
19
19
|
from typing import List, Any, Dict, Tuple
|
|
20
20
|
|
|
21
|
-
from model_compression_toolkit.
|
|
22
|
-
from model_compression_toolkit.
|
|
21
|
+
from model_compression_toolkit.logger import Logger
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.operations_to_layers import \
|
|
23
23
|
OperationsToLayers, OperationsSetToLayers
|
|
24
|
-
from model_compression_toolkit.
|
|
25
|
-
from model_compression_toolkit.
|
|
26
|
-
from model_compression_toolkit.
|
|
27
|
-
from model_compression_toolkit.
|
|
28
|
-
from model_compression_toolkit.core.common.target_platform.op_quantization_config import QuantizationConfigOptions, \
|
|
24
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
|
|
25
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.layer_filter_params import LayerFilterParams
|
|
26
|
+
from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass
|
|
27
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import QuantizationConfigOptions, \
|
|
29
28
|
OpQuantizationConfig
|
|
30
|
-
from model_compression_toolkit.
|
|
31
|
-
from model_compression_toolkit.
|
|
32
|
-
from model_compression_toolkit.
|
|
29
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSetBase
|
|
30
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model import TargetPlatformModel
|
|
31
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
|
|
33
32
|
|
|
34
33
|
|
|
35
34
|
class TargetPlatformCapabilities(ImmutableClass):
|
|
@@ -163,26 +162,6 @@ class TargetPlatformCapabilities(ImmutableClass):
|
|
|
163
162
|
"""
|
|
164
163
|
return self.tp_model.get_default_op_quantization_config()
|
|
165
164
|
|
|
166
|
-
def get_qco_by_node(self,
|
|
167
|
-
node: BaseNode) -> QuantizationConfigOptions:
|
|
168
|
-
"""
|
|
169
|
-
Get the QuantizationConfigOptions of a node in a graph according
|
|
170
|
-
to the mappings from layers/LayerFilterParams to the OperatorsSet in the TargetPlatformModel.
|
|
171
|
-
|
|
172
|
-
Args:
|
|
173
|
-
node: Node from graph to get its QuantizationConfigOptions.
|
|
174
|
-
|
|
175
|
-
Returns:
|
|
176
|
-
QuantizationConfigOptions of the node.
|
|
177
|
-
"""
|
|
178
|
-
if node is None:
|
|
179
|
-
Logger.error(f'Can not retrieve QC options for None node') # pragma: no cover
|
|
180
|
-
for fl, qco in self.filterlayer2qco.items():
|
|
181
|
-
if fl.match(node):
|
|
182
|
-
return qco
|
|
183
|
-
if node.type in self.layer2qco:
|
|
184
|
-
return self.layer2qco.get(node.type)
|
|
185
|
-
return self.tp_model.default_qco
|
|
186
165
|
|
|
187
166
|
def _get_config_options_mapping(self) -> Tuple[Dict[Any, QuantizationConfigOptions],
|
|
188
167
|
Dict[LayerFilterParams, QuantizationConfigOptions]]:
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from model_compression_toolkit.
|
|
16
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class TargetPlatformCapabilitiesComponent:
|
|
File without changes
|