mct-nightly 1.8.0.8042023.post345__py3-none-any.whl → 1.8.0.8052023.post414__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/METADATA +4 -3
- {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/RECORD +285 -277
- model_compression_toolkit/__init__.py +9 -32
- model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
- model_compression_toolkit/core/__init__.py +14 -0
- model_compression_toolkit/core/analyzer.py +3 -2
- model_compression_toolkit/core/common/__init__.py +0 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +1 -8
- model_compression_toolkit/core/common/framework_info.py +1 -1
- model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
- model_compression_toolkit/core/common/graph/base_graph.py +2 -2
- model_compression_toolkit/core/common/graph/base_node.py +57 -1
- model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
- model_compression_toolkit/core/common/memory_computation.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/model_collector.py +2 -2
- model_compression_toolkit/core/common/model_validation.py +1 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -1
- model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
- model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
- model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
- model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
- model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
- model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -4
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +2 -2
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/keras/constants.py +0 -7
- model_compression_toolkit/core/keras/default_framework_info.py +3 -3
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/keras/keras_implementation.py +2 -10
- model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
- model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
- model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +1 -1
- model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +2 -2
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/constants.py +0 -6
- model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
- model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
- model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/runner.py +7 -7
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/tflite/int8_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +3 -2
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +5 -3
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +3 -2
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +4 -3
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +2 -2
- model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
- model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
- model_compression_toolkit/gptq/common/gptq_training.py +2 -1
- model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
- model_compression_toolkit/gptq/keras/gptq_training.py +5 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +27 -20
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -5
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +2 -2
- model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
- model_compression_toolkit/gptq/pytorch/gptq_training.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +13 -13
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -3
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +8 -3
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -2
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +2 -2
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +9 -11
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +2 -2
- model_compression_toolkit/gptq/runner.py +3 -2
- model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +11 -12
- model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
- model_compression_toolkit/ptq/__init__.py +3 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
- model_compression_toolkit/qat/__init__.py +4 -0
- model_compression_toolkit/qat/common/__init__.py +1 -2
- model_compression_toolkit/qat/common/qat_config.py +3 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +18 -20
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +13 -11
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +9 -9
- model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +4 -3
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +7 -5
- model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +9 -9
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -6
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +2 -2
- model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
- model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
- model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
- model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
- {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/WHEEL +0 -0
- {mct_nightly-1.8.0.8042023.post345.dist-info → mct_nightly-1.8.0.8052023.post414.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
- /model_compression_toolkit/{core/common/logger.py → logger.py} +0 -0
- /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
|
@@ -26,13 +26,13 @@ else:
|
|
|
26
26
|
|
|
27
27
|
from typing import Any, Dict, List, Tuple
|
|
28
28
|
from tensorflow.python.util.object_identity import Reference as TFReference
|
|
29
|
-
from model_compression_toolkit.
|
|
29
|
+
from model_compression_toolkit.constants import EPS, MIN_JACOBIANS_ITER, JACOBIANS_COMP_TOLERANCE
|
|
30
30
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
|
31
31
|
from model_compression_toolkit.core import common
|
|
32
32
|
from model_compression_toolkit.core.common import BaseNode, Graph
|
|
33
33
|
from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
|
|
34
34
|
from model_compression_toolkit.core.keras.back2framework.instance_builder import OperationHandler
|
|
35
|
-
from model_compression_toolkit.
|
|
35
|
+
from model_compression_toolkit.logger import Logger
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
def build_input_tensors_list(node: BaseNode,
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import List
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit import FrameworkInfo
|
|
17
|
+
from model_compression_toolkit.core import FrameworkInfo
|
|
18
18
|
from model_compression_toolkit.core import common
|
|
19
19
|
from model_compression_toolkit.core.common import BaseNode
|
|
20
20
|
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
|
@@ -99,13 +99,6 @@ OUTPUT_BIAS = '/attention_output/bias'
|
|
|
99
99
|
# ReLU bound constants
|
|
100
100
|
RELU_POT_BOUND = 8.0
|
|
101
101
|
|
|
102
|
-
# Supported TP models names for Tensorflow:
|
|
103
|
-
DEFAULT_TP_MODEL = 'default'
|
|
104
|
-
IMX500_TP_MODEL = 'imx500'
|
|
105
|
-
TFLITE_TP_MODEL = 'tflite'
|
|
106
|
-
QNNPACK_TP_MODEL = 'qnnpack'
|
|
107
|
-
|
|
108
|
-
|
|
109
102
|
# TFOpLambda functions:
|
|
110
103
|
ADD = 'add'
|
|
111
104
|
PAD = 'pad'
|
|
@@ -25,9 +25,9 @@ else:
|
|
|
25
25
|
from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU
|
|
26
26
|
|
|
27
27
|
from model_compression_toolkit.core.common.defaultdict import DefaultDict
|
|
28
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
29
|
-
from model_compression_toolkit.
|
|
30
|
-
from model_compression_toolkit.
|
|
28
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
29
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
30
|
+
from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
|
|
31
31
|
from model_compression_toolkit.core.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \
|
|
32
32
|
KERNEL, DEPTHWISE_KERNEL
|
|
33
33
|
from model_compression_toolkit.core.keras.quantizer.fake_quant_builder import power_of_two_quantization, symmetric_quantization, uniform_quantization
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
from tensorflow.keras.layers import Dense, DepthwiseConv2D, Conv2D, Conv2DTranspose, Activation, SeparableConv2D
|
|
18
18
|
|
|
19
19
|
from model_compression_toolkit.core import common
|
|
20
|
-
from model_compression_toolkit.
|
|
20
|
+
from model_compression_toolkit.constants import FLOAT_32, DATA_TYPE
|
|
21
21
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
22
22
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, \
|
|
23
23
|
NodeFrameworkAttrMatcher
|
|
@@ -23,7 +23,7 @@ from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
|
23
23
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, EdgeMatcher, WalkMatcher
|
|
24
24
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
25
25
|
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
|
26
|
-
from model_compression_toolkit.
|
|
26
|
+
from model_compression_toolkit.constants import THRESHOLD
|
|
27
27
|
from model_compression_toolkit.core.keras.constants import KERNEL
|
|
28
28
|
|
|
29
29
|
input_node = NodeOperationMatcher(InputLayer)
|
|
@@ -21,7 +21,7 @@ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOpera
|
|
|
21
21
|
from model_compression_toolkit.core.common.substitutions.linear_collapsing import Conv2DCollapsing
|
|
22
22
|
from model_compression_toolkit.core.keras.constants import KERNEL, KERNEL_SIZE, STRIDES, DILATIONS, LINEAR, \
|
|
23
23
|
ACTIVATION, BIAS, USE_BIAS, LAYER_NAME, FILTERS, PADDING, GROUPS, DATA_FORMAT
|
|
24
|
-
from model_compression_toolkit.
|
|
24
|
+
from model_compression_toolkit.logger import Logger
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
def linear_collapsing_node_matchers() -> Tuple[NodeOperationMatcher, NodeOperationMatcher]:
|
|
@@ -23,17 +23,16 @@ else:
|
|
|
23
23
|
from keras.layers.core import TFOpLambda
|
|
24
24
|
from keras.layers import MultiHeadAttention, Conv2D, Softmax, Concatenate, Reshape, Permute
|
|
25
25
|
|
|
26
|
-
from model_compression_toolkit.
|
|
26
|
+
from model_compression_toolkit.logger import Logger
|
|
27
27
|
from model_compression_toolkit.core import common
|
|
28
28
|
from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
|
|
29
29
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
|
30
30
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
31
|
-
from model_compression_toolkit.
|
|
32
|
-
from model_compression_toolkit.core.keras.reader.node_builder import REUSED_IDENTIFIER
|
|
31
|
+
from model_compression_toolkit.constants import REUSE, REUSE_GROUP
|
|
33
32
|
from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, NUM_HEADS, KEY_DIM, VALUE_DIM, \
|
|
34
33
|
QUERY_SHAPE, KEY_SHAPE, VALUE_SHAPE, OUTPUT_SHAPE, ATTENTION_AXES, ACTIVATION, LINEAR, FILTERS, \
|
|
35
34
|
FUNCTION, DIMS, TARGET_SHAPE, F_STRIDED_SLICE, F_STACK, Q_KERNEL, Q_BIAS, K_KERNEL, K_BIAS, V_KERNEL, V_BIAS, \
|
|
36
|
-
OUTPUT_KERNEL, OUTPUT_BIAS, F_MATMUL,
|
|
35
|
+
OUTPUT_KERNEL, OUTPUT_BIAS, F_MATMUL, KERNEL_SIZE, AXIS, F_STRIDED_SLICE_BEGIN, F_STRIDED_SLICE_END
|
|
37
36
|
|
|
38
37
|
|
|
39
38
|
class MHAParams:
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py
CHANGED
|
@@ -23,6 +23,7 @@ from model_compression_toolkit.core import common
|
|
|
23
23
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
24
24
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, WalkMatcher
|
|
25
25
|
from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, RELU_MAX_VALUE, RELU_POT_BOUND
|
|
26
|
+
from model_compression_toolkit.logger import Logger
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
|
|
@@ -81,7 +82,7 @@ class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
|
|
|
81
82
|
scale_factor = max_value / self.threshold
|
|
82
83
|
|
|
83
84
|
non_linear_node.framework_attr[RELU_MAX_VALUE] = np.float32(self.threshold)
|
|
84
|
-
|
|
85
|
+
Logger.debug(
|
|
85
86
|
f"Node named:{non_linear_node.name} max value change "
|
|
86
87
|
f"to:{non_linear_node.framework_attr[RELU_MAX_VALUE]}")
|
|
87
88
|
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py
CHANGED
|
@@ -20,7 +20,8 @@ from model_compression_toolkit.core import common
|
|
|
20
20
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
21
21
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher,NodeFrameworkAttrMatcher
|
|
22
22
|
from model_compression_toolkit.core.keras.constants import RELU_MAX_VALUE
|
|
23
|
-
from model_compression_toolkit.
|
|
23
|
+
from model_compression_toolkit.constants import THRESHOLD
|
|
24
|
+
from model_compression_toolkit.logger import Logger
|
|
24
25
|
|
|
25
26
|
MATCHER = NodeOperationMatcher(ReLU) & NodeFrameworkAttrMatcher(RELU_MAX_VALUE, None).logic_not()
|
|
26
27
|
|
|
@@ -56,5 +57,5 @@ class RemoveReLUUpperBound(common.BaseSubstitution):
|
|
|
56
57
|
node.final_activation_quantization_cfg.activation_quantization_params.get(THRESHOLD) == \
|
|
57
58
|
node.framework_attr.get(RELU_MAX_VALUE):
|
|
58
59
|
node.framework_attr[RELU_MAX_VALUE] = None
|
|
59
|
-
|
|
60
|
+
Logger.info(f'Removing upper bound of {node.name}. Threshold and upper bound are equal.')
|
|
60
61
|
return graph
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py
CHANGED
|
@@ -21,7 +21,7 @@ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOpera
|
|
|
21
21
|
NodeFrameworkAttrMatcher
|
|
22
22
|
from model_compression_toolkit.core.common.substitutions.residual_collapsing import ResidualCollapsing
|
|
23
23
|
from model_compression_toolkit.core.keras.constants import KERNEL, LINEAR, ACTIVATION, LAYER_NAME
|
|
24
|
-
from model_compression_toolkit.
|
|
24
|
+
from model_compression_toolkit.logger import Logger
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
def residual_collapsing_node_matchers() -> Tuple[NodeOperationMatcher, NodeOperationMatcher]:
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py
CHANGED
|
@@ -22,7 +22,7 @@ from tensorflow.python.keras.layers.core import TFOpLambda
|
|
|
22
22
|
from tensorflow.keras.layers import Activation, Conv2D, Dense, DepthwiseConv2D, ZeroPadding2D, Reshape, \
|
|
23
23
|
GlobalAveragePooling2D, Dropout, ReLU, PReLU, ELU
|
|
24
24
|
|
|
25
|
-
from model_compression_toolkit import CoreConfig, FrameworkInfo
|
|
25
|
+
from model_compression_toolkit.core import CoreConfig, FrameworkInfo
|
|
26
26
|
from model_compression_toolkit.core.common import BaseNode, Graph
|
|
27
27
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
|
28
28
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, \
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from typing import List, Any, Tuple, Callable,
|
|
15
|
+
from typing import List, Any, Tuple, Callable, Dict
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
import tensorflow as tf
|
|
@@ -43,7 +43,7 @@ else:
|
|
|
43
43
|
Concatenate, Add
|
|
44
44
|
from keras.layers.core import TFOpLambda
|
|
45
45
|
|
|
46
|
-
from model_compression_toolkit import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfigV2
|
|
46
|
+
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfigV2
|
|
47
47
|
from model_compression_toolkit.core import common
|
|
48
48
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
49
49
|
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
|
|
@@ -52,8 +52,6 @@ from model_compression_toolkit.core.common.model_builder_mode import ModelBuilde
|
|
|
52
52
|
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
|
|
53
53
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
54
54
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
55
|
-
from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
|
|
56
|
-
from model_compression_toolkit.gptq.keras.gptq_training import KerasGPTQTrainer
|
|
57
55
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.activation_decomposition import \
|
|
58
56
|
ActivationDecomposition
|
|
59
57
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.softmax_shift import \
|
|
@@ -348,12 +346,6 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
348
346
|
substitutions_list.append(keras_batchnorm_refusing())
|
|
349
347
|
return substitutions_list
|
|
350
348
|
|
|
351
|
-
def get_gptq_trainer_obj(self) -> Type[GPTQTrainer]:
|
|
352
|
-
"""
|
|
353
|
-
Returns: Keras object of GPTQTrainer
|
|
354
|
-
"""
|
|
355
|
-
return KerasGPTQTrainer
|
|
356
|
-
|
|
357
349
|
def get_sensitivity_evaluator(self,
|
|
358
350
|
graph: Graph,
|
|
359
351
|
quant_config: MixedPrecisionQuantizationConfigV2,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from tensorflow.keras.models import Model
|
|
2
2
|
|
|
3
|
-
from model_compression_toolkit import FrameworkInfo
|
|
3
|
+
from model_compression_toolkit.core import FrameworkInfo
|
|
4
4
|
from model_compression_toolkit.core.common.framework_info import ChannelAxis
|
|
5
5
|
from model_compression_toolkit.core.common.model_validation import ModelValidation
|
|
6
6
|
from model_compression_toolkit.core.keras.constants import CHANNELS_FORMAT, CHANNELS_FORMAT_LAST, CHANNELS_FORMAT_FIRST
|
|
@@ -8,7 +8,7 @@ if version.parse(tf.__version__) < version.parse("2.6"):
|
|
|
8
8
|
else:
|
|
9
9
|
from keras.layers import Activation, ReLU, BatchNormalization
|
|
10
10
|
|
|
11
|
-
from model_compression_toolkit import FrameworkInfo
|
|
11
|
+
from model_compression_toolkit.core import FrameworkInfo
|
|
12
12
|
from model_compression_toolkit.core.common import BaseNode
|
|
13
13
|
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
|
|
14
14
|
from model_compression_toolkit.core.keras.constants import ACTIVATION, RELU_MAX_VALUE, NEGATIVE_SLOPE, THRESHOLD, \
|
|
@@ -15,19 +15,19 @@
|
|
|
15
15
|
|
|
16
16
|
from typing import Callable
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit import MixedPrecisionQuantizationConfig, CoreConfig, MixedPrecisionQuantizationConfigV2
|
|
18
|
+
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, CoreConfig, MixedPrecisionQuantizationConfigV2
|
|
19
19
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
20
|
-
from model_compression_toolkit.
|
|
21
|
-
from model_compression_toolkit.
|
|
22
|
-
from model_compression_toolkit.
|
|
20
|
+
from model_compression_toolkit.logger import Logger
|
|
21
|
+
from model_compression_toolkit.constants import TENSORFLOW
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
|
23
23
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi_data import compute_kpi_data
|
|
24
24
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
25
25
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
26
26
|
DEFAULT_MIXEDPRECISION_CONFIG
|
|
27
|
-
from model_compression_toolkit.
|
|
27
|
+
from model_compression_toolkit.constants import FOUND_TF
|
|
28
28
|
|
|
29
29
|
if FOUND_TF:
|
|
30
|
-
from model_compression_toolkit.
|
|
30
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
31
31
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
32
32
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
|
33
33
|
from tensorflow.keras.models import Model
|
|
@@ -51,7 +51,7 @@ if FOUND_TF:
|
|
|
51
51
|
representative_data_gen (Callable): Dataset used for calibration.
|
|
52
52
|
quant_config (MixedPrecisionQuantizationConfig): MixedPrecisionQuantizationConfig containing parameters of how the model should be quantized.
|
|
53
53
|
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
|
|
54
|
-
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
54
|
+
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
55
55
|
|
|
56
56
|
Returns:
|
|
57
57
|
A KPI object with total weights parameters sum, max activation tensor and total kpi.
|
|
@@ -75,7 +75,7 @@ if FOUND_TF:
|
|
|
75
75
|
Import MCT and call for KPI data calculation:
|
|
76
76
|
|
|
77
77
|
>>> import model_compression_toolkit as mct
|
|
78
|
-
>>> kpi_data = mct.keras_kpi_data(model, repr_datagen)
|
|
78
|
+
>>> kpi_data = mct.core.keras_kpi_data(model, repr_datagen)
|
|
79
79
|
|
|
80
80
|
|
|
81
81
|
"""
|
|
@@ -112,7 +112,7 @@ if FOUND_TF:
|
|
|
112
112
|
representative_data_gen (Callable): Dataset used for calibration.
|
|
113
113
|
core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision of how the model should be quantized.
|
|
114
114
|
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
|
|
115
|
-
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
115
|
+
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
116
116
|
|
|
117
117
|
Returns:
|
|
118
118
|
|
|
@@ -133,7 +133,7 @@ if FOUND_TF:
|
|
|
133
133
|
Import MCT and call for KPI data calculation:
|
|
134
134
|
|
|
135
135
|
>>> import model_compression_toolkit as mct
|
|
136
|
-
>>> kpi_data = mct.keras_kpi_data(model, repr_datagen)
|
|
136
|
+
>>> kpi_data = mct.core.keras_kpi_data(model, repr_datagen)
|
|
137
137
|
|
|
138
138
|
"""
|
|
139
139
|
|
|
@@ -20,8 +20,8 @@ import tensorflow as tf
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
from tensorflow.python.util.object_identity import Reference as TFReference
|
|
22
22
|
|
|
23
|
-
from model_compression_toolkit.
|
|
24
|
-
from model_compression_toolkit.
|
|
23
|
+
from model_compression_toolkit.logger import Logger
|
|
24
|
+
from model_compression_toolkit.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX
|
|
25
25
|
from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import threshold_is_power_of_two
|
|
26
26
|
|
|
27
27
|
|
|
@@ -23,7 +23,7 @@ from tensorflow_model_optimization.python.core.quantization.keras.quantize_confi
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
from model_compression_toolkit.core.common import BaseNode
|
|
26
|
-
from model_compression_toolkit.
|
|
26
|
+
from model_compression_toolkit.constants import INPUT_BASE_NAME
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class InputLayerWrapperTransform(InputLayerQuantize):
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
from typing import Tuple, Dict,
|
|
1
|
+
from typing import Tuple, Dict, Callable
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import tensorflow as tf
|
|
5
5
|
from keras.layers import Layer
|
|
6
6
|
from tensorflow.python.util.object_identity import Reference as TFReference
|
|
7
7
|
|
|
8
|
-
from model_compression_toolkit.
|
|
8
|
+
from model_compression_toolkit.constants import SIGNED, CLUSTER_CENTERS, EPS, \
|
|
9
9
|
MULTIPLIER_N_BITS, THRESHOLD
|
|
10
10
|
|
|
11
11
|
|
|
@@ -24,7 +24,7 @@ from model_compression_toolkit.core.common.quantization.candidate_node_quantizat
|
|
|
24
24
|
from model_compression_toolkit.core.keras.quantizer.mixed_precision.selective_activation_quantizer import \
|
|
25
25
|
SelectiveActivationQuantizer
|
|
26
26
|
from packaging import version
|
|
27
|
-
from model_compression_toolkit.
|
|
27
|
+
from model_compression_toolkit.logger import Logger
|
|
28
28
|
|
|
29
29
|
if version.parse(tf.__version__) < version.parse("2.6"):
|
|
30
30
|
from tensorflow.python.keras.layers import Layer # pragma: no cover
|
|
@@ -29,7 +29,7 @@ else:
|
|
|
29
29
|
from keras.engine.functional import Functional
|
|
30
30
|
from keras.engine.sequential import Sequential
|
|
31
31
|
|
|
32
|
-
from model_compression_toolkit.
|
|
32
|
+
from model_compression_toolkit.logger import Logger
|
|
33
33
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
34
34
|
|
|
35
35
|
|
model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py
CHANGED
|
@@ -19,7 +19,7 @@ from tensorflow.keras.layers import BatchNormalization
|
|
|
19
19
|
from tqdm import tqdm
|
|
20
20
|
|
|
21
21
|
import model_compression_toolkit.core.keras.constants as keras_constants
|
|
22
|
-
from model_compression_toolkit import CoreConfig
|
|
22
|
+
from model_compression_toolkit.core import CoreConfig
|
|
23
23
|
from model_compression_toolkit.core import common
|
|
24
24
|
|
|
25
25
|
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from model_compression_toolkit.
|
|
16
|
+
from model_compression_toolkit.logger import Logger
|
|
17
17
|
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
|
|
18
18
|
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
|
|
19
19
|
from model_compression_toolkit.core.pytorch.back2framework.mixed_precision_model_builder import \
|
|
@@ -17,7 +17,7 @@ from typing import List, Tuple
|
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
|
|
20
|
-
from model_compression_toolkit import FrameworkInfo
|
|
20
|
+
from model_compression_toolkit.core import FrameworkInfo
|
|
21
21
|
from model_compression_toolkit.core import common
|
|
22
22
|
from model_compression_toolkit.core.common import BaseNode
|
|
23
23
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
@@ -17,7 +17,7 @@ from typing import List, Any, Tuple
|
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
|
|
20
|
-
from model_compression_toolkit import FrameworkInfo
|
|
20
|
+
from model_compression_toolkit.core import FrameworkInfo
|
|
21
21
|
from model_compression_toolkit.core import common
|
|
22
22
|
from model_compression_toolkit.core.common import BaseNode
|
|
23
23
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
|
@@ -22,14 +22,14 @@ import numpy as np
|
|
|
22
22
|
|
|
23
23
|
from model_compression_toolkit.core import common
|
|
24
24
|
from model_compression_toolkit.core.common import BaseNode, Graph
|
|
25
|
-
from model_compression_toolkit.
|
|
25
|
+
from model_compression_toolkit.constants import EPS, MIN_JACOBIANS_ITER, JACOBIANS_COMP_TOLERANCE
|
|
26
26
|
from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
|
|
27
27
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
|
28
28
|
from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder
|
|
29
29
|
from model_compression_toolkit.core.pytorch.constants import BUFFER
|
|
30
30
|
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder, BufferHolder
|
|
31
31
|
from model_compression_toolkit.core.pytorch.utils import torch_tensor_to_numpy, get_working_device
|
|
32
|
-
from model_compression_toolkit.
|
|
32
|
+
from model_compression_toolkit.logger import Logger
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
def build_input_tensors_list(node: BaseNode,
|
|
@@ -18,7 +18,7 @@ from typing import Tuple, Any, Dict, List, Union, Callable
|
|
|
18
18
|
import torch
|
|
19
19
|
from networkx import topological_sort
|
|
20
20
|
|
|
21
|
-
from model_compression_toolkit import FrameworkInfo
|
|
21
|
+
from model_compression_toolkit.core import FrameworkInfo
|
|
22
22
|
from model_compression_toolkit.core import common
|
|
23
23
|
from model_compression_toolkit.core.common import BaseNode, Graph
|
|
24
24
|
from model_compression_toolkit.core.common.back2framework.base_model_builder import BaseModelBuilder
|
|
@@ -17,7 +17,7 @@ from typing import List, Tuple
|
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
|
|
20
|
-
from model_compression_toolkit import FrameworkInfo
|
|
20
|
+
from model_compression_toolkit.core import FrameworkInfo
|
|
21
21
|
from model_compression_toolkit.core import common
|
|
22
22
|
from model_compression_toolkit.core.common import BaseNode
|
|
23
23
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
@@ -69,12 +69,6 @@ CPU = 'cpu'
|
|
|
69
69
|
# ReLU bound constants
|
|
70
70
|
RELU_POT_BOUND = 8.0
|
|
71
71
|
|
|
72
|
-
# Supported TP models names for Pytorch:
|
|
73
|
-
DEFAULT_TP_MODEL = 'default'
|
|
74
|
-
IMX500_TP_MODEL = 'imx500'
|
|
75
|
-
TFLITE_TP_MODEL = 'tflite'
|
|
76
|
-
QNNPACK_TP_MODEL = 'qnnpack'
|
|
77
|
-
|
|
78
72
|
# MultiHeadAttention layer attributes:
|
|
79
73
|
EMBED_DIM = 'embed_dim'
|
|
80
74
|
NUM_HEADS = 'num_heads'
|
|
@@ -19,8 +19,8 @@ from torch import sigmoid
|
|
|
19
19
|
|
|
20
20
|
from model_compression_toolkit.core.common.defaultdict import DefaultDict
|
|
21
21
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo, ChannelAxis
|
|
22
|
-
from model_compression_toolkit.
|
|
23
|
-
from model_compression_toolkit.
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
23
|
+
from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
|
|
24
24
|
from model_compression_toolkit.core.pytorch.constants import KERNEL
|
|
25
25
|
from model_compression_toolkit.core.pytorch.quantizer.fake_quant_builder import power_of_two_quantization, \
|
|
26
26
|
symmetric_quantization, uniform_quantization
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py
CHANGED
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOpera
|
|
|
22
22
|
from model_compression_toolkit.core.common import BaseNode
|
|
23
23
|
from model_compression_toolkit.core.common.substitutions.linear_collapsing import Conv2DCollapsing
|
|
24
24
|
from model_compression_toolkit.core.pytorch.constants import KERNEL, KERNEL_SIZE, STRIDES, DILATIONS, BIAS, USE_BIAS, FILTERS, PADDING, GROUPS
|
|
25
|
-
from model_compression_toolkit.
|
|
25
|
+
from model_compression_toolkit.logger import Logger
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def linear_collapsing_node_matchers() -> Tuple[NodeOperationMatcher, NodeOperationMatcher]:
|
|
@@ -20,7 +20,7 @@ import torch.nn as nn
|
|
|
20
20
|
import operator
|
|
21
21
|
from typing import List
|
|
22
22
|
|
|
23
|
-
from model_compression_toolkit.
|
|
23
|
+
from model_compression_toolkit.logger import Logger
|
|
24
24
|
from model_compression_toolkit.core import common
|
|
25
25
|
from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
|
|
26
26
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py
CHANGED
|
@@ -25,6 +25,7 @@ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOpera
|
|
|
25
25
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
26
26
|
from model_compression_toolkit.core.pytorch.constants import KERNEL, BIAS, INPLACE, HARDTANH_MIN_VAL, HARDTANH_MAX_VAL, \
|
|
27
27
|
RELU_POT_BOUND
|
|
28
|
+
from model_compression_toolkit.logger import Logger
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
|
|
@@ -102,8 +103,8 @@ class ReLUBoundToPowerOfTwo(common.BaseSubstitution):
|
|
|
102
103
|
else:
|
|
103
104
|
return graph
|
|
104
105
|
else:
|
|
105
|
-
|
|
106
|
-
|
|
106
|
+
Logger.error(f"In substitution with wrong matched pattern")
|
|
107
|
+
Logger.debug(
|
|
107
108
|
f"Node named:{non_linear_node.name} changed "
|
|
108
109
|
f"to:{non_linear_node.type}")
|
|
109
110
|
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
from torch import reshape
|
|
16
16
|
import torch
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
19
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
20
20
|
from model_compression_toolkit.core import common
|
|
21
21
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py
CHANGED
|
@@ -20,7 +20,7 @@ from model_compression_toolkit.core.common import BaseNode
|
|
|
20
20
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
21
21
|
from model_compression_toolkit.core.common.substitutions.residual_collapsing import ResidualCollapsing
|
|
22
22
|
from model_compression_toolkit.core.pytorch.constants import KERNEL
|
|
23
|
-
from model_compression_toolkit.
|
|
23
|
+
from model_compression_toolkit.logger import Logger
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def residual_collapsing_node_matchers() -> Tuple[NodeOperationMatcher, NodeOperationMatcher]:
|
|
@@ -21,7 +21,7 @@ from torch import reshape
|
|
|
21
21
|
from torch.nn.functional import hardswish, silu, prelu, elu
|
|
22
22
|
from torch.nn.functional import avg_pool2d
|
|
23
23
|
|
|
24
|
-
from model_compression_toolkit import CoreConfig, FrameworkInfo
|
|
24
|
+
from model_compression_toolkit.core import CoreConfig, FrameworkInfo
|
|
25
25
|
from model_compression_toolkit.core import common
|
|
26
26
|
from model_compression_toolkit.core.common import BaseNode, Graph
|
|
27
27
|
from model_compression_toolkit.core.common.graph.graph_matchers import EdgeMatcher
|
|
@@ -15,21 +15,21 @@
|
|
|
15
15
|
|
|
16
16
|
from typing import Callable
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
20
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.constants import PYTORCH
|
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
|
21
21
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
22
22
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
23
23
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi_data import compute_kpi_data
|
|
24
24
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
|
25
25
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
26
26
|
MixedPrecisionQuantizationConfig, DEFAULT_MIXEDPRECISION_CONFIG, MixedPrecisionQuantizationConfigV2
|
|
27
|
-
from model_compression_toolkit.
|
|
27
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
|
28
28
|
|
|
29
29
|
if FOUND_TORCH:
|
|
30
30
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
|
31
31
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
|
32
|
-
from model_compression_toolkit.
|
|
32
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
33
33
|
from torch.nn import Module
|
|
34
34
|
|
|
35
35
|
from model_compression_toolkit import get_target_platform_capabilities
|
|
@@ -51,7 +51,7 @@ if FOUND_TORCH:
|
|
|
51
51
|
representative_data_gen (Callable): Dataset used for calibration.
|
|
52
52
|
quant_config (MixedPrecisionQuantizationConfig): MixedPrecisionQuantizationConfig containing parameters of how the model should be quantized.
|
|
53
53
|
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default PyTorch info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/pytorch/default_framework_info.py>`_
|
|
54
|
-
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
54
|
+
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
55
55
|
|
|
56
56
|
Returns:
|
|
57
57
|
A KPI object with total weights parameters sum, max activation tensor and total kpi.
|
|
@@ -75,7 +75,7 @@ if FOUND_TORCH:
|
|
|
75
75
|
Import mct and call for KPI data calculation:
|
|
76
76
|
|
|
77
77
|
>>> import model_compression_toolkit as mct
|
|
78
|
-
>>> kpi_data = mct.pytorch_kpi_data(module, repr_datagen)
|
|
78
|
+
>>> kpi_data = mct.core.pytorch_kpi_data(module, repr_datagen)
|
|
79
79
|
|
|
80
80
|
"""
|
|
81
81
|
|
|
@@ -111,7 +111,7 @@ if FOUND_TORCH:
|
|
|
111
111
|
representative_data_gen (Callable): Dataset used for calibration.
|
|
112
112
|
core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision
|
|
113
113
|
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default PyTorch info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/pytorch/default_framework_info.py>`_
|
|
114
|
-
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
114
|
+
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
115
115
|
|
|
116
116
|
Returns:
|
|
117
117
|
|
|
@@ -132,7 +132,7 @@ if FOUND_TORCH:
|
|
|
132
132
|
Import mct and call for KPI data calculation:
|
|
133
133
|
|
|
134
134
|
>>> import model_compression_toolkit as mct
|
|
135
|
-
>>> kpi_data = mct.pytorch_kpi_data(module, repr_datagen)
|
|
135
|
+
>>> kpi_data = mct.core.pytorch_kpi_data(module, repr_datagen)
|
|
136
136
|
|
|
137
137
|
"""
|
|
138
138
|
|
|
@@ -18,7 +18,7 @@ from typing import Any, List
|
|
|
18
18
|
import torch
|
|
19
19
|
import copy
|
|
20
20
|
|
|
21
|
-
from model_compression_toolkit import FrameworkInfo
|
|
21
|
+
from model_compression_toolkit.core import FrameworkInfo
|
|
22
22
|
from model_compression_toolkit.core.common import BaseNode
|
|
23
23
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
|
24
24
|
from model_compression_toolkit.core.pytorch.utils import set_model, to_torch_tensor
|