mct-nightly 1.8.0.8042023.post345__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +4 -3
- {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +285 -277
- model_compression_toolkit/__init__.py +9 -32
- model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
- model_compression_toolkit/core/__init__.py +14 -0
- model_compression_toolkit/core/analyzer.py +3 -2
- model_compression_toolkit/core/common/__init__.py +0 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +1 -8
- model_compression_toolkit/core/common/framework_info.py +1 -1
- model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
- model_compression_toolkit/core/common/graph/base_graph.py +2 -2
- model_compression_toolkit/core/common/graph/base_node.py +57 -1
- model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
- model_compression_toolkit/core/common/memory_computation.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/model_collector.py +2 -2
- model_compression_toolkit/core/common/model_validation.py +1 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -1
- model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
- model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
- model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
- model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
- model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
- model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +2 -2
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/keras/constants.py +0 -7
- model_compression_toolkit/core/keras/default_framework_info.py +3 -3
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/keras/keras_implementation.py +2 -10
- model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
- model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
- model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +1 -1
- model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +2 -2
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/constants.py +0 -6
- model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
- model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
- model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/runner.py +7 -7
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +3 -2
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +5 -3
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +3 -2
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +4 -3
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +2 -2
- model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
- model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
- model_compression_toolkit/gptq/common/gptq_training.py +2 -1
- model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
- model_compression_toolkit/gptq/keras/gptq_training.py +5 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +27 -20
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -5
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +2 -2
- model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
- model_compression_toolkit/gptq/pytorch/gptq_training.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +13 -13
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -3
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +8 -3
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -2
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +2 -2
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +9 -11
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +2 -2
- model_compression_toolkit/gptq/runner.py +3 -2
- model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +11 -12
- model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
- model_compression_toolkit/ptq/__init__.py +3 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
- model_compression_toolkit/qat/__init__.py +4 -0
- model_compression_toolkit/qat/common/__init__.py +1 -2
- model_compression_toolkit/qat/common/qat_config.py +3 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +13 -11
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +9 -9
- model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +4 -3
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +7 -5
- model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +9 -9
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -6
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +2 -2
- model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
- model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
- model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
- model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
- {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +0 -0
- {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
- /model_compression_toolkit/{core/common/logger.py → logger.py} +0 -0
- /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
|
@@ -19,7 +19,7 @@ import numpy as np
|
|
|
19
19
|
|
|
20
20
|
from model_compression_toolkit.core.common import max_power_of_two
|
|
21
21
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
22
|
-
from model_compression_toolkit.
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
23
23
|
from model_compression_toolkit.gptq.common.gptq_config import RoundingType
|
|
24
24
|
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
|
|
25
25
|
BasePytorchGPTQTrainableQuantizer
|
|
@@ -27,7 +27,7 @@ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_
|
|
|
27
27
|
from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
|
|
28
28
|
from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
|
|
29
29
|
SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
|
|
30
|
-
from model_compression_toolkit.
|
|
30
|
+
from model_compression_toolkit.constants import THRESHOLD, MIN_THRESHOLD
|
|
31
31
|
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
|
|
32
32
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
|
|
33
33
|
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
|
|
@@ -18,7 +18,8 @@ from typing import Dict
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
|
|
20
20
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
21
|
-
from model_compression_toolkit.
|
|
21
|
+
from model_compression_toolkit.quantizers_infrastructure.constants import FQ_MIN, FQ_MAX
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
22
23
|
from model_compression_toolkit.gptq.common.gptq_config import RoundingType
|
|
23
24
|
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
|
|
24
25
|
BasePytorchGPTQTrainableQuantizer
|
|
@@ -31,8 +32,8 @@ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructur
|
|
|
31
32
|
mark_quantizer
|
|
32
33
|
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import \
|
|
33
34
|
VariableGroup
|
|
34
|
-
from model_compression_toolkit.
|
|
35
|
-
|
|
35
|
+
from model_compression_toolkit.constants import RANGE_MAX, RANGE_MIN
|
|
36
|
+
|
|
36
37
|
|
|
37
38
|
def soft_rounding_unifrom_quantizer(input_tensor: torch.Tensor,
|
|
38
39
|
auxvar_tensor: torch.Tensor,
|
|
@@ -54,13 +55,12 @@ def soft_rounding_unifrom_quantizer(input_tensor: torch.Tensor,
|
|
|
54
55
|
"""
|
|
55
56
|
# adjusts the quantization range so the quantization grid includes zero.
|
|
56
57
|
min_range, max_range = fix_range_to_include_zero(min_range, max_range, num_bits)
|
|
57
|
-
delta = qutils.calculate_delta_uniform(
|
|
58
|
-
|
|
59
|
-
input_tensor_int = torch.floor(input_tensor / delta)
|
|
58
|
+
delta = qutils.calculate_delta_uniform(min_range, max_range, num_bits)
|
|
59
|
+
input_tensor_int = qutils.ste_floor((input_tensor - min_range) / delta)
|
|
60
60
|
tensor_q = input_tensor_int + auxvar_tensor
|
|
61
61
|
return delta * qutils.ste_clip(tensor_q,
|
|
62
62
|
min_val=0,
|
|
63
|
-
max_val=2 ** num_bits - 1)
|
|
63
|
+
max_val=2 ** num_bits - 1) + min_range
|
|
64
64
|
|
|
65
65
|
|
|
66
66
|
@mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
|
|
@@ -122,8 +122,8 @@ class UniformSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
|
122
122
|
layer.register_parameter(name+"_"+FQ_MAX, nn.Parameter(max_values, requires_grad=self.quantization_parameter_learning))
|
|
123
123
|
|
|
124
124
|
w = layer.layer.weight
|
|
125
|
-
delta = qutils.calculate_delta_uniform(
|
|
126
|
-
w_clipped_normed = torch.clip(w / delta, 0, 2 ** self.num_bits - 1)
|
|
125
|
+
delta = qutils.calculate_delta_uniform(min_values, max_values, self.num_bits)
|
|
126
|
+
w_clipped_normed = torch.clip((w - min_values) / delta, 0, 2 ** self.num_bits - 1)
|
|
127
127
|
rest = w_clipped_normed - torch.floor(w_clipped_normed) # rest of rounding [0, 1)
|
|
128
128
|
alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # => sigmoid(alpha) = rest
|
|
129
129
|
layer.register_parameter(f"{name}_{AUXVAR}", nn.Parameter(alpha, requires_grad=True))
|
|
@@ -133,7 +133,6 @@ class UniformSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
|
133
133
|
self.add_quantizer_variable(FQ_MAX, layer.get_parameter(name+"_"+FQ_MAX), VariableGroup.QPARAMS)
|
|
134
134
|
self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
|
|
135
135
|
|
|
136
|
-
|
|
137
136
|
def get_soft_targets(self) -> torch.Tensor:
|
|
138
137
|
"""
|
|
139
138
|
Computes the rectified sigmoid function for the quantization target parameters.
|
|
@@ -192,5 +191,4 @@ class UniformSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
|
|
|
192
191
|
max_range=max_range,
|
|
193
192
|
num_bits=self.num_bits)
|
|
194
193
|
|
|
195
|
-
|
|
196
194
|
return q_tensor
|
|
@@ -19,14 +19,14 @@ import numpy as np
|
|
|
19
19
|
from model_compression_toolkit.core.common.defaultdict import DefaultDict
|
|
20
20
|
|
|
21
21
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
22
|
-
from model_compression_toolkit.
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
23
23
|
from model_compression_toolkit.gptq.common.gptq_config import RoundingType
|
|
24
24
|
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
|
|
25
25
|
BasePytorchGPTQTrainableQuantizer
|
|
26
26
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
|
|
27
27
|
from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
|
|
28
28
|
from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD, MAX_LSB_CHANGE
|
|
29
|
-
from model_compression_toolkit.
|
|
29
|
+
from model_compression_toolkit.constants import THRESHOLD
|
|
30
30
|
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
|
|
31
31
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import \
|
|
32
32
|
mark_quantizer
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
from typing import Callable
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit import CoreConfig
|
|
18
|
+
from model_compression_toolkit.core import CoreConfig
|
|
19
19
|
from model_compression_toolkit.core import common
|
|
20
20
|
from model_compression_toolkit.core.common.statistics_correction.statistics_correction import \
|
|
21
21
|
apply_statistics_correction
|
|
@@ -28,6 +28,7 @@ from model_compression_toolkit.gptq.common.gptq_training import gptq_training
|
|
|
28
28
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter
|
|
29
29
|
from model_compression_toolkit.core.common.statistics_correction.apply_bias_correction_to_graph import \
|
|
30
30
|
apply_bias_correction_to_graph
|
|
31
|
+
from model_compression_toolkit.logger import Logger
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
def _apply_gptq(gptq_config: GradientPTQConfigV2,
|
|
@@ -55,7 +56,7 @@ def _apply_gptq(gptq_config: GradientPTQConfigV2,
|
|
|
55
56
|
|
|
56
57
|
"""
|
|
57
58
|
if gptq_config is not None and gptq_config.n_epochs > 0:
|
|
58
|
-
|
|
59
|
+
Logger.info("Using experimental Gradient Based PTQ: If you encounter an issue "
|
|
59
60
|
"please file a bug. To disable it, do not pass a gptq configuration.")
|
|
60
61
|
|
|
61
62
|
tg_bias = gptq_training(tg,
|
model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py}
RENAMED
|
@@ -15,9 +15,8 @@
|
|
|
15
15
|
|
|
16
16
|
from typing import Callable, List, Tuple
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
20
|
-
from model_compression_toolkit.core.common.constants import TENSORFLOW
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.constants import TENSORFLOW
|
|
21
20
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
22
21
|
from model_compression_toolkit.gptq import GradientPTQConfig, GradientPTQConfigV2
|
|
23
22
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
@@ -35,15 +34,15 @@ from model_compression_toolkit.ptq.runner import ptq_runner
|
|
|
35
34
|
from model_compression_toolkit.core.exporter import export_model
|
|
36
35
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
37
36
|
|
|
38
|
-
from model_compression_toolkit.
|
|
39
|
-
from model_compression_toolkit.
|
|
37
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
38
|
+
from model_compression_toolkit.constants import FOUND_TF
|
|
40
39
|
|
|
41
40
|
if FOUND_TF:
|
|
42
41
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
43
42
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
|
44
43
|
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
|
45
44
|
from tensorflow.keras.models import Model
|
|
46
|
-
from model_compression_toolkit.
|
|
45
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
47
46
|
|
|
48
47
|
from model_compression_toolkit import get_target_platform_capabilities
|
|
49
48
|
|
|
@@ -81,7 +80,7 @@ if FOUND_TF:
|
|
|
81
80
|
network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
|
|
82
81
|
gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
|
|
83
82
|
analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
|
|
84
|
-
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
83
|
+
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
85
84
|
|
|
86
85
|
Returns:
|
|
87
86
|
A quantized model and information the user may need to handle the quantized model.
|
|
@@ -184,7 +183,7 @@ if FOUND_TF:
|
|
|
184
183
|
network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
|
|
185
184
|
gptq_config (GradientPTQConfig): Configuration for using GPTQ (e.g. optimizer).
|
|
186
185
|
analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
|
|
187
|
-
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
186
|
+
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
188
187
|
|
|
189
188
|
|
|
190
189
|
Returns:
|
|
@@ -209,13 +208,13 @@ if FOUND_TF:
|
|
|
209
208
|
Create a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
|
|
210
209
|
The candidates bitwidth for quantization should be defined in the target platform model:
|
|
211
210
|
|
|
212
|
-
>>> config = mct.MixedPrecisionQuantizationConfig()
|
|
211
|
+
>>> config = mct.core.MixedPrecisionQuantizationConfig()
|
|
213
212
|
|
|
214
213
|
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
|
|
215
214
|
that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value,
|
|
216
215
|
while the bias will not):
|
|
217
216
|
|
|
218
|
-
>>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
217
|
+
>>> kpi = mct.core.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
219
218
|
|
|
220
219
|
Pass the model, the representative dataset generator, the configuration and the target KPI to get a
|
|
221
220
|
quantized model:
|
|
@@ -229,11 +228,11 @@ if FOUND_TF:
|
|
|
229
228
|
fw_info=fw_info).validate()
|
|
230
229
|
|
|
231
230
|
if not isinstance(quant_config, MixedPrecisionQuantizationConfig):
|
|
232
|
-
|
|
231
|
+
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
233
232
|
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization API,"
|
|
234
233
|
"or pass a valid mixed precision configuration.")
|
|
235
234
|
|
|
236
|
-
|
|
235
|
+
Logger.info("Using experimental mixed-precision quantization. "
|
|
237
236
|
"If you encounter an issue please file a bug.")
|
|
238
237
|
|
|
239
238
|
quantization_config, mp_config = quant_config.separate_configs()
|
|
@@ -14,12 +14,11 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Callable, List, Tuple
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.core.common.constants import PYTORCH
|
|
17
|
+
from model_compression_toolkit.logger import Logger
|
|
18
|
+
from model_compression_toolkit.constants import PYTORCH
|
|
20
19
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
21
20
|
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GradientPTQConfigV2
|
|
22
|
-
from model_compression_toolkit.
|
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
|
23
22
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
24
23
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
25
24
|
from model_compression_toolkit.core.common.network_editors.actions import EditRule
|
|
@@ -34,12 +33,12 @@ from model_compression_toolkit.gptq.runner import gptq_runner
|
|
|
34
33
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
|
35
34
|
from model_compression_toolkit.core.exporter import export_model
|
|
36
35
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
37
|
-
from model_compression_toolkit.
|
|
36
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
|
38
37
|
|
|
39
38
|
if FOUND_TORCH:
|
|
40
39
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
41
40
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
|
42
|
-
from model_compression_toolkit.
|
|
41
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
43
42
|
from torch.nn import Module
|
|
44
43
|
|
|
45
44
|
from model_compression_toolkit import get_target_platform_capabilities
|
|
@@ -76,7 +75,7 @@ if FOUND_TORCH:
|
|
|
76
75
|
network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
|
|
77
76
|
gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
|
|
78
77
|
analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
|
|
79
|
-
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
78
|
+
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
80
79
|
|
|
81
80
|
|
|
82
81
|
Returns:
|
|
@@ -175,7 +174,7 @@ if FOUND_TORCH:
|
|
|
175
174
|
network_editor (List[EditRule]): List of EditRules. Each EditRule consists of a node filter and an action to change quantization settings of the filtered nodes.
|
|
176
175
|
gptq_config (GradientPTQConfig): Configuration for using GPTQ (e.g. optimizer).
|
|
177
176
|
analyze_similarity (bool): Whether to plot similarity figures within TensorBoard (when logger is enabled) or not.
|
|
178
|
-
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
177
|
+
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
179
178
|
|
|
180
179
|
Returns:
|
|
181
180
|
A quantized model and information the user may need to handle the quantized model.
|
|
@@ -199,13 +198,13 @@ if FOUND_TORCH:
|
|
|
199
198
|
Create a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
|
|
200
199
|
The candidates bitwidth for quantization should be defined in the target platform model:
|
|
201
200
|
|
|
202
|
-
>>> config = mct.MixedPrecisionQuantizationConfig()
|
|
201
|
+
>>> config = mct.core.MixedPrecisionQuantizationConfig()
|
|
203
202
|
|
|
204
203
|
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
|
|
205
204
|
that should be quantized (for example, the kernel of Conv2D in PyTorch will be affected by this value,
|
|
206
205
|
while the bias will not):
|
|
207
206
|
|
|
208
|
-
>>> kpi = mct.KPI(sum(p.numel() for p in module.parameters()) * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
207
|
+
>>> kpi = mct.core.KPI(sum(p.numel() for p in module.parameters()) * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
209
208
|
|
|
210
209
|
Pass the model, the representative dataset generator, the configuration and the target KPI to get a
|
|
211
210
|
quantized model:
|
|
@@ -217,11 +216,11 @@ if FOUND_TORCH:
|
|
|
217
216
|
"""
|
|
218
217
|
|
|
219
218
|
if not isinstance(quant_config, MixedPrecisionQuantizationConfig):
|
|
220
|
-
|
|
219
|
+
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
221
220
|
"MixedPrecisionQuantizationConfig. Please use pytorch_post_training_quantization API, "
|
|
222
221
|
"or pass a valid mixed precision configuration.")
|
|
223
222
|
|
|
224
|
-
|
|
223
|
+
Logger.info("Using experimental mixed-precision quantization. "
|
|
225
224
|
"If you encounter an issue please file a bug.")
|
|
226
225
|
|
|
227
226
|
quantization_config, mp_config = quant_config.separate_configs()
|
|
@@ -12,3 +12,6 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from model_compression_toolkit.ptq.pytorch.quantization_facade import pytorch_post_training_quantization_experimental
|
|
17
|
+
from model_compression_toolkit.ptq.keras.quantization_facade import keras_post_training_quantization_experimental
|
|
@@ -15,15 +15,14 @@
|
|
|
15
15
|
|
|
16
16
|
from typing import Callable
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit import CoreConfig
|
|
19
|
-
from model_compression_toolkit.core import common
|
|
18
|
+
from model_compression_toolkit.core import CoreConfig
|
|
20
19
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
21
|
-
from model_compression_toolkit.
|
|
22
|
-
from model_compression_toolkit.
|
|
20
|
+
from model_compression_toolkit.logger import Logger
|
|
21
|
+
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
|
|
23
22
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
24
23
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
25
24
|
MixedPrecisionQuantizationConfigV2
|
|
26
|
-
from model_compression_toolkit.
|
|
25
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
27
26
|
from model_compression_toolkit.core.exporter import export_model
|
|
28
27
|
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
|
|
29
28
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
|
@@ -33,7 +32,7 @@ if FOUND_TF:
|
|
|
33
32
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
|
34
33
|
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
|
35
34
|
from tensorflow.keras.models import Model
|
|
36
|
-
from model_compression_toolkit.
|
|
35
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
37
36
|
from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
|
|
38
37
|
|
|
39
38
|
from model_compression_toolkit import get_target_platform_capabilities
|
|
@@ -93,25 +92,25 @@ if FOUND_TF:
|
|
|
93
92
|
|
|
94
93
|
Create a MCT core config, containing the quantization configuration:
|
|
95
94
|
|
|
96
|
-
>>> config = mct.CoreConfig()
|
|
95
|
+
>>> config = mct.core.CoreConfig()
|
|
97
96
|
|
|
98
97
|
If mixed precision is desired, create a MCT core config with a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
|
|
99
98
|
The candidates bitwidth for quantization should be defined in the target platform model.
|
|
100
99
|
In this example we use 1 image to search mixed-precision configuration:
|
|
101
100
|
|
|
102
|
-
>>> config = mct.CoreConfig(mixed_precision_config=mct.MixedPrecisionQuantizationConfigV2(num_of_images=1))
|
|
101
|
+
>>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfigV2(num_of_images=1))
|
|
103
102
|
|
|
104
103
|
For mixed-precision set a target KPI object:
|
|
105
104
|
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
|
|
106
105
|
that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value,
|
|
107
106
|
while the bias will not):
|
|
108
107
|
|
|
109
|
-
>>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
108
|
+
>>> kpi = mct.core.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
110
109
|
|
|
111
110
|
Pass the model, the representative dataset generator, the configuration and the target KPI to get a
|
|
112
111
|
quantized model:
|
|
113
112
|
|
|
114
|
-
>>> quantized_model, quantization_info = mct.keras_post_training_quantization_experimental(model, repr_datagen, kpi, core_config=config)
|
|
113
|
+
>>> quantized_model, quantization_info = mct.ptq.keras_post_training_quantization_experimental(model, repr_datagen, kpi, core_config=config)
|
|
115
114
|
|
|
116
115
|
For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
|
|
117
116
|
|
|
@@ -124,11 +123,11 @@ if FOUND_TF:
|
|
|
124
123
|
|
|
125
124
|
if core_config.mixed_precision_enable:
|
|
126
125
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
|
|
127
|
-
|
|
126
|
+
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
128
127
|
"MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
|
|
129
128
|
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
130
129
|
|
|
131
|
-
|
|
130
|
+
Logger.info("Using experimental mixed-precision quantization. "
|
|
132
131
|
"If you encounter an issue please file a bug.")
|
|
133
132
|
|
|
134
133
|
tb_w = _init_tensorboard_writer(fw_info)
|
|
@@ -15,11 +15,11 @@
|
|
|
15
15
|
from typing import Callable
|
|
16
16
|
|
|
17
17
|
from model_compression_toolkit.core import common
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
20
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.constants import PYTORCH, FOUND_TORCH
|
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
|
21
21
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
22
|
-
from model_compression_toolkit import CoreConfig
|
|
22
|
+
from model_compression_toolkit.core import CoreConfig
|
|
23
23
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
24
24
|
MixedPrecisionQuantizationConfigV2
|
|
25
25
|
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
|
|
@@ -31,7 +31,7 @@ from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
|
31
31
|
if FOUND_TORCH:
|
|
32
32
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
33
33
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
|
34
|
-
from model_compression_toolkit.
|
|
34
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
35
35
|
from torch.nn import Module
|
|
36
36
|
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
|
|
37
37
|
from model_compression_toolkit import get_target_platform_capabilities
|
|
@@ -88,18 +88,18 @@ if FOUND_TORCH:
|
|
|
88
88
|
Set number of clibration iterations to 1:
|
|
89
89
|
|
|
90
90
|
>>> import model_compression_toolkit as mct
|
|
91
|
-
>>> quantized_module, quantization_info = mct.pytorch_post_training_quantization_experimental(module, repr_datagen)
|
|
91
|
+
>>> quantized_module, quantization_info = mct.ptq.pytorch_post_training_quantization_experimental(module, repr_datagen)
|
|
92
92
|
|
|
93
93
|
"""
|
|
94
94
|
|
|
95
95
|
if core_config.mixed_precision_enable:
|
|
96
96
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
|
|
97
|
-
|
|
97
|
+
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
98
98
|
"MixedPrecisionQuantizationConfigV2. Please use "
|
|
99
99
|
"pytorch_post_training_quantization API, or pass a valid mixed precision "
|
|
100
100
|
"configuration.") # pragma: no cover
|
|
101
101
|
|
|
102
|
-
|
|
102
|
+
Logger.info("Using experimental mixed-precision quantization. "
|
|
103
103
|
"If you encounter an issue please file a bug.")
|
|
104
104
|
|
|
105
105
|
tb_w = _init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
|
|
@@ -12,3 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
from model_compression_toolkit.qat.common.qat_config import QATConfig, TrainingMethod
|
|
16
|
+
|
|
17
|
+
from model_compression_toolkit.qat.keras.quantization_facade import keras_quantization_aware_training_init, keras_quantization_aware_training_finalize
|
|
18
|
+
from model_compression_toolkit.qat.pytorch.quantization_facade import pytorch_quantization_aware_training_init, pytorch_quantization_aware_training_finalize
|
|
@@ -12,5 +12,4 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
from model_compression_toolkit.qat.common.constants import THRESHOLD_TENSOR, WEIGHTS_QUANTIZATION_PARAMS
|
|
15
|
+
from model_compression_toolkit.quantizers_infrastructure.constants import THRESHOLD_TENSOR, WEIGHTS_QUANTIZATION_PARAMS
|
|
@@ -17,6 +17,8 @@ from typing import Dict
|
|
|
17
17
|
from enum import Enum
|
|
18
18
|
from model_compression_toolkit.core import common
|
|
19
19
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
20
|
+
from model_compression_toolkit.logger import Logger
|
|
21
|
+
|
|
20
22
|
|
|
21
23
|
def _is_qat_applicable(node: common.BaseNode,
|
|
22
24
|
fw_info: FrameworkInfo) -> bool:
|
|
@@ -31,7 +33,7 @@ def _is_qat_applicable(node: common.BaseNode,
|
|
|
31
33
|
"""
|
|
32
34
|
|
|
33
35
|
if node.is_weights_quantization_enabled() and not fw_info.is_kernel_op(node.type):
|
|
34
|
-
|
|
36
|
+
Logger.error("QAT Error: Quantizing a node without a kernel isn't supported")
|
|
35
37
|
return node.is_weights_quantization_enabled() or node.is_activation_quantization_enabled()
|
|
36
38
|
|
|
37
39
|
|
|
@@ -16,15 +16,13 @@
|
|
|
16
16
|
from typing import Callable
|
|
17
17
|
from functools import partial
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit import CoreConfig
|
|
20
|
-
from model_compression_toolkit.
|
|
21
|
-
from model_compression_toolkit.
|
|
22
|
-
from model_compression_toolkit.core.common.constants import TENSORFLOW, FOUND_TF
|
|
23
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
19
|
+
from model_compression_toolkit.core import CoreConfig
|
|
20
|
+
from model_compression_toolkit.logger import Logger
|
|
21
|
+
from model_compression_toolkit.constants import FOUND_TF
|
|
24
22
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
25
23
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
26
24
|
MixedPrecisionQuantizationConfigV2
|
|
27
|
-
from model_compression_toolkit.
|
|
25
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
28
26
|
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
|
|
29
27
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
|
30
28
|
|
|
@@ -36,7 +34,7 @@ if FOUND_TF:
|
|
|
36
34
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
37
35
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
|
38
36
|
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
|
39
|
-
from model_compression_toolkit.
|
|
37
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
40
38
|
|
|
41
39
|
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
|
42
40
|
|
|
@@ -46,10 +44,10 @@ if FOUND_TF:
|
|
|
46
44
|
from model_compression_toolkit import get_target_platform_capabilities
|
|
47
45
|
from model_compression_toolkit.core import common
|
|
48
46
|
from model_compression_toolkit.core.common import BaseNode
|
|
49
|
-
from model_compression_toolkit.
|
|
47
|
+
from model_compression_toolkit.constants import TENSORFLOW
|
|
50
48
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
51
49
|
from model_compression_toolkit.qat.common.qat_config import _is_qat_applicable
|
|
52
|
-
from model_compression_toolkit.
|
|
50
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
53
51
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
54
52
|
from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder
|
|
55
53
|
from model_compression_toolkit.qat.common.qat_config import QATConfig
|
|
@@ -134,24 +132,24 @@ if FOUND_TF:
|
|
|
134
132
|
|
|
135
133
|
Create a MCT core config, containing the quantization configuration:
|
|
136
134
|
|
|
137
|
-
>>> config = mct.CoreConfig()
|
|
135
|
+
>>> config = mct.core.CoreConfig()
|
|
138
136
|
|
|
139
137
|
If mixed precision is desired, create a MCT core config with a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
|
|
140
138
|
The candidates bitwidth for quantization should be defined in the target platform model:
|
|
141
139
|
|
|
142
|
-
>>> config = mct.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfigV2())
|
|
140
|
+
>>> config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfigV2())
|
|
143
141
|
|
|
144
142
|
For mixed-precision set a target KPI object:
|
|
145
143
|
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
|
|
146
144
|
that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value,
|
|
147
145
|
while the bias will not):
|
|
148
146
|
|
|
149
|
-
>>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
147
|
+
>>> kpi = mct.core.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
150
148
|
|
|
151
149
|
Pass the model, the representative dataset generator, the configuration and the target KPI to get a
|
|
152
150
|
quantized model:
|
|
153
151
|
|
|
154
|
-
>>> quantized_model, quantization_info, custom_objects = mct.keras_quantization_aware_training_init(model, repr_datagen, kpi, core_config=config)
|
|
152
|
+
>>> quantized_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init(model, repr_datagen, kpi, core_config=config)
|
|
155
153
|
|
|
156
154
|
Use the quantized model for fine-tuning. For loading the model from file, use the custom_objects dictionary:
|
|
157
155
|
|
|
@@ -165,11 +163,11 @@ if FOUND_TF:
|
|
|
165
163
|
|
|
166
164
|
if core_config.mixed_precision_enable:
|
|
167
165
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
|
|
168
|
-
|
|
166
|
+
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
169
167
|
"MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization API,"
|
|
170
168
|
"or pass a valid mixed precision configuration.")
|
|
171
169
|
|
|
172
|
-
|
|
170
|
+
Logger.info("Using experimental mixed-precision quantization. "
|
|
173
171
|
"If you encounter an issue please file a bug.")
|
|
174
172
|
|
|
175
173
|
tb_w = _init_tensorboard_writer(fw_info)
|
|
@@ -223,29 +221,29 @@ if FOUND_TF:
|
|
|
223
221
|
|
|
224
222
|
Create a MCT core config, containing the quantization configuration:
|
|
225
223
|
|
|
226
|
-
>>> config = mct.CoreConfig()
|
|
224
|
+
>>> config = mct.core.CoreConfig()
|
|
227
225
|
|
|
228
226
|
If mixed precision is desired, create a MCT core config with a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
|
|
229
227
|
The candidates bitwidth for quantization should be defined in the target platform model:
|
|
230
228
|
|
|
231
|
-
>>> config = mct.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfigV2())
|
|
229
|
+
>>> config = mct.core.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfigV2())
|
|
232
230
|
|
|
233
231
|
For mixed-precision set a target KPI object:
|
|
234
232
|
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
|
|
235
233
|
that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value,
|
|
236
234
|
while the bias will not):
|
|
237
235
|
|
|
238
|
-
>>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
236
|
+
>>> kpi = mct.core.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
239
237
|
|
|
240
238
|
Pass the model, the representative dataset generator, the configuration and the target KPI to get a
|
|
241
239
|
quantized model:
|
|
242
240
|
|
|
243
|
-
>>> quantized_model, quantization_info, custom_objects = mct.keras_quantization_aware_training_init(model, repr_datagen, kpi, core_config=config)
|
|
241
|
+
>>> quantized_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init(model, repr_datagen, kpi, core_config=config)
|
|
244
242
|
|
|
245
243
|
Use the quantized model for fine-tuning. For loading the model from file, use the custom_objects dictionary:
|
|
246
244
|
|
|
247
245
|
>>> quantized_model = tf.keras.models.load_model(model_file, custom_objects=custom_objects)
|
|
248
|
-
>>> quantized_model = mct.keras_quantization_aware_training_finalize(quantized_model)
|
|
246
|
+
>>> quantized_model = mct.qat.keras_quantization_aware_training_finalize(quantized_model)
|
|
249
247
|
|
|
250
248
|
"""
|
|
251
249
|
def _export(layer):
|
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Union
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.
|
|
18
|
-
from model_compression_toolkit.
|
|
17
|
+
from model_compression_toolkit.logger import Logger
|
|
18
|
+
from model_compression_toolkit.constants import FOUND_TF
|
|
19
19
|
|
|
20
20
|
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
21
21
|
TrainableQuantizerActivationConfig, BaseKerasTrainableQuantizer
|