mct-nightly 1.8.0.8032023.post421__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +10 -9
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +303 -291
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +12 -41
- model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
- model_compression_toolkit/core/__init__.py +14 -0
- model_compression_toolkit/core/analyzer.py +3 -2
- model_compression_toolkit/core/common/__init__.py +0 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +1 -8
- model_compression_toolkit/core/common/framework_info.py +1 -1
- model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
- model_compression_toolkit/core/common/graph/base_graph.py +2 -2
- model_compression_toolkit/core/common/graph/base_node.py +57 -1
- model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
- model_compression_toolkit/core/common/memory_computation.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/model_collector.py +2 -2
- model_compression_toolkit/core/common/model_validation.py +1 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -1
- model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
- model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
- model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
- model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
- model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
- model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/keras/constants.py +0 -7
- model_compression_toolkit/core/keras/default_framework_info.py +3 -3
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/keras/keras_implementation.py +2 -10
- model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
- model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
- model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +1 -1
- model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/constants.py +4 -6
- model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
- model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
- model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/runner.py +7 -7
- model_compression_toolkit/exporter/__init__.py +5 -0
- model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +2 -2
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -38
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +43 -26
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +51 -43
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -35
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +27 -7
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
- model_compression_toolkit/gptq/__init__.py +6 -0
- model_compression_toolkit/gptq/common/gptq_config.py +57 -104
- model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
- model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
- model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
- model_compression_toolkit/gptq/common/gptq_training.py +30 -39
- model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
- model_compression_toolkit/gptq/keras/gptq_training.py +15 -32
- model_compression_toolkit/gptq/keras/graph_info.py +8 -33
- model_compression_toolkit/gptq/keras/quantization_facade.py +25 -24
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -3
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +110 -0
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +40 -137
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +13 -43
- model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
- model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -25
- model_compression_toolkit/gptq/pytorch/graph_info.py +7 -27
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +21 -34
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -23
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +114 -0
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +32 -119
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +11 -33
- model_compression_toolkit/gptq/runner.py +3 -2
- model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +12 -13
- model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
- model_compression_toolkit/{core/common/logger.py → logger.py} +10 -2
- model_compression_toolkit/ptq/__init__.py +3 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
- model_compression_toolkit/qat/__init__.py +4 -0
- model_compression_toolkit/qat/common/__init__.py +1 -2
- model_compression_toolkit/qat/common/qat_config.py +3 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +43 -48
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +34 -43
- model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +25 -24
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +32 -31
- model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +15 -5
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +61 -10
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -5
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +24 -6
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +26 -5
- model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
- model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
- model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
- model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.8.0.8032023.post421.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
- /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py
CHANGED
|
@@ -20,7 +20,7 @@ from model_compression_toolkit.core.common import BaseNode
|
|
|
20
20
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
21
21
|
from model_compression_toolkit.core.common.substitutions.residual_collapsing import ResidualCollapsing
|
|
22
22
|
from model_compression_toolkit.core.pytorch.constants import KERNEL
|
|
23
|
-
from model_compression_toolkit.
|
|
23
|
+
from model_compression_toolkit.logger import Logger
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def residual_collapsing_node_matchers() -> Tuple[NodeOperationMatcher, NodeOperationMatcher]:
|
|
@@ -21,7 +21,7 @@ from torch import reshape
|
|
|
21
21
|
from torch.nn.functional import hardswish, silu, prelu, elu
|
|
22
22
|
from torch.nn.functional import avg_pool2d
|
|
23
23
|
|
|
24
|
-
from model_compression_toolkit import CoreConfig, FrameworkInfo
|
|
24
|
+
from model_compression_toolkit.core import CoreConfig, FrameworkInfo
|
|
25
25
|
from model_compression_toolkit.core import common
|
|
26
26
|
from model_compression_toolkit.core.common import BaseNode, Graph
|
|
27
27
|
from model_compression_toolkit.core.common.graph.graph_matchers import EdgeMatcher
|
|
@@ -15,21 +15,21 @@
|
|
|
15
15
|
|
|
16
16
|
from typing import Callable
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
20
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.constants import PYTORCH
|
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
|
21
21
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
22
22
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
23
23
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi_data import compute_kpi_data
|
|
24
24
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
|
25
25
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
26
26
|
MixedPrecisionQuantizationConfig, DEFAULT_MIXEDPRECISION_CONFIG, MixedPrecisionQuantizationConfigV2
|
|
27
|
-
from model_compression_toolkit.
|
|
27
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
|
28
28
|
|
|
29
29
|
if FOUND_TORCH:
|
|
30
30
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
31
31
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
|
32
|
-
from model_compression_toolkit.
|
|
32
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
33
33
|
from torch.nn import Module
|
|
34
34
|
|
|
35
35
|
from model_compression_toolkit import get_target_platform_capabilities
|
|
@@ -51,7 +51,7 @@ if FOUND_TORCH:
|
|
|
51
51
|
representative_data_gen (Callable): Dataset used for calibration.
|
|
52
52
|
quant_config (MixedPrecisionQuantizationConfig): MixedPrecisionQuantizationConfig containing parameters of how the model should be quantized.
|
|
53
53
|
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default PyTorch info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/pytorch/default_framework_info.py>`_
|
|
54
|
-
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
54
|
+
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
55
55
|
|
|
56
56
|
Returns:
|
|
57
57
|
A KPI object with total weights parameters sum, max activation tensor and total kpi.
|
|
@@ -75,7 +75,7 @@ if FOUND_TORCH:
|
|
|
75
75
|
Import mct and call for KPI data calculation:
|
|
76
76
|
|
|
77
77
|
>>> import model_compression_toolkit as mct
|
|
78
|
-
>>> kpi_data = mct.pytorch_kpi_data(module, repr_datagen)
|
|
78
|
+
>>> kpi_data = mct.core.pytorch_kpi_data(module, repr_datagen)
|
|
79
79
|
|
|
80
80
|
"""
|
|
81
81
|
|
|
@@ -111,7 +111,7 @@ if FOUND_TORCH:
|
|
|
111
111
|
representative_data_gen (Callable): Dataset used for calibration.
|
|
112
112
|
core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision
|
|
113
113
|
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default PyTorch info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/pytorch/default_framework_info.py>`_
|
|
114
|
-
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
114
|
+
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
115
115
|
|
|
116
116
|
Returns:
|
|
117
117
|
|
|
@@ -132,7 +132,7 @@ if FOUND_TORCH:
|
|
|
132
132
|
Import mct and call for KPI data calculation:
|
|
133
133
|
|
|
134
134
|
>>> import model_compression_toolkit as mct
|
|
135
|
-
>>> kpi_data = mct.pytorch_kpi_data(module, repr_datagen)
|
|
135
|
+
>>> kpi_data = mct.core.pytorch_kpi_data(module, repr_datagen)
|
|
136
136
|
|
|
137
137
|
"""
|
|
138
138
|
|
|
@@ -18,7 +18,7 @@ from typing import Any, List
|
|
|
18
18
|
import torch
|
|
19
19
|
import copy
|
|
20
20
|
|
|
21
|
-
from model_compression_toolkit import FrameworkInfo
|
|
21
|
+
from model_compression_toolkit.core import FrameworkInfo
|
|
22
22
|
from model_compression_toolkit.core.common import BaseNode
|
|
23
23
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
|
24
24
|
from model_compression_toolkit.core.pytorch.utils import set_model, to_torch_tensor
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
import operator
|
|
16
|
+
from copy import deepcopy
|
|
16
17
|
from typing import List, Any, Tuple, Callable, Type, Dict
|
|
17
18
|
|
|
18
19
|
import numpy as np
|
|
@@ -22,7 +23,7 @@ from torch.nn import Conv2d, ConvTranspose2d, Linear
|
|
|
22
23
|
from torch.nn import Module, Sigmoid, Softmax
|
|
23
24
|
|
|
24
25
|
import model_compression_toolkit.core.pytorch.constants as pytorch_constants
|
|
25
|
-
from model_compression_toolkit import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfigV2
|
|
26
|
+
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfigV2
|
|
26
27
|
from model_compression_toolkit.core import common
|
|
27
28
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
28
29
|
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
|
|
@@ -74,10 +75,7 @@ from model_compression_toolkit.core.pytorch.pytorch_node_prior_info import creat
|
|
|
74
75
|
from model_compression_toolkit.core.pytorch.reader.reader import model_reader
|
|
75
76
|
from model_compression_toolkit.core.pytorch.statistics_correction.apply_second_moment_correction import \
|
|
76
77
|
pytorch_apply_second_moment_correction
|
|
77
|
-
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
78
|
-
from model_compression_toolkit.core.pytorch.utils import torch_tensor_to_numpy
|
|
79
|
-
from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
|
|
80
|
-
from model_compression_toolkit.gptq.pytorch.gptq_training import PytorchGPTQTrainer
|
|
78
|
+
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model
|
|
81
79
|
|
|
82
80
|
|
|
83
81
|
class PytorchImplementation(FrameworkImplementation):
|
|
@@ -127,7 +125,9 @@ class PytorchImplementation(FrameworkImplementation):
|
|
|
127
125
|
Returns:
|
|
128
126
|
Graph representing the input module.
|
|
129
127
|
"""
|
|
130
|
-
|
|
128
|
+
_module = deepcopy(module)
|
|
129
|
+
_module.eval()
|
|
130
|
+
return model_reader(_module, representative_data_gen, self.to_numpy, self.to_tensor)
|
|
131
131
|
|
|
132
132
|
def model_builder(self,
|
|
133
133
|
graph: Graph,
|
|
@@ -323,12 +323,6 @@ class PytorchImplementation(FrameworkImplementation):
|
|
|
323
323
|
substitutions_list.append(pytorch_batchnorm_refusing())
|
|
324
324
|
return substitutions_list
|
|
325
325
|
|
|
326
|
-
def get_gptq_trainer_obj(self) -> Type[GPTQTrainer]:
|
|
327
|
-
"""
|
|
328
|
-
Returns: GPTQTrainer object
|
|
329
|
-
"""
|
|
330
|
-
return PytorchGPTQTrainer
|
|
331
|
-
|
|
332
326
|
def get_sensitivity_evaluator(self,
|
|
333
327
|
graph: Graph,
|
|
334
328
|
quant_config: MixedPrecisionQuantizationConfigV2,
|
|
@@ -16,7 +16,7 @@ from typing import Any, Tuple
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
from torch.nn import BatchNorm2d
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit import FrameworkInfo
|
|
19
|
+
from model_compression_toolkit.core import FrameworkInfo
|
|
20
20
|
from model_compression_toolkit.core.common import BaseNode, Graph
|
|
21
21
|
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
|
|
22
22
|
from model_compression_toolkit.core.pytorch.constants import MOVING_MEAN, MOVING_VARIANCE, GAMMA, BETA
|
|
@@ -12,10 +12,10 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from typing import
|
|
15
|
+
from typing import Callable
|
|
16
16
|
import torch
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX
|
|
19
19
|
from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import threshold_is_power_of_two
|
|
20
20
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import fix_range_to_include_zero
|
|
21
21
|
|
|
@@ -3,7 +3,7 @@ from typing import Dict, Callable
|
|
|
3
3
|
import torch
|
|
4
4
|
import numpy as np
|
|
5
5
|
|
|
6
|
-
from model_compression_toolkit.
|
|
6
|
+
from model_compression_toolkit.constants import SIGNED, CLUSTER_CENTERS, THRESHOLD, MULTIPLIER_N_BITS, EPS
|
|
7
7
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
8
8
|
|
|
9
9
|
|
|
@@ -25,6 +25,7 @@ from model_compression_toolkit.core.common.graph.functional_node import Function
|
|
|
25
25
|
from model_compression_toolkit.core.pytorch.constants import OUTPUT, PLACEHOLDER, TENSOR_META, CALL_FUNCTION, TYPE, \
|
|
26
26
|
CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, GET_ATTR, CONSTANT, BUFFER
|
|
27
27
|
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder, ConstantHolder, BufferHolder
|
|
28
|
+
from model_compression_toolkit.logger import Logger
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
def extract_holder_weights(constant_name, node_target, model, weights, to_numpy):
|
|
@@ -64,6 +65,7 @@ def nodes_builder(model: GraphModule,
|
|
|
64
65
|
Args:
|
|
65
66
|
model: Pytorch FX model.
|
|
66
67
|
module_dict: A dictionary of the Pyotrch model's named modules.
|
|
68
|
+
to_numpy: A function to convert a Tensor to numpy array
|
|
67
69
|
|
|
68
70
|
Returns:
|
|
69
71
|
A list of Graph nodes that were built from the fx GraphModule nodes.
|
|
@@ -91,7 +93,7 @@ def nodes_builder(model: GraphModule,
|
|
|
91
93
|
node_type = node.target
|
|
92
94
|
if node_type == getattr:
|
|
93
95
|
node_has_activation = False
|
|
94
|
-
|
|
96
|
+
Logger.warning(
|
|
95
97
|
'Pytorch model has a parameter or constant Tensor value. This can cause unexpected behaviour when '
|
|
96
98
|
'converting the model.')
|
|
97
99
|
elif node.op == PLACEHOLDER:
|
|
@@ -112,7 +114,7 @@ def nodes_builder(model: GraphModule,
|
|
|
112
114
|
else:
|
|
113
115
|
node_type = ConstantHolder
|
|
114
116
|
node_has_activation = False
|
|
115
|
-
|
|
117
|
+
Logger.warning(
|
|
116
118
|
'Pytorch model has a parameter or constant Tensor value. This can cause unexpected behaviour when '
|
|
117
119
|
'converting the model.')
|
|
118
120
|
else:
|
model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py
CHANGED
|
@@ -18,7 +18,7 @@ from typing import Any, Callable
|
|
|
18
18
|
import torch
|
|
19
19
|
from tqdm import tqdm
|
|
20
20
|
|
|
21
|
-
from model_compression_toolkit import CoreConfig
|
|
21
|
+
from model_compression_toolkit.core import CoreConfig
|
|
22
22
|
from model_compression_toolkit.core import common
|
|
23
23
|
from model_compression_toolkit.core.pytorch.constants import GAMMA, BETA, MOVING_MEAN, MOVING_VARIANCE
|
|
24
24
|
from model_compression_toolkit.core.pytorch.utils import set_model, to_torch_tensor
|
|
@@ -22,7 +22,7 @@ from tqdm import tqdm
|
|
|
22
22
|
|
|
23
23
|
from model_compression_toolkit.core import common
|
|
24
24
|
from model_compression_toolkit.core.common import FrameworkInfo
|
|
25
|
-
from model_compression_toolkit.
|
|
25
|
+
from model_compression_toolkit.logger import Logger
|
|
26
26
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
27
27
|
from model_compression_toolkit.core.common.fusion.layer_fusing import fusion
|
|
28
28
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
@@ -48,7 +48,7 @@ from model_compression_toolkit.core.common.statistics_correction.statistics_corr
|
|
|
48
48
|
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
|
|
49
49
|
from model_compression_toolkit.core.common.substitutions.linear_collapsing_substitution import \
|
|
50
50
|
linear_collapsing_substitute
|
|
51
|
-
from model_compression_toolkit.
|
|
51
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
52
52
|
from model_compression_toolkit.core.common.visualization.final_config_visualizer import \
|
|
53
53
|
WeightsFinalBitwidthConfigVisualizer, \
|
|
54
54
|
ActivationFinalBitwidthConfigVisualizer
|
|
@@ -143,9 +143,9 @@ def core_runner(in_model: Any,
|
|
|
143
143
|
weights_conf_nodes_bitwidth = tg.get_final_weights_config()
|
|
144
144
|
activation_conf_nodes_bitwidth = tg.get_final_activation_config()
|
|
145
145
|
|
|
146
|
-
|
|
146
|
+
Logger.info(
|
|
147
147
|
f'Final weights bit-width configuration: {[node_b[1] for node_b in weights_conf_nodes_bitwidth]}')
|
|
148
|
-
|
|
148
|
+
Logger.info(
|
|
149
149
|
f'Final activation bit-width configuration: {[node_b[1] for node_b in activation_conf_nodes_bitwidth]}')
|
|
150
150
|
|
|
151
151
|
if tb_w is not None:
|
|
@@ -259,9 +259,9 @@ def _init_tensorboard_writer(fw_info: FrameworkInfo) -> TensorboardWriter:
|
|
|
259
259
|
A TensorBoardWriter object.
|
|
260
260
|
"""
|
|
261
261
|
tb_w = None
|
|
262
|
-
if
|
|
263
|
-
tb_log_dir = os.path.join(os.getcwd(),
|
|
264
|
-
|
|
262
|
+
if Logger.LOG_PATH is not None:
|
|
263
|
+
tb_log_dir = os.path.join(os.getcwd(), Logger.LOG_PATH, 'tensorboard_logs')
|
|
264
|
+
Logger.info(f'To use Tensorboard, please run: tensorboard --logdir {tb_log_dir}')
|
|
265
265
|
tb_w = TensorboardWriter(tb_log_dir, fw_info)
|
|
266
266
|
return tb_w
|
|
267
267
|
|
|
@@ -12,3 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from model_compression_toolkit.exporter.model_exporter.keras.keras_export_facade import keras_export_model, KerasExportMode
|
|
17
|
+
from model_compression_toolkit.exporter.model_exporter.pytorch.pytorch_export_facade import PyTorchExportMode, pytorch_export_model
|
|
18
|
+
from model_compression_toolkit.exporter.model_exporter.tflite.tflite_export_facade import tflite_export_model, TFLiteExportMode
|
|
19
|
+
|
|
@@ -13,6 +13,3 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from model_compression_toolkit.exporter.model_exporter.keras.keras_export_facade import keras_export_model, KerasExportMode
|
|
17
|
-
from model_compression_toolkit.exporter.model_exporter.pytorch.pytorch_export_facade import PyTorchExportMode, pytorch_export_model
|
|
18
|
-
from model_compression_toolkit.exporter.model_exporter.tflite.tflite_export_facade import tflite_export_model, TFLiteExportMode
|
|
@@ -19,7 +19,7 @@ import keras.models
|
|
|
19
19
|
import tensorflow as tf
|
|
20
20
|
from keras.engine.base_layer import Layer
|
|
21
21
|
|
|
22
|
-
from model_compression_toolkit.
|
|
22
|
+
from model_compression_toolkit.logger import Logger
|
|
23
23
|
from model_compression_toolkit.exporter.model_exporter.keras.base_keras_exporter import \
|
|
24
24
|
BaseKerasExporter
|
|
25
25
|
from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
|
|
@@ -15,8 +15,8 @@
|
|
|
15
15
|
from enum import Enum
|
|
16
16
|
from typing import Callable, Dict
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.constants import FOUND_TF
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class KerasExportMode(Enum):
|
model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py
CHANGED
|
@@ -16,17 +16,21 @@ from typing import Callable
|
|
|
16
16
|
|
|
17
17
|
import torch.nn
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit.
|
|
19
|
+
from model_compression_toolkit.logger import Logger
|
|
20
20
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
21
21
|
from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter
|
|
22
22
|
from packaging import version
|
|
23
23
|
|
|
24
|
+
from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
|
|
25
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import LAYER
|
|
26
|
+
|
|
24
27
|
# ONNX opset version 16 is supported from PyTorch 1.12
|
|
25
28
|
if version.parse(torch.__version__) < version.parse("1.12"):
|
|
26
29
|
OPSET_VERSION = 15
|
|
27
30
|
else:
|
|
28
31
|
OPSET_VERSION = 16
|
|
29
32
|
|
|
33
|
+
|
|
30
34
|
class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
|
|
31
35
|
"""
|
|
32
36
|
Exporter for fakely-quant PyTorch models.
|
|
@@ -70,6 +74,16 @@ class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
|
|
|
70
74
|
|
|
71
75
|
Logger.info(f"Exporting PyTorch fake quant onnx model: {self.save_model_path}")
|
|
72
76
|
|
|
77
|
+
# Replace float weight with wrapped quantized weights
|
|
78
|
+
for layer in self.model.modules():
|
|
79
|
+
if isinstance(layer, PytorchQuantizationWrapper):
|
|
80
|
+
for name in layer.weights_quantizers.keys():
|
|
81
|
+
quantized_weight = torch.nn.Parameter(layer.get_quantized_weights()[name]).detach()
|
|
82
|
+
linear_layer = getattr(layer, LAYER)
|
|
83
|
+
delattr(linear_layer, name)
|
|
84
|
+
setattr(linear_layer, name, torch.nn.Parameter(quantized_weight))
|
|
85
|
+
layer.weights_quantizers = {}
|
|
86
|
+
|
|
73
87
|
torch.onnx.export(self.model,
|
|
74
88
|
model_input,
|
|
75
89
|
self.save_model_path,
|
|
@@ -16,7 +16,7 @@ from typing import Callable
|
|
|
16
16
|
|
|
17
17
|
import torch.nn
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit.
|
|
19
|
+
from model_compression_toolkit.logger import Logger
|
|
20
20
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
|
21
21
|
from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter
|
|
22
22
|
|
|
@@ -15,8 +15,8 @@
|
|
|
15
15
|
from enum import Enum
|
|
16
16
|
from typing import Callable
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class PyTorchExportMode(Enum):
|
|
@@ -19,8 +19,8 @@ from typing import Callable
|
|
|
19
19
|
import keras.models
|
|
20
20
|
import tensorflow as tf
|
|
21
21
|
|
|
22
|
-
from model_compression_toolkit import keras_load_quantized_model
|
|
23
|
-
from model_compression_toolkit.
|
|
22
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.load_model import keras_load_quantized_model
|
|
23
|
+
from model_compression_toolkit.logger import Logger
|
|
24
24
|
from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import FakelyQuantKerasExporter
|
|
25
25
|
|
|
26
26
|
|
|
@@ -23,7 +23,7 @@ from keras.layers import Dense, Conv2D, Reshape
|
|
|
23
23
|
from keras.models import clone_model
|
|
24
24
|
|
|
25
25
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
26
|
-
from model_compression_toolkit.
|
|
26
|
+
from model_compression_toolkit.logger import Logger
|
|
27
27
|
from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import FakelyQuantKerasExporter
|
|
28
28
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import \
|
|
29
29
|
constants as keras_inferable_constants
|
|
@@ -15,8 +15,8 @@
|
|
|
15
15
|
from enum import Enum
|
|
16
16
|
from typing import Callable
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.constants import FOUND_TF
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class TFLiteExportMode(Enum):
|
|
@@ -13,12 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from model_compression_toolkit.
|
|
16
|
+
from model_compression_toolkit.exporter.model_wrapper.keras.validate_layer import is_keras_layer_exportable
|
|
17
|
+
from model_compression_toolkit.exporter.model_wrapper.keras.builder.fully_quantized_model_builder import get_exportable_keras_model
|
|
17
18
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
from model_compression_toolkit.exporter.model_wrapper.keras.builder.fully_quantized_model_builder import get_exportable_keras_model
|
|
21
|
-
|
|
22
|
-
if FOUND_TORCH:
|
|
23
|
-
from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable
|
|
24
|
-
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
|
|
19
|
+
from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable
|
|
20
|
+
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
|
model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -14,46 +14,53 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Tuple
|
|
16
16
|
|
|
17
|
-
import tensorflow as tf
|
|
18
|
-
from tensorflow.keras.layers import Layer
|
|
19
17
|
|
|
20
18
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
21
19
|
from model_compression_toolkit.core import common
|
|
22
20
|
from model_compression_toolkit.core.common import Graph
|
|
21
|
+
from model_compression_toolkit.constants import FOUND_TF
|
|
23
22
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
24
|
-
from model_compression_toolkit.
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
23
|
+
from model_compression_toolkit.logger import Logger
|
|
24
|
+
|
|
25
|
+
if FOUND_TF:
|
|
26
|
+
import tensorflow as tf
|
|
27
|
+
from tensorflow.keras.layers import Layer
|
|
28
|
+
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
|
29
|
+
from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizers import get_quantization_quantizers
|
|
30
|
+
|
|
31
|
+
def _get_wrapper(node: common.BaseNode,
|
|
32
|
+
layer: Layer) -> qi.KerasQuantizationWrapper:
|
|
33
|
+
"""
|
|
34
|
+
A function which takes a computational graph node and a keras layer and perform the quantization wrapping
|
|
35
|
+
Args:
|
|
36
|
+
n: A node of mct graph.
|
|
37
|
+
layer: A keras layer
|
|
38
|
+
|
|
39
|
+
Returns: Wrapped layer with weights quantizers and activation quantizers
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
weights_quantizers, activation_quantizers = get_quantization_quantizers(node)
|
|
43
|
+
return qi.KerasQuantizationWrapper(layer, weights_quantizers, activation_quantizers)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_exportable_keras_model(graph: Graph) -> Tuple[tf.keras.models.Model, UserInformation]:
|
|
47
|
+
"""
|
|
48
|
+
Convert graph to an exportable Keras model (model with all quantization parameters).
|
|
49
|
+
An exportable model can then be exported using model_exporter, to retrieve the
|
|
50
|
+
final exported model.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
graph: Graph to convert to an exportable Keras model.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Exportable Keras model and user information.
|
|
57
|
+
"""
|
|
58
|
+
exportable_model, user_info = KerasModelBuilder(graph=graph,
|
|
59
|
+
wrapper=_get_wrapper).build_model()
|
|
60
|
+
exportable_model.trainable = False
|
|
61
|
+
return exportable_model, user_info
|
|
62
|
+
else:
|
|
63
|
+
def get_exportable_keras_model(*args, **kwargs): # pragma: no cover
|
|
64
|
+
Logger.error('Installing tensorflow and tensorflow_model_optimization is mandatory '
|
|
65
|
+
'when using get_exportable_keras_model. '
|
|
66
|
+
'Could not find Tensorflow package.')
|
|
@@ -14,16 +14,15 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Dict, Any
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.core.common import BaseNode
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.core.common.target_platform import QuantizationMethod
|
|
20
|
-
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget
|
|
21
|
-
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
|
|
22
|
-
get_inferable_quantizer_class
|
|
23
|
-
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer \
|
|
24
|
-
import \
|
|
25
|
-
BaseKerasInferableQuantizer
|
|
17
|
+
from model_compression_toolkit.core.common import BaseNode
|
|
18
|
+
from model_compression_toolkit.constants import THRESHOLD, RANGE_MIN, RANGE_MAX, SIGNED, CLUSTER_CENTERS, SCALE_PER_CHANNEL
|
|
26
19
|
|
|
20
|
+
from model_compression_toolkit.logger import Logger
|
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
22
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget
|
|
23
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import get_inferable_quantizer_class
|
|
24
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer import BaseKerasInferableQuantizer
|
|
25
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import constants as qi_keras_consts
|
|
27
26
|
|
|
28
27
|
def get_inferable_quantizer_kwargs(node: BaseNode,
|
|
29
28
|
quantization_target: QuantizationTarget) -> Dict[str, Any]:
|
|
@@ -44,19 +43,29 @@ def get_inferable_quantizer_kwargs(node: BaseNode,
|
|
|
44
43
|
# Return the appropriate quantization parameters based on the quantization method
|
|
45
44
|
if quantization_method in [QuantizationMethod.POWER_OF_TWO,
|
|
46
45
|
QuantizationMethod.SYMMETRIC]:
|
|
47
|
-
return {
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
46
|
+
return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
|
|
47
|
+
qi_keras_consts.THRESHOLD: list(node_w_qc.weights_quantization_params[THRESHOLD].flatten()),
|
|
48
|
+
qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
|
|
49
|
+
qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
|
|
50
|
+
qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[THRESHOLD].shape)}
|
|
52
51
|
|
|
53
52
|
elif quantization_method in [QuantizationMethod.UNIFORM]:
|
|
54
|
-
return {
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
53
|
+
return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
|
|
54
|
+
qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
|
|
55
|
+
qi_keras_consts.MIN_RANGE: list(node_w_qc.weights_quantization_params[RANGE_MIN].flatten()),
|
|
56
|
+
qi_keras_consts.MAX_RANGE: list(node_w_qc.weights_quantization_params[RANGE_MAX].flatten()),
|
|
57
|
+
qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
|
|
58
|
+
qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[RANGE_MIN].shape)}
|
|
59
|
+
|
|
60
|
+
elif quantization_method in [QuantizationMethod.LUT_SYM_QUANTIZER, QuantizationMethod.LUT_POT_QUANTIZER]:
|
|
61
|
+
return {qi_keras_consts.NUM_BITS: node_w_qc.weights_n_bits,
|
|
62
|
+
qi_keras_consts.PER_CHANNEL: node_w_qc.weights_per_channel_threshold,
|
|
63
|
+
qi_keras_consts.CLUSTER_CENTERS: node_w_qc.weights_quantization_params[CLUSTER_CENTERS],
|
|
64
|
+
qi_keras_consts.THRESHOLD: list(node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten()),
|
|
65
|
+
qi_keras_consts.CHANNEL_AXIS: node_w_qc.weights_channels_axis,
|
|
66
|
+
# TODO: how to pass multiplier nbits and eps for a specific node?
|
|
67
|
+
qi_keras_consts.INPUT_RANK: len(node_w_qc.weights_quantization_params[SCALE_PER_CHANNEL].shape)}
|
|
68
|
+
|
|
60
69
|
else:
|
|
61
70
|
Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
|
|
62
71
|
|
|
@@ -68,16 +77,24 @@ def get_inferable_quantizer_kwargs(node: BaseNode,
|
|
|
68
77
|
# Return the appropriate quantization parameters based on the quantization method
|
|
69
78
|
if quantization_method in [QuantizationMethod.POWER_OF_TWO,
|
|
70
79
|
QuantizationMethod.SYMMETRIC]:
|
|
71
|
-
return {
|
|
80
|
+
return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
|
|
72
81
|
# In activation quantization is per-tensor only - thus we hold the threshold as a list with a len of 1
|
|
73
|
-
|
|
74
|
-
|
|
82
|
+
qi_keras_consts.THRESHOLD: [node_qc.activation_quantization_params[THRESHOLD]],
|
|
83
|
+
qi_keras_consts.SIGNED: node_qc.activation_quantization_params[SIGNED]}
|
|
75
84
|
|
|
76
85
|
elif quantization_method in [QuantizationMethod.UNIFORM]:
|
|
77
|
-
return {
|
|
86
|
+
return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
|
|
78
87
|
# In activation quantization is per-tensor only - thus we hold the min/max as a list with a len of 1
|
|
79
|
-
|
|
80
|
-
|
|
88
|
+
qi_keras_consts.MIN_RANGE: [node_qc.activation_quantization_params[RANGE_MIN]],
|
|
89
|
+
qi_keras_consts.MAX_RANGE: [node_qc.activation_quantization_params[RANGE_MAX]]}
|
|
90
|
+
|
|
91
|
+
elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER]:
|
|
92
|
+
return {qi_keras_consts.NUM_BITS: node_qc.activation_n_bits,
|
|
93
|
+
qi_keras_consts.SIGNED: node_qc.activation_quantization_params[SIGNED],
|
|
94
|
+
qi_keras_consts.CLUSTER_CENTERS: node_qc.activation_quantization_params[CLUSTER_CENTERS],
|
|
95
|
+
qi_keras_consts.THRESHOLD: [node_qc.activation_quantization_params[THRESHOLD]]
|
|
96
|
+
# TODO: how to pass multiplier nbits and eps for a specific node?
|
|
97
|
+
}
|
|
81
98
|
else:
|
|
82
99
|
Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover
|
|
83
100
|
else:
|