mct-nightly 1.8.0.22032023.post333__py3-none-any.whl → 1.8.0.22052023.post408__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/METADATA +4 -3
- {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/RECORD +294 -284
- model_compression_toolkit/__init__.py +9 -32
- model_compression_toolkit/{core/common/constants.py → constants.py} +2 -6
- model_compression_toolkit/core/__init__.py +14 -0
- model_compression_toolkit/core/analyzer.py +3 -2
- model_compression_toolkit/core/common/__init__.py +0 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/mean_collector.py +1 -1
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +1 -8
- model_compression_toolkit/core/common/framework_info.py +1 -1
- model_compression_toolkit/core/common/fusion/layer_fusing.py +4 -4
- model_compression_toolkit/core/common/graph/base_graph.py +2 -2
- model_compression_toolkit/core/common/graph/base_node.py +57 -1
- model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
- model_compression_toolkit/core/common/memory_computation.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -4
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/model_collector.py +2 -2
- model_compression_toolkit/core/common/model_validation.py +1 -1
- model_compression_toolkit/core/common/network_editors/actions.py +4 -1
- model_compression_toolkit/core/common/network_editors/edit_network.py +0 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +3 -4
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +3 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +2 -2
- model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +4 -4
- model_compression_toolkit/core/common/quantization/quantize_node.py +2 -2
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +4 -2
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +2 -2
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +7 -7
- model_compression_toolkit/core/common/similarity_analyzer.py +2 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -4
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +5 -5
- model_compression_toolkit/core/common/substitutions/apply_substitutions.py +2 -5
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -3
- model_compression_toolkit/core/common/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +0 -3
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +1 -1
- model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +66 -21
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/keras/back2framework/model_gradients.py +5 -4
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/keras/constants.py +0 -7
- model_compression_toolkit/core/keras/default_framework_info.py +3 -3
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/activation_decomposition.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +3 -4
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +2 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/keras/keras_implementation.py +2 -10
- model_compression_toolkit/core/keras/keras_model_validation.py +1 -1
- model_compression_toolkit/core/keras/keras_node_prior_info.py +1 -1
- model_compression_toolkit/core/keras/kpi_data_facade.py +10 -10
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/keras/quantizer/input_layer_quantize_transform.py +1 -1
- model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py +2 -2
- model_compression_toolkit/core/keras/quantizer/mixed_precision/selective_quantize_config.py +1 -1
- model_compression_toolkit/core/keras/reader/common.py +1 -1
- model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/model_gradients.py +15 -8
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/constants.py +0 -6
- model_compression_toolkit/core/pytorch/default_framework_info.py +2 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +3 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -1
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +9 -9
- model_compression_toolkit/core/pytorch/mixed_precision/mixed_precision_wrapper.py +1 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -12
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -1
- model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py +2 -2
- model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py +1 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +4 -2
- model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py +1 -1
- model_compression_toolkit/core/runner.py +7 -7
- model_compression_toolkit/exporter/__init__.py +6 -3
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/export_serialization_format.py +20 -0
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/{tflite → keras}/fakely_quant_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/{tflite → keras}/int8_tflite_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +60 -22
- model_compression_toolkit/exporter/model_exporter/pytorch/export_serialization_format.py +20 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +15 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +54 -31
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +5 -3
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +5 -3
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +3 -2
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +4 -3
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +2 -2
- model_compression_toolkit/gptq/common/gptq_config.py +2 -4
- model_compression_toolkit/gptq/common/gptq_framework_implementation.py +32 -0
- model_compression_toolkit/gptq/common/gptq_graph.py +2 -2
- model_compression_toolkit/gptq/common/gptq_training.py +5 -4
- model_compression_toolkit/gptq/keras/gptq_keras_implementation.py +29 -0
- model_compression_toolkit/gptq/keras/gptq_training.py +41 -14
- model_compression_toolkit/gptq/keras/graph_info.py +4 -0
- model_compression_toolkit/gptq/keras/quantization_facade.py +27 -20
- model_compression_toolkit/gptq/keras/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +2 -2
- model_compression_toolkit/gptq/keras/quantizer/quant_utils.py +18 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -5
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +21 -16
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +224 -0
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +2 -2
- model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py +29 -0
- model_compression_toolkit/gptq/pytorch/gptq_training.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +13 -13
- model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +1 -0
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -3
- model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +18 -4
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -2
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +13 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +194 -0
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +2 -2
- model_compression_toolkit/gptq/runner.py +3 -2
- model_compression_toolkit/{core/keras/quantization_facade.py → legacy/keras_quantization_facade.py} +11 -12
- model_compression_toolkit/{core/pytorch/quantization_facade.py → legacy/pytorch_quantization_facade.py} +11 -12
- model_compression_toolkit/ptq/__init__.py +3 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +11 -12
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -8
- model_compression_toolkit/qat/__init__.py +4 -0
- model_compression_toolkit/qat/common/__init__.py +1 -2
- model_compression_toolkit/qat/common/qat_config.py +5 -1
- model_compression_toolkit/qat/keras/quantization_facade.py +34 -28
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +31 -4
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +13 -11
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +9 -9
- model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -9
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +4 -3
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +7 -5
- model_compression_toolkit/quantizers_infrastructure/__init__.py +2 -2
- model_compression_toolkit/{qat/common → quantizers_infrastructure}/constants.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +5 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/activation_quantization_holder.py +147 -0
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/load_model.py +5 -5
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +3 -5
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +2 -3
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +3 -3
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +9 -9
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizer_config.py +2 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +4 -6
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/trainable_quantizer_config.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +2 -2
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +2 -2
- model_compression_toolkit/target_platform_capabilities/constants.py +27 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/current_tp_model.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/fusing.py +2 -2
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/operators.py +4 -4
- model_compression_toolkit/target_platform_capabilities/target_platform/quantization_format.py +20 -0
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model.py +16 -7
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/target_platform_model_component.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/__init__.py +5 -5
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/attribute_filter.py +1 -1
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/layer_filter_params.py +33 -35
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/operations_to_layers.py +4 -4
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities.py +9 -30
- model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/target_platform_capabilities_component.py +1 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py +0 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/default_tpc/latest/__init__.py +25 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/target_platform_capabilities.py +19 -17
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v4_lut → target_platform_capabilities/tpc_models/default_tpc/v3_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v5 → target_platform_capabilities/tpc_models/default_tpc/v4}/tpc_keras.py +2 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core/tpc_models/default_tpc/v3_lut → target_platform_capabilities/tpc_models/default_tpc/v4_lut}/tp_model.py +7 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tp_model.py +7 -1
- model_compression_toolkit/{core/tpc_models/default_tpc/v4 → target_platform_capabilities/tpc_models/default_tpc/v5}/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/tpc_pytorch.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/get_target_platform_capabilities.py +6 -8
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py +14 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/latest/__init__.py +6 -6
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tp_model.py +7 -1
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_keras.py +2 -2
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +2 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +22 -0
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/target_platform_capabilities.py +6 -5
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tp_model.py +26 -18
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -3
- model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -3
- model_compression_toolkit/core/tpc_models/default_tpc/latest/__init__.py +0 -25
- model_compression_toolkit/core/tpc_models/qnnpack_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/core/tpc_models/tflite_tpc/latest/__init__.py +0 -22
- model_compression_toolkit/exporter/model_exporter/tflite/__init__.py +0 -14
- model_compression_toolkit/exporter/model_exporter/tflite/tflite_export_facade.py +0 -73
- {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/WHEEL +0 -0
- {mct_nightly-1.8.0.22032023.post333.dist-info → mct_nightly-1.8.0.22052023.post408.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/{core/tpc_models/imx500_tpc → legacy}/__init__.py +0 -0
- /model_compression_toolkit/{core/common/logger.py → logger.py} +0 -0
- /model_compression_toolkit/{core/tpc_models → target_platform_capabilities}/__init__.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/immutable.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/op_quantization_config.py +0 -0
- /model_compression_toolkit/{core/common → target_platform_capabilities}/target_platform/targetplatform2framework/current_tpc.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v2/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v3_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v4_lut/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/default_tpc/v5/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/imx500_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/qnnpack_tpc/v1/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/__init__.py +0 -0
- /model_compression_toolkit/{core → target_platform_capabilities}/tpc_models/tflite_tpc/v1/__init__.py +0 -0
|
@@ -12,63 +12,86 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from enum import Enum
|
|
16
15
|
from typing import Callable
|
|
17
16
|
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
17
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
|
18
|
+
from model_compression_toolkit.exporter.model_exporter.pytorch.export_serialization_format import \
|
|
19
|
+
PytorchExportSerializationFormat
|
|
20
|
+
from model_compression_toolkit.logger import Logger
|
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.quantization_format import \
|
|
23
|
+
QuantizationFormat
|
|
26
24
|
|
|
27
25
|
if FOUND_TORCH:
|
|
28
26
|
import torch.nn
|
|
29
|
-
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import
|
|
30
|
-
|
|
27
|
+
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import \
|
|
28
|
+
FakelyQuantONNXPyTorchExporter
|
|
29
|
+
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_torchscript_pytorch_exporter import \
|
|
30
|
+
FakelyQuantTorchScriptPyTorchExporter
|
|
31
31
|
from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable
|
|
32
32
|
|
|
33
|
+
supported_serialization_quantization_export_dict = {
|
|
34
|
+
PytorchExportSerializationFormat.TORCHSCRIPT: [QuantizationFormat.FAKELY_QUANT],
|
|
35
|
+
PytorchExportSerializationFormat.ONNX: [QuantizationFormat.FAKELY_QUANT]
|
|
36
|
+
}
|
|
37
|
+
|
|
33
38
|
def pytorch_export_model(model: torch.nn.Module,
|
|
34
39
|
save_model_path: str,
|
|
35
40
|
repr_dataset: Callable,
|
|
41
|
+
target_platform_capabilities: TargetPlatformCapabilities,
|
|
36
42
|
is_layer_exportable_fn: Callable = is_pytorch_layer_exportable,
|
|
37
|
-
|
|
43
|
+
serialization_format: PytorchExportSerializationFormat =
|
|
44
|
+
PytorchExportSerializationFormat.TORCHSCRIPT) -> None:
|
|
38
45
|
"""
|
|
39
46
|
Export a PyTorch quantized model to a torchscript or onnx model.
|
|
40
47
|
The model will be saved to the path in save_model_path.
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
is in an ONNX format and its weights and activations are float fakely-quantized values)
|
|
48
|
+
Currently, pytorch_export_model supports only QuantizationFormat.FAKELY_QUANT (where weights
|
|
49
|
+
and activations are float fakely-quantized values) and PytorchExportSerializationFormat.TORCHSCRIPT
|
|
50
|
+
(where the model will be saved to TorchScript model) or PytorchExportSerializationFormat.ONNX
|
|
51
|
+
(where the model will be saved to ONNX model).
|
|
46
52
|
|
|
47
53
|
Args:
|
|
48
54
|
model: Model to export.
|
|
49
|
-
is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
|
|
50
|
-
mode: Mode to export the model according to.
|
|
51
55
|
save_model_path: Path to save the model.
|
|
52
56
|
repr_dataset: Representative dataset for tracing the pytorch model (mandatory for exporting it).
|
|
57
|
+
target_platform_capabilities: TargetPlatformCapabilities object that describes the desired inference
|
|
58
|
+
target platform (includes quantization format).
|
|
59
|
+
is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
|
|
60
|
+
serialization_format: Format to export the model according to (by default
|
|
61
|
+
PytorchExportSerializationFormat.TORCHSCRIPT).
|
|
53
62
|
|
|
54
63
|
"""
|
|
55
64
|
|
|
56
|
-
if
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
65
|
+
if serialization_format == PytorchExportSerializationFormat.TORCHSCRIPT:
|
|
66
|
+
if target_platform_capabilities.tp_model.quantization_format in \
|
|
67
|
+
supported_serialization_quantization_export_dict[serialization_format]:
|
|
68
|
+
exporter = FakelyQuantTorchScriptPyTorchExporter(model,
|
|
69
|
+
is_layer_exportable_fn,
|
|
70
|
+
save_model_path,
|
|
71
|
+
repr_dataset)
|
|
72
|
+
else:
|
|
73
|
+
Logger.critical(
|
|
74
|
+
f'Unsupported quantization {target_platform_capabilities.tp_model.quantization_format} for '
|
|
75
|
+
f'serialization {serialization_format} was used to export Pytorch model. Please see API for '
|
|
76
|
+
f'supported formats.') # pragma: no cover
|
|
61
77
|
|
|
62
|
-
elif
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
78
|
+
elif serialization_format == PytorchExportSerializationFormat.ONNX:
|
|
79
|
+
if target_platform_capabilities.tp_model.quantization_format in \
|
|
80
|
+
supported_serialization_quantization_export_dict[serialization_format]:
|
|
81
|
+
exporter = FakelyQuantONNXPyTorchExporter(model,
|
|
82
|
+
is_layer_exportable_fn,
|
|
83
|
+
save_model_path,
|
|
84
|
+
repr_dataset)
|
|
85
|
+
else:
|
|
86
|
+
Logger.critical(
|
|
87
|
+
f'Unsupported quantization {target_platform_capabilities.tp_model.quantization_format} for '
|
|
88
|
+
f'serialization {serialization_format} was used to export Pytorch model. Please see API for '
|
|
89
|
+
f'supported formats.') # pragma: no cover
|
|
67
90
|
|
|
68
91
|
else:
|
|
69
92
|
Logger.critical(
|
|
70
|
-
f'Unsupported
|
|
71
|
-
f'
|
|
93
|
+
f'Unsupported serialization {serialization_format} was used to export Pytorch model. Please see API '
|
|
94
|
+
f'for supported formats.') # pragma: no cover
|
|
72
95
|
|
|
73
96
|
exporter.export()
|
|
74
97
|
|
model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -17,9 +17,10 @@ from typing import Tuple
|
|
|
17
17
|
|
|
18
18
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
19
19
|
from model_compression_toolkit.core import common
|
|
20
|
-
from model_compression_toolkit.core.common import Graph
|
|
21
|
-
from model_compression_toolkit.
|
|
20
|
+
from model_compression_toolkit.core.common import Graph
|
|
21
|
+
from model_compression_toolkit.constants import FOUND_TF
|
|
22
22
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
23
|
+
from model_compression_toolkit.logger import Logger
|
|
23
24
|
|
|
24
25
|
if FOUND_TF:
|
|
25
26
|
import tensorflow as tf
|
|
@@ -34,6 +35,7 @@ if FOUND_TF:
|
|
|
34
35
|
Args:
|
|
35
36
|
n: A node of mct graph.
|
|
36
37
|
layer: A keras layer
|
|
38
|
+
include_activation_quantizers: Whether to use the wrapper for the activation quantizer or not
|
|
37
39
|
|
|
38
40
|
Returns: Wrapped layer with weights quantizers and activation quantizers
|
|
39
41
|
|
|
@@ -55,7 +57,7 @@ if FOUND_TF:
|
|
|
55
57
|
Exportable Keras model and user information.
|
|
56
58
|
"""
|
|
57
59
|
exportable_model, user_info = KerasModelBuilder(graph=graph,
|
|
58
|
-
|
|
60
|
+
wrapper=_get_wrapper).build_model()
|
|
59
61
|
exportable_model.trainable = False
|
|
60
62
|
return exportable_model, user_info
|
|
61
63
|
else:
|
|
@@ -14,9 +14,11 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Dict, Any
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.core.common import BaseNode
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
|
|
17
|
+
from model_compression_toolkit.core.common import BaseNode
|
|
18
|
+
from model_compression_toolkit.constants import THRESHOLD, RANGE_MIN, RANGE_MAX, SIGNED, CLUSTER_CENTERS, SCALE_PER_CHANNEL
|
|
19
|
+
|
|
20
|
+
from model_compression_toolkit.logger import Logger
|
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
20
22
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget
|
|
21
23
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import get_inferable_quantizer_class
|
|
22
24
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer import BaseKerasInferableQuantizer
|
|
@@ -15,8 +15,8 @@
|
|
|
15
15
|
from typing import Any
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.constants import FOUND_TF
|
|
20
20
|
|
|
21
21
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import BaseInferableQuantizer
|
|
22
22
|
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -16,8 +16,9 @@
|
|
|
16
16
|
|
|
17
17
|
from model_compression_toolkit import quantizers_infrastructure as qi
|
|
18
18
|
from model_compression_toolkit.core import common
|
|
19
|
-
from model_compression_toolkit.core.common import Graph
|
|
20
|
-
from model_compression_toolkit.
|
|
19
|
+
from model_compression_toolkit.core.common import Graph
|
|
20
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
|
21
|
+
from model_compression_toolkit.logger import Logger
|
|
21
22
|
|
|
22
23
|
if FOUND_TORCH:
|
|
23
24
|
import torch
|
|
@@ -15,10 +15,11 @@
|
|
|
15
15
|
|
|
16
16
|
from typing import Dict, Any
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.core.common import BaseNode
|
|
19
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.core.common import BaseNode
|
|
19
|
+
from model_compression_toolkit.constants import THRESHOLD, SIGNED, RANGE_MIN, RANGE_MAX, \
|
|
20
20
|
SCALE_PER_CHANNEL, CLUSTER_CENTERS
|
|
21
|
-
from model_compression_toolkit.
|
|
21
|
+
from model_compression_toolkit.logger import Logger
|
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
22
23
|
from model_compression_toolkit.quantizers_infrastructure import QuantizationTarget
|
|
23
24
|
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import \
|
|
24
25
|
get_inferable_quantizer_class
|
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Any
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.
|
|
18
|
-
from model_compression_toolkit.
|
|
17
|
+
from model_compression_toolkit.logger import Logger
|
|
18
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
|
19
19
|
|
|
20
20
|
if FOUND_TORCH:
|
|
21
21
|
from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
|
|
@@ -111,10 +111,8 @@ class GradientPTQConfig:
|
|
|
111
111
|
self.regularization_factor = regularization_factor
|
|
112
112
|
self.hessian_weights_config = hessian_weights_config
|
|
113
113
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
self.gptq_quantizer_params_override = {QUANT_PARAM_LEARNING_STR: False} \
|
|
117
|
-
if gptq_quantizer_params_override is None else gptq_quantizer_params_override
|
|
114
|
+
self.gptq_quantizer_params_override = {} if gptq_quantizer_params_override is None \
|
|
115
|
+
else gptq_quantizer_params_override
|
|
118
116
|
|
|
119
117
|
|
|
120
118
|
class GradientPTQConfigV2(GradientPTQConfig):
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from abc import abstractmethod
|
|
17
|
+
|
|
18
|
+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GPTQFrameworkImplemantation(FrameworkImplementation):
|
|
22
|
+
"""
|
|
23
|
+
Class to implement framework related methods that are used in GPTQ
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def get_gptq_trainer_obj(self):
|
|
28
|
+
"""
|
|
29
|
+
Returns: GPTQTrainer object
|
|
30
|
+
"""
|
|
31
|
+
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
32
|
+
f'framework\'s get_gptq_trainer method.') # pragma: no cover
|
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Tuple, List
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit import FrameworkInfo
|
|
18
|
-
from model_compression_toolkit.
|
|
17
|
+
from model_compression_toolkit.core import FrameworkInfo
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
19
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
20
20
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
21
21
|
|
|
@@ -17,12 +17,13 @@ from abc import ABC, abstractmethod
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
from typing import Callable, List, Any
|
|
19
19
|
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
20
|
-
from model_compression_toolkit.core.common import Graph,
|
|
20
|
+
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
21
21
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
22
|
-
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
23
22
|
from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR
|
|
23
|
+
from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
|
|
24
24
|
from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points
|
|
25
25
|
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
|
|
26
|
+
from model_compression_toolkit.logger import Logger
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
class GPTQTrainer(ABC):
|
|
@@ -34,7 +35,7 @@ class GPTQTrainer(ABC):
|
|
|
34
35
|
graph_float: Graph,
|
|
35
36
|
graph_quant: Graph,
|
|
36
37
|
gptq_config: GradientPTQConfig,
|
|
37
|
-
fw_impl:
|
|
38
|
+
fw_impl: GPTQFrameworkImplemantation,
|
|
38
39
|
fw_info: FrameworkInfo):
|
|
39
40
|
"""
|
|
40
41
|
Build two models from a graph: A teacher network (float model) and a student network (quantized model).
|
|
@@ -259,7 +260,7 @@ def gptq_training(graph_float: Graph,
|
|
|
259
260
|
graph_quant: Graph,
|
|
260
261
|
gptq_config: GradientPTQConfig,
|
|
261
262
|
representative_data_gen: Callable,
|
|
262
|
-
fw_impl:
|
|
263
|
+
fw_impl: GPTQFrameworkImplemantation,
|
|
263
264
|
fw_info: FrameworkInfo) -> Graph:
|
|
264
265
|
"""
|
|
265
266
|
GPTQ training process using knowledge distillation with a teacher network (float model) and a student network (quantized model).
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Type
|
|
17
|
+
|
|
18
|
+
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
|
19
|
+
from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
|
|
20
|
+
from model_compression_toolkit.gptq.keras.gptq_training import KerasGPTQTrainer
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class GPTQKerasImplemantation(GPTQFrameworkImplemantation, KerasImplementation):
|
|
24
|
+
|
|
25
|
+
def get_gptq_trainer_obj(self) -> Type[KerasGPTQTrainer]:
|
|
26
|
+
"""
|
|
27
|
+
Returns: Keras object of GPTQTrainer
|
|
28
|
+
"""
|
|
29
|
+
return KerasGPTQTrainer
|
|
@@ -16,17 +16,18 @@ from typing import Callable, List, Tuple, Union
|
|
|
16
16
|
|
|
17
17
|
import tensorflow as tf
|
|
18
18
|
from keras import Model
|
|
19
|
+
from packaging import version
|
|
19
20
|
from tensorflow.keras.layers import Layer
|
|
20
21
|
from tqdm import tqdm
|
|
21
22
|
|
|
22
23
|
# As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
|
|
23
24
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
24
25
|
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
|
25
|
-
from packaging import version
|
|
26
|
-
|
|
27
26
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
28
27
|
from model_compression_toolkit.gptq.keras.quantizer.quantization_builder import quantization_builder
|
|
28
|
+
from model_compression_toolkit.logger import Logger
|
|
29
29
|
from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
|
|
30
|
+
from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.activation_quantization_holder import ActivationQuantizationHolder
|
|
30
31
|
|
|
31
32
|
if version.parse(tf.__version__) < version.parse("2.6"):
|
|
32
33
|
from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
|
|
@@ -105,7 +106,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
105
106
|
[len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0
|
|
106
107
|
|
|
107
108
|
if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
|
|
108
|
-
|
|
109
|
+
Logger.error("Input scale mismatch between float and GPTQ networks") # pragma: no cover
|
|
109
110
|
else:
|
|
110
111
|
self.input_scale = self.gptq_user_info.input_scale
|
|
111
112
|
|
|
@@ -113,8 +114,8 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
113
114
|
|
|
114
115
|
self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
|
|
115
116
|
|
|
116
|
-
def
|
|
117
|
-
|
|
117
|
+
def _is_gptq_weights_trainable(self,
|
|
118
|
+
node: common.BaseNode) -> bool:
|
|
118
119
|
"""
|
|
119
120
|
A function for deciding if a layer should be fine-tuned during GPTQ.
|
|
120
121
|
|
|
@@ -126,11 +127,13 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
126
127
|
"""
|
|
127
128
|
|
|
128
129
|
if node.is_weights_quantization_enabled() and not self.fw_info.is_kernel_op(node.type):
|
|
129
|
-
|
|
130
|
+
Logger.error(f"GPTQ Error: Quantizing node {node.name} of type {node.type} "
|
|
130
131
|
f"without a kernel isn't supported")
|
|
131
132
|
return node.is_weights_quantization_enabled()
|
|
132
133
|
|
|
133
|
-
def gptq_wrapper(self,
|
|
134
|
+
def gptq_wrapper(self,
|
|
135
|
+
n: common.BaseNode,
|
|
136
|
+
layer: Layer) -> Union[qi.KerasQuantizationWrapper, Layer]:
|
|
134
137
|
"""
|
|
135
138
|
A function which takes a computational graph node and a keras layer and perform the quantization wrapping.
|
|
136
139
|
|
|
@@ -141,14 +144,37 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
141
144
|
Returns: Wrapped layer if the layer should be wrap, otherwise returns the layer as is.
|
|
142
145
|
|
|
143
146
|
"""
|
|
144
|
-
if self.
|
|
145
|
-
weights_quantizers,
|
|
147
|
+
if self._is_gptq_weights_trainable(n):
|
|
148
|
+
weights_quantizers, _ = quantization_builder(n, self.gptq_config) # TODO: split quantizers building into two functions: for weights and activations
|
|
146
149
|
return qi.KerasQuantizationWrapper(layer,
|
|
147
|
-
weights_quantizers=weights_quantizers
|
|
148
|
-
activation_quantizers=activation_quantizers)
|
|
150
|
+
weights_quantizers=weights_quantizers)
|
|
149
151
|
else:
|
|
150
152
|
return layer
|
|
151
153
|
|
|
154
|
+
def get_activation_quantizer_holder(self, n: common.BaseNode) -> Union[None, Callable]:
|
|
155
|
+
"""
|
|
156
|
+
Retrieve a ActivationQuantizationHolder layer to use for activation quantization for a node.
|
|
157
|
+
If the layer is not supposed to be wrapped with activation quantizers - return None.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
n: Node to get ActivationQuantizationHolder to attach in its output.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
A ActivationQuantizationHolder layer for the node activation quantization.
|
|
164
|
+
"""
|
|
165
|
+
_, activation_quantizers = quantization_builder(n, self.gptq_config) # TODO: split quantizers building into two functions: for weights and activations
|
|
166
|
+
|
|
167
|
+
# Holder by definition uses a single quantizer for the activation quantization
|
|
168
|
+
# thus we make sure this is the only possible case (unless it's a node with no activation
|
|
169
|
+
# quantization, which in this case has an empty list).
|
|
170
|
+
if len(activation_quantizers) == 1:
|
|
171
|
+
return ActivationQuantizationHolder(activation_quantizers[0])
|
|
172
|
+
|
|
173
|
+
Logger.error(
|
|
174
|
+
f'ActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
|
|
175
|
+
f'were found for node {n}')
|
|
176
|
+
|
|
177
|
+
|
|
152
178
|
def build_gptq_model(self) -> Tuple[Model, UserInformation]:
|
|
153
179
|
"""
|
|
154
180
|
Build the GPTQ model with QuantizationWrappers
|
|
@@ -161,7 +187,8 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
161
187
|
append2output=self.compare_points,
|
|
162
188
|
fw_info=self.fw_info,
|
|
163
189
|
return_float_outputs=True,
|
|
164
|
-
wrapper=self.gptq_wrapper
|
|
190
|
+
wrapper=self.gptq_wrapper,
|
|
191
|
+
get_activation_quantizer_holder_fn=self.get_activation_quantizer_holder).build_model()
|
|
165
192
|
|
|
166
193
|
return gptq_model, gptq_user_info
|
|
167
194
|
|
|
@@ -280,7 +307,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
280
307
|
self.gptq_config.log_function(loss_value_step, grads[0], in_optimizer_with_param[0][-1],
|
|
281
308
|
self.compare_points)
|
|
282
309
|
self.loss_list.append(loss_value_step.numpy())
|
|
283
|
-
|
|
310
|
+
Logger.debug(f'last loss value: {self.loss_list[-1]}')
|
|
284
311
|
|
|
285
312
|
def update_graph(self):
|
|
286
313
|
"""
|
|
@@ -297,7 +324,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
297
324
|
if len(node) == 0 and isinstance(layer.layer, TensorFlowOpLayer):
|
|
298
325
|
node = graph.find_node_by_name('_'.join(layer.layer.name.split('_')[3:]))
|
|
299
326
|
if len(node) != 1:
|
|
300
|
-
|
|
327
|
+
Logger.error(f"Can't update GPTQ graph due to missing layer named: {layer.layer.name}")
|
|
301
328
|
node = node[0]
|
|
302
329
|
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
|
|
303
330
|
fw_info=self.fw_info)
|
|
@@ -20,6 +20,7 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
|
20
20
|
from tensorflow.keras.models import Model
|
|
21
21
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
22
22
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
23
|
+
from model_compression_toolkit.logger import Logger
|
|
23
24
|
from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
|
|
24
25
|
from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
25
26
|
|
|
@@ -50,6 +51,9 @@ def get_gptq_trainable_parameters(fxp_model: Model,
|
|
|
50
51
|
fw_info=DEFAULT_KERAS_INFO)
|
|
51
52
|
|
|
52
53
|
# collect trainable weights per quantizer
|
|
54
|
+
if kernel_attribute not in layer.weights_quantizers:
|
|
55
|
+
Logger.error(f'{kernel_attribute} was not found in weight quantizers of layer {layer.layer}')
|
|
56
|
+
|
|
53
57
|
quantizer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.WEIGHTS)
|
|
54
58
|
quantizer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.QPARAMS)
|
|
55
59
|
trainable_weights.append(quantizer_trainable_weights)
|
|
@@ -16,21 +16,19 @@
|
|
|
16
16
|
from typing import Callable, Tuple
|
|
17
17
|
from packaging import version
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit.
|
|
20
|
-
from model_compression_toolkit.
|
|
21
|
-
from model_compression_toolkit.core.common.constants import TENSORFLOW
|
|
19
|
+
from model_compression_toolkit.logger import Logger
|
|
20
|
+
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
|
|
22
21
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
23
22
|
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
|
|
24
23
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
25
24
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
26
|
-
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import
|
|
27
|
-
|
|
28
|
-
from model_compression_toolkit import CoreConfig
|
|
25
|
+
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfigV2
|
|
26
|
+
from model_compression_toolkit.core import CoreConfig
|
|
29
27
|
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
|
|
30
28
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
|
31
29
|
from model_compression_toolkit.core.exporter import export_model
|
|
32
30
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
33
|
-
from model_compression_toolkit.
|
|
31
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
34
32
|
|
|
35
33
|
LR_DEFAULT = 0.15
|
|
36
34
|
LR_REST_DEFAULT = 1e-4
|
|
@@ -38,14 +36,14 @@ LR_BIAS_DEFAULT = 1e-4
|
|
|
38
36
|
LR_QUANTIZATION_PARAM_DEFAULT = 1e-3
|
|
39
37
|
GPTQ_MOMENTUM = 0.9
|
|
40
38
|
|
|
41
|
-
if
|
|
39
|
+
if FOUND_TF:
|
|
42
40
|
import tensorflow as tf
|
|
43
41
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
44
|
-
from model_compression_toolkit.
|
|
42
|
+
from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation
|
|
45
43
|
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
|
46
44
|
from tensorflow.keras.models import Model
|
|
47
45
|
from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss
|
|
48
|
-
from model_compression_toolkit.
|
|
46
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
|
49
47
|
from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
|
|
50
48
|
from model_compression_toolkit import get_target_platform_capabilities
|
|
51
49
|
|
|
@@ -62,7 +60,8 @@ if common.constants.FOUND_TF:
|
|
|
62
60
|
optimizer: OptimizerV2 = tf.keras.optimizers.Adam(learning_rate=LR_DEFAULT),
|
|
63
61
|
optimizer_rest: OptimizerV2 = tf.keras.optimizers.Adam(learning_rate=LR_REST_DEFAULT),
|
|
64
62
|
loss: Callable = GPTQMultipleTensorsLoss(),
|
|
65
|
-
log_function: Callable = None
|
|
63
|
+
log_function: Callable = None,
|
|
64
|
+
use_hessian_based_weights: bool = True) -> GradientPTQConfigV2:
|
|
66
65
|
"""
|
|
67
66
|
Create a GradientPTQConfigV2 instance for Keras models.
|
|
68
67
|
|
|
@@ -72,6 +71,7 @@ if common.constants.FOUND_TF:
|
|
|
72
71
|
optimizer_rest (OptimizerV2): Keras optimizer to use for fine-tuning of the bias variable.
|
|
73
72
|
loss (Callable): loss to use during fine-tuning. should accept 4 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors, the 3rd is a list of quantized weights and the 4th is a list of float weights.
|
|
74
73
|
log_function (Callable): Function to log information about the gptq process.
|
|
74
|
+
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
|
|
75
75
|
|
|
76
76
|
returns:
|
|
77
77
|
a GradientPTQConfigV2 object to use when fine-tuning the quantized model using gptq.
|
|
@@ -94,9 +94,16 @@ if common.constants.FOUND_TF:
|
|
|
94
94
|
The configuration can be passed to :func:`~model_compression_toolkit.keras_post_training_quantization` in order to quantize a keras model using gptq.
|
|
95
95
|
|
|
96
96
|
"""
|
|
97
|
-
bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT,
|
|
98
|
-
|
|
99
|
-
|
|
97
|
+
bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT,
|
|
98
|
+
momentum=GPTQ_MOMENTUM)
|
|
99
|
+
return GradientPTQConfigV2(n_epochs,
|
|
100
|
+
optimizer,
|
|
101
|
+
optimizer_rest=optimizer_rest,
|
|
102
|
+
loss=loss,
|
|
103
|
+
log_function=log_function,
|
|
104
|
+
train_bias=True,
|
|
105
|
+
optimizer_bias=bias_optimizer,
|
|
106
|
+
use_hessian_based_weights=use_hessian_based_weights)
|
|
100
107
|
|
|
101
108
|
|
|
102
109
|
def keras_gradient_post_training_quantization_experimental(in_model: Model,
|
|
@@ -158,20 +165,20 @@ if common.constants.FOUND_TF:
|
|
|
158
165
|
|
|
159
166
|
Create an MCT core config, containing the quantization configuration:
|
|
160
167
|
|
|
161
|
-
>>> config = mct.CoreConfig()
|
|
168
|
+
>>> config = mct.core.CoreConfig()
|
|
162
169
|
|
|
163
170
|
If mixed precision is desired, create an MCT core config with a mixed-precision configuration, to quantize a model
|
|
164
171
|
with different bitwidths for different layers.
|
|
165
172
|
The candidates bitwidth for quantization should be defined in the target platform model:
|
|
166
173
|
|
|
167
|
-
>>> config = mct.CoreConfig(mixed_precision_config=mct.MixedPrecisionQuantizationConfigV2(num_of_images=1))
|
|
174
|
+
>>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfigV2(num_of_images=1))
|
|
168
175
|
|
|
169
176
|
For mixed-precision set a target KPI object:
|
|
170
177
|
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
|
|
171
178
|
that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value,
|
|
172
179
|
while the bias will not):
|
|
173
180
|
|
|
174
|
-
>>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
181
|
+
>>> kpi = mct.core.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
175
182
|
|
|
176
183
|
Create GPTQ config:
|
|
177
184
|
|
|
@@ -187,15 +194,15 @@ if common.constants.FOUND_TF:
|
|
|
187
194
|
|
|
188
195
|
if core_config.mixed_precision_enable:
|
|
189
196
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
|
|
190
|
-
|
|
197
|
+
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
191
198
|
"MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
|
|
192
199
|
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
193
200
|
|
|
194
|
-
|
|
201
|
+
Logger.info("Using experimental mixed-precision quantization. "
|
|
195
202
|
"If you encounter an issue please file a bug.")
|
|
196
203
|
tb_w = _init_tensorboard_writer(fw_info)
|
|
197
204
|
|
|
198
|
-
fw_impl =
|
|
205
|
+
fw_impl = GPTQKerasImplemantation()
|
|
199
206
|
|
|
200
207
|
tg, bit_widths_config = core_runner(in_model=in_model,
|
|
201
208
|
representative_data_gen=representative_data_gen,
|
|
@@ -15,3 +15,4 @@
|
|
|
15
15
|
|
|
16
16
|
import model_compression_toolkit.gptq.keras.quantizer.ste_rounding.symmetric_ste
|
|
17
17
|
import model_compression_toolkit.gptq.keras.quantizer.soft_rounding.symmetric_soft_quantizer
|
|
18
|
+
import model_compression_toolkit.gptq.keras.quantizer.soft_rounding.uniform_soft_quantizer
|
|
@@ -15,8 +15,8 @@
|
|
|
15
15
|
from abc import abstractmethod
|
|
16
16
|
from typing import Union, Dict, List
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.
|
|
19
|
-
from model_compression_toolkit.
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.constants import FOUND_TF
|
|
20
20
|
from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
|
|
21
21
|
|
|
22
22
|
from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig, \
|