mct-nightly 1.11.0.20240320.400__py3-none-any.whl → 1.11.0.20240322.404__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.11.0.20240320.400.dist-info → mct_nightly-1.11.0.20240322.404.dist-info}/METADATA +17 -9
- {mct_nightly-1.11.0.20240320.400.dist-info → mct_nightly-1.11.0.20240322.404.dist-info}/RECORD +152 -152
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/constants.py +1 -1
- model_compression_toolkit/core/__init__.py +3 -3
- model_compression_toolkit/core/common/collectors/base_collector.py +2 -2
- model_compression_toolkit/core/common/data_loader.py +3 -3
- model_compression_toolkit/core/common/graph/base_graph.py +10 -13
- model_compression_toolkit/core/common/graph/base_node.py +3 -3
- model_compression_toolkit/core/common/graph/edge.py +2 -1
- model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +2 -4
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
- model_compression_toolkit/core/common/hessian/hessian_info_service.py +2 -3
- model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +1 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +24 -23
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +110 -112
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +114 -0
- model_compression_toolkit/core/common/mixed_precision/{kpi_tools/kpi_data.py → resource_utilization_tools/resource_utilization_data.py} +19 -19
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +105 -0
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +26 -0
- model_compression_toolkit/core/common/mixed_precision/{kpi_tools/kpi_methods.py → resource_utilization_tools/ru_methods.py} +61 -61
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +75 -71
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -4
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +34 -34
- model_compression_toolkit/core/common/model_collector.py +2 -2
- model_compression_toolkit/core/common/network_editors/actions.py +3 -3
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +12 -12
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +2 -2
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +2 -2
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +7 -7
- model_compression_toolkit/core/common/pruning/prune_graph.py +2 -3
- model_compression_toolkit/core/common/pruning/pruner.py +7 -7
- model_compression_toolkit/core/common/pruning/pruning_config.py +1 -1
- model_compression_toolkit/core/common/pruning/pruning_info.py +2 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +7 -4
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +4 -6
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -4
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +8 -6
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +2 -2
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +4 -6
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +4 -7
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +3 -3
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +3 -3
- model_compression_toolkit/core/common/user_info.py +1 -1
- model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +3 -3
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +2 -2
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +4 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +2 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py +3 -3
- model_compression_toolkit/core/keras/hessian/trace_hessian_calculator_keras.py +1 -2
- model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py +5 -6
- model_compression_toolkit/core/keras/keras_implementation.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +2 -4
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +7 -7
- model_compression_toolkit/core/keras/reader/common.py +2 -2
- model_compression_toolkit/core/keras/reader/node_builder.py +1 -1
- model_compression_toolkit/core/keras/{kpi_data_facade.py → resource_utilization_data_facade.py} +25 -24
- model_compression_toolkit/core/keras/tf_tensor_numpy.py +4 -2
- model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +3 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +6 -11
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +2 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/hessian/activation_trace_hessian_calculator_pytorch.py +3 -7
- model_compression_toolkit/core/pytorch/hessian/trace_hessian_calculator_pytorch.py +1 -2
- model_compression_toolkit/core/pytorch/hessian/weights_trace_hessian_calculator_pytorch.py +2 -2
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +3 -3
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +5 -7
- model_compression_toolkit/core/pytorch/reader/reader.py +2 -2
- model_compression_toolkit/core/pytorch/{kpi_data_facade.py → resource_utilization_data_facade.py} +24 -22
- model_compression_toolkit/core/pytorch/utils.py +3 -2
- model_compression_toolkit/core/runner.py +43 -42
- model_compression_toolkit/data_generation/common/data_generation.py +18 -18
- model_compression_toolkit/data_generation/common/model_info_exctractors.py +1 -1
- model_compression_toolkit/data_generation/keras/keras_data_generation.py +7 -10
- model_compression_toolkit/data_generation/keras/model_info_exctractors.py +2 -1
- model_compression_toolkit/data_generation/keras/optimization_functions/image_initilization.py +2 -1
- model_compression_toolkit/data_generation/keras/optimization_functions/output_loss_functions.py +2 -4
- model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py +2 -1
- model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py +8 -11
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -3
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -3
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +8 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +7 -8
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +19 -12
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +10 -11
- model_compression_toolkit/gptq/common/gptq_graph.py +3 -3
- model_compression_toolkit/gptq/common/gptq_training.py +14 -12
- model_compression_toolkit/gptq/keras/gptq_training.py +10 -8
- model_compression_toolkit/gptq/keras/graph_info.py +1 -1
- model_compression_toolkit/gptq/keras/quantization_facade.py +15 -17
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +4 -5
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +1 -2
- model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -8
- model_compression_toolkit/gptq/pytorch/graph_info.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +11 -13
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -4
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +1 -2
- model_compression_toolkit/logger.py +1 -13
- model_compression_toolkit/pruning/keras/pruning_facade.py +11 -12
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +11 -12
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -14
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -8
- model_compression_toolkit/qat/keras/quantization_facade.py +20 -22
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -3
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +12 -14
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -3
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +1 -1
- model_compression_toolkit/target_platform_capabilities/immutable.py +4 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +4 -8
- model_compression_toolkit/target_platform_capabilities/target_platform/current_tp_model.py +1 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/fusing.py +43 -8
- model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +13 -18
- model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +2 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attribute_filter.py +2 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/current_tpc.py +2 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +5 -5
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +1 -2
- model_compression_toolkit/trainable_infrastructure/common/base_trainable_quantizer.py +13 -13
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +14 -7
- model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +5 -5
- model_compression_toolkit/trainable_infrastructure/keras/base_keras_quantizer.py +2 -3
- model_compression_toolkit/trainable_infrastructure/keras/load_model.py +4 -5
- model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py +3 -4
- model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi.py +0 -112
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_aggregation_methods.py +0 -105
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_functions_mapping.py +0 -26
- {mct_nightly-1.11.0.20240320.400.dist-info → mct_nightly-1.11.0.20240322.404.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.11.0.20240320.400.dist-info → mct_nightly-1.11.0.20240322.404.dist-info}/WHEEL +0 -0
- {mct_nightly-1.11.0.20240320.400.dist-info → mct_nightly-1.11.0.20240322.404.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/common/mixed_precision/{kpi_tools → resource_utilization_tools}/__init__.py +0 -0
|
@@ -46,29 +46,29 @@ if FOUND_TF:
|
|
|
46
46
|
|
|
47
47
|
valid_layer = isinstance(layer, Layer)
|
|
48
48
|
if not valid_layer:
|
|
49
|
-
Logger.
|
|
49
|
+
Logger.critical(
|
|
50
50
|
f'Exportable layer must be a Keras layer, but layer {layer.name} is of type '
|
|
51
51
|
f'{type(layer)}') # pragma: no cover
|
|
52
52
|
|
|
53
53
|
if isinstance(layer, KerasQuantizationWrapper):
|
|
54
54
|
valid_weights_quantizers = isinstance(layer.weights_quantizers, dict)
|
|
55
55
|
if not valid_weights_quantizers:
|
|
56
|
-
Logger.
|
|
56
|
+
Logger.critical(
|
|
57
57
|
f'KerasQuantizationWrapper must have a weights_quantizers but has a '
|
|
58
58
|
f'{type(layer.weights_quantizers)} object') # pragma: no cover
|
|
59
59
|
|
|
60
60
|
if len(layer.weights_quantizers) == 0:
|
|
61
|
-
Logger.
|
|
61
|
+
Logger.critical(f'KerasQuantizationWrapper must have at least one weight quantizer, but found {len(layer.weights_quantizers)} quantizers. If layer is not quantized it should be a Keras layer.')
|
|
62
62
|
|
|
63
63
|
for _, weights_quantizer in layer.weights_quantizers.items():
|
|
64
64
|
if not isinstance(weights_quantizer, BaseInferableQuantizer):
|
|
65
|
-
Logger.
|
|
65
|
+
Logger.critical(
|
|
66
66
|
f'weights_quantizer must be a BaseInferableQuantizer object but has a '
|
|
67
67
|
f'{type(weights_quantizer)} object') # pragma: no cover
|
|
68
68
|
|
|
69
69
|
if isinstance(layer, KerasActivationQuantizationHolder):
|
|
70
70
|
if not isinstance(layer.activation_holder_quantizer, BaseInferableQuantizer):
|
|
71
|
-
Logger.
|
|
71
|
+
Logger.critical(
|
|
72
72
|
f'activation quantizer in KerasActivationQuantizationHolder'
|
|
73
73
|
f' must be a BaseInferableQuantizer object but has a '
|
|
74
74
|
f'{type(layer.activation_holder_quantizer)} object') # pragma: no cover
|
|
@@ -76,6 +76,5 @@ if FOUND_TF:
|
|
|
76
76
|
return True
|
|
77
77
|
else:
|
|
78
78
|
def is_keras_layer_exportable(*args, **kwargs): # pragma: no cover
|
|
79
|
-
Logger.
|
|
80
|
-
|
|
81
|
-
'Could not find Tensorflow package.')
|
|
79
|
+
Logger.critical("Tensorflow must be installed to use is_keras_layer_exportable. "
|
|
80
|
+
"The 'tensorflow' package is missing.") # pragma: no cover
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -60,7 +60,7 @@ if FOUND_TORCH:
|
|
|
60
60
|
# quantization, which in this case has an empty list).
|
|
61
61
|
if len(activation_quantizers) == 1:
|
|
62
62
|
return PytorchActivationQuantizationHolder(activation_quantizers[0])
|
|
63
|
-
Logger.
|
|
63
|
+
Logger.critical(
|
|
64
64
|
f'PytorchActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
|
|
65
65
|
f'were found for node {node}')
|
|
66
66
|
|
|
@@ -74,16 +74,23 @@ if FOUND_TORCH:
|
|
|
74
74
|
Returns:
|
|
75
75
|
Fully quantized PyTorch model.
|
|
76
76
|
"""
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
77
|
+
exportable_model, user_info = PyTorchModelBuilder(graph=graph,
|
|
78
|
+
wrapper=lambda n, m:
|
|
79
|
+
fully_quantized_wrapper(n, m,
|
|
80
|
+
fw_impl=C.pytorch.pytorch_implementation.PytorchImplementation()),
|
|
81
|
+
get_activation_quantizer_holder_fn=lambda n:
|
|
82
|
+
get_activation_quantizer_holder(n,
|
|
83
|
+
fw_impl=C.pytorch.pytorch_implementation.PytorchImplementation())).build_model()
|
|
84
|
+
|
|
85
|
+
Logger.info("Please run your accuracy evaluation on the exported quantized model to verify it's accuracy.\n"
|
|
86
|
+
"Checkout the FAQ and Troubleshooting pages for resolving common issues and improving the quantized model accuracy:\n"
|
|
87
|
+
"FAQ: https://github.com/sony/model_optimization/tree/main/FAQ.md"
|
|
88
|
+
"Quantization Troubleshooting: https://github.com/sony/model_optimization/tree/main/quantization_troubleshooting.md")
|
|
89
|
+
|
|
90
|
+
return exportable_model, user_info
|
|
91
|
+
|
|
84
92
|
|
|
85
93
|
else:
|
|
86
|
-
def get_exportable_pytorch_model(*args, **kwargs):
|
|
87
|
-
Logger.
|
|
88
|
-
|
|
89
|
-
'Could not find PyTorch package.')
|
|
94
|
+
def get_exportable_pytorch_model(*args, **kwargs):
|
|
95
|
+
Logger.critical("PyTorch must be installed to use 'get_exportable_pytorch_model'. "
|
|
96
|
+
"The 'torch' package is missing.") # pragma: no cover
|
|
@@ -44,7 +44,7 @@ def get_weights_inferable_quantizer_kwargs(node_qc: NodeWeightsQuantizationConfi
|
|
|
44
44
|
"""
|
|
45
45
|
|
|
46
46
|
if not isinstance(node_qc, NodeWeightsQuantizationConfig):
|
|
47
|
-
Logger.
|
|
47
|
+
Logger.critical(
|
|
48
48
|
f"Non-compatible node quantization config was given for quantization target Weights.") # pragma: no cover
|
|
49
49
|
|
|
50
50
|
if attr_name is None:
|
|
@@ -97,7 +97,7 @@ def get_activation_inferable_quantizer_kwargs(node_qc: NodeActivationQuantizatio
|
|
|
97
97
|
"""
|
|
98
98
|
|
|
99
99
|
if not isinstance(node_qc, NodeActivationQuantizationConfig):
|
|
100
|
-
Logger.
|
|
100
|
+
Logger.critical(
|
|
101
101
|
f"Non-compatible node quantization config was given for quantization target Activation.") # pragma: no cover
|
|
102
102
|
|
|
103
103
|
quantization_method = node_qc.activation_quantization_method
|
|
@@ -35,36 +35,35 @@ if FOUND_TORCH:
|
|
|
35
35
|
Check whether a PyTorch layer is a valid exportable layer or not.
|
|
36
36
|
"""
|
|
37
37
|
if not isinstance(layer, nn.Module):
|
|
38
|
-
Logger.
|
|
38
|
+
Logger.critical(f'Exportable layer must be a nn.Module layer, but layer {layer.name} is of type {type(layer)}.') # pragma: no cover
|
|
39
39
|
|
|
40
40
|
if isinstance(layer, PytorchQuantizationWrapper):
|
|
41
41
|
valid_weights_quantizers = isinstance(layer.weights_quantizers, dict)
|
|
42
42
|
if not valid_weights_quantizers:
|
|
43
|
-
Logger.
|
|
43
|
+
Logger.critical(
|
|
44
44
|
f'PytorchQuantizationWrapper must have a weights_quantizers but has a '
|
|
45
|
-
f'{type(layer.weights_quantizers)} object') # pragma: no cover
|
|
45
|
+
f'{type(layer.weights_quantizers)} object.') # pragma: no cover
|
|
46
46
|
|
|
47
47
|
if len(layer.weights_quantizers) == 0:
|
|
48
|
-
Logger.
|
|
48
|
+
Logger.critical(f'PytorchQuantizationWrapper must have at least one weight quantizer, but found {len(layer.weights_quantizers)} quantizers.'
|
|
49
49
|
f'If layer is not quantized it should be a Keras layer.')
|
|
50
50
|
|
|
51
51
|
for _, weights_quantizer in layer.weights_quantizers.items():
|
|
52
52
|
if not isinstance(weights_quantizer, BasePyTorchInferableQuantizer):
|
|
53
|
-
Logger.
|
|
53
|
+
Logger.critical(
|
|
54
54
|
f'weights_quantizer must be a BasePyTorchInferableQuantizer object but has a '
|
|
55
|
-
f'{type(weights_quantizer)} object') # pragma: no cover
|
|
55
|
+
f'{type(weights_quantizer)} object.') # pragma: no cover
|
|
56
56
|
|
|
57
57
|
elif isinstance(layer, PytorchActivationQuantizationHolder):
|
|
58
58
|
if not isinstance(layer.activation_holder_quantizer, BasePyTorchInferableQuantizer):
|
|
59
|
-
Logger.
|
|
59
|
+
Logger.critical(
|
|
60
60
|
f'activation quantizer in PytorchActivationQuantizationHolder'
|
|
61
61
|
f' must be a BasePyTorchInferableQuantizer object but has a '
|
|
62
|
-
f'{type(layer.activation_holder_quantizer)} object') # pragma: no cover
|
|
62
|
+
f'{type(layer.activation_holder_quantizer)} object.') # pragma: no cover
|
|
63
63
|
|
|
64
64
|
return True
|
|
65
65
|
|
|
66
66
|
else:
|
|
67
67
|
def is_pytorch_layer_exportable(*args, **kwargs): # pragma: no cover
|
|
68
|
-
Logger.
|
|
69
|
-
|
|
70
|
-
'Could not find PyTorch package.')
|
|
68
|
+
Logger.critical("PyTorch must be installed to use 'is_pytorch_layer_exportable'. "
|
|
69
|
+
"The 'torch' package is missing.") # pragma: no cover
|
|
@@ -62,7 +62,7 @@ def get_kernel_attribute_name_for_gptq(layer_type: type, fw_info: FrameworkInfo)
|
|
|
62
62
|
"""
|
|
63
63
|
kernel_attribute = fw_info.get_kernel_op_attributes(layer_type)
|
|
64
64
|
if len(kernel_attribute) != 1:
|
|
65
|
-
Logger.
|
|
66
|
-
f"In GPTQ training only the kernel weights attribute should be trained
|
|
67
|
-
f"attributes is {len(kernel_attribute)}.")
|
|
65
|
+
Logger.critical( # pragma: no cover
|
|
66
|
+
f"In GPTQ training, only the kernel weights attribute should be trained. "
|
|
67
|
+
f"However, the number of kernel attributes is {len(kernel_attribute)}.")
|
|
68
68
|
return kernel_attribute[0]
|
|
@@ -75,8 +75,8 @@ class GPTQTrainer(ABC):
|
|
|
75
75
|
self.fxp_model, self.gptq_user_info = self.build_gptq_model()
|
|
76
76
|
if self.gptq_config.use_hessian_based_weights:
|
|
77
77
|
if not isinstance(hessian_info_service, HessianInfoService):
|
|
78
|
-
Logger.
|
|
79
|
-
|
|
78
|
+
Logger.critical(f"When using Hessian-based approximations for sensitivity evaluation, "
|
|
79
|
+
f"an 'HessianInfoService' object must be provided, but received: {hessian_info_service}.")
|
|
80
80
|
self.hessian_service = hessian_info_service
|
|
81
81
|
|
|
82
82
|
def get_optimizer_with_param(self,
|
|
@@ -106,8 +106,8 @@ class GPTQTrainer(ABC):
|
|
|
106
106
|
else:
|
|
107
107
|
w2train_res.extend(flattened_bias_weights)
|
|
108
108
|
if self.gptq_config.optimizer_rest is None:
|
|
109
|
-
Logger.
|
|
110
|
-
|
|
109
|
+
Logger.critical("To enable bias micro-training, an additional optimizer is required. "
|
|
110
|
+
"Please define the 'optimizer_rest' parameter.")# pragma: no cover
|
|
111
111
|
if quant_params_learning:
|
|
112
112
|
if self.gptq_config.optimizer_quantization_parameter is not None: # Ability to override optimizer
|
|
113
113
|
optimizer_with_param.append((self.gptq_config.optimizer_quantization_parameter,
|
|
@@ -115,14 +115,16 @@ class GPTQTrainer(ABC):
|
|
|
115
115
|
else:
|
|
116
116
|
w2train_res.extend(trainable_quantization_parameters)
|
|
117
117
|
if self.gptq_config.optimizer_rest is None:
|
|
118
|
-
Logger.
|
|
119
|
-
"To enable quantization parameters micro
|
|
118
|
+
Logger.critical(
|
|
119
|
+
"To enable quantization parameters micro-training, an additional optimizer is required. "
|
|
120
|
+
"Please define the 'optimizer_rest' parameter.") # pragma: no cover
|
|
120
121
|
if len(w2train_res) > 0:
|
|
121
122
|
# Either bias or quantization parameters are trainable but did not provide a specific optimizer,
|
|
122
123
|
# so we should use optimizer_rest to train them
|
|
123
124
|
if self.gptq_config.optimizer_rest is None:
|
|
124
|
-
Logger.
|
|
125
|
-
"To enable
|
|
125
|
+
Logger.critical(
|
|
126
|
+
"To enable bais or quantization parameters micro-training, an additional optimizer is required. "
|
|
127
|
+
"Please define the 'optimizer_rest' parameter.") # pragma: no cover
|
|
126
128
|
optimizer_with_param.append((self.gptq_config.optimizer_rest, w2train_res))
|
|
127
129
|
|
|
128
130
|
return optimizer_with_param
|
|
@@ -236,11 +238,11 @@ class GPTQTrainer(ABC):
|
|
|
236
238
|
trace_approx: Trace approximation to validate.
|
|
237
239
|
"""
|
|
238
240
|
if not isinstance(trace_approx, list):
|
|
239
|
-
Logger.
|
|
241
|
+
Logger.critical(f"Trace approximation was expected to be a list but is of type: {type(trace_approx)}.")
|
|
240
242
|
if len(trace_approx) != 1:
|
|
241
|
-
Logger.
|
|
242
|
-
|
|
243
|
-
|
|
243
|
+
Logger.critical(f"Trace approximation was expected to have a length of 1 "
|
|
244
|
+
f"(for computations with granularity set to 'HessianInfoGranularity.PER_TENSOR') "
|
|
245
|
+
f"but has a length of {len(trace_approx)}."
|
|
244
246
|
)
|
|
245
247
|
|
|
246
248
|
@staticmethod
|
|
@@ -96,9 +96,9 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
96
96
|
|
|
97
97
|
if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
|
|
98
98
|
self.fxp_weights_list)):
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
99
|
+
Logger.critical("Mismatch in the number of comparison points, layers with trainable weights, "
|
|
100
|
+
"and the number of float and quantized weights for loss calculation. "
|
|
101
|
+
"Ensure all these elements align to proceed with GPTQ training.")
|
|
102
102
|
|
|
103
103
|
flattened_trainable_weights = [w for layer_weights in trainable_weights for w in layer_weights]
|
|
104
104
|
flattened_bias_weights = [w for layer_weights in bias_weights for w in layer_weights]
|
|
@@ -110,7 +110,8 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
110
110
|
[len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0
|
|
111
111
|
|
|
112
112
|
if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
|
|
113
|
-
Logger.
|
|
113
|
+
Logger.critical("Input scale mismatch detected between the float model and the GPTQ model. "
|
|
114
|
+
"Confirm that the input scales for both models are correctly configured and aligned.") # pragma: no cover
|
|
114
115
|
else:
|
|
115
116
|
self.input_scale = self.gptq_user_info.input_scale
|
|
116
117
|
|
|
@@ -177,9 +178,9 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
177
178
|
if len(activation_quantizers) == 1:
|
|
178
179
|
return KerasActivationQuantizationHolder(activation_quantizers[0])
|
|
179
180
|
|
|
180
|
-
Logger.
|
|
181
|
-
|
|
182
|
-
|
|
181
|
+
Logger.critical(f"'KerasActivationQuantizationHolder' is designed to support a single quantizer, "
|
|
182
|
+
f"but {len(activation_quantizers)} quantizers were found for node '{n}'. "
|
|
183
|
+
f"Ensure only one quantizer is configured for each node's activation.")
|
|
183
184
|
|
|
184
185
|
|
|
185
186
|
def build_gptq_model(self) -> Tuple[Model, UserInformation]:
|
|
@@ -331,7 +332,8 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
331
332
|
if len(node) == 0 and isinstance(layer.layer, TensorFlowOpLayer):
|
|
332
333
|
node = graph.find_node_by_name('_'.join(layer.layer.name.split('_')[3:]))
|
|
333
334
|
if len(node) != 1:
|
|
334
|
-
Logger.
|
|
335
|
+
Logger.critical(f"Unable to update the GPTQ graph because the layer named '{layer.layer.name}' could not be found. "
|
|
336
|
+
f"Verify that the layer names in the GPTQ model match those in the graph.")
|
|
335
337
|
node = node[0]
|
|
336
338
|
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
|
|
337
339
|
fw_info=self.fw_info)
|
|
@@ -52,7 +52,7 @@ def get_gptq_trainable_parameters(fxp_model: Model,
|
|
|
52
52
|
|
|
53
53
|
# collect trainable weights per quantizer
|
|
54
54
|
if kernel_attribute not in layer.weights_quantizers:
|
|
55
|
-
Logger.
|
|
55
|
+
Logger.critical(f"'{kernel_attribute}' was not found in the weight quantizers of layer '{layer.layer}'.")
|
|
56
56
|
|
|
57
57
|
quantizer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.WEIGHTS)
|
|
58
58
|
quantizer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.QPARAMS)
|
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.logger import Logger
|
|
|
22
22
|
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
|
|
23
23
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
24
24
|
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
25
|
-
from model_compression_toolkit.core.common.mixed_precision.
|
|
25
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
|
26
26
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
27
27
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
|
|
28
28
|
from model_compression_toolkit.core import CoreConfig
|
|
@@ -116,7 +116,7 @@ if FOUND_TF:
|
|
|
116
116
|
def keras_gradient_post_training_quantization(in_model: Model, representative_data_gen: Callable,
|
|
117
117
|
gptq_config: GradientPTQConfig,
|
|
118
118
|
gptq_representative_data_gen: Callable = None,
|
|
119
|
-
|
|
119
|
+
target_resource_utilization: ResourceUtilization = None,
|
|
120
120
|
core_config: CoreConfig = CoreConfig(),
|
|
121
121
|
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
|
|
122
122
|
"""
|
|
@@ -129,7 +129,7 @@ if FOUND_TF:
|
|
|
129
129
|
statistics. Then, if given a mixed precision config in the core_config, using an ILP solver we find
|
|
130
130
|
a mixed-precision configuration, and set a bit-width for each layer. The model is then quantized
|
|
131
131
|
(both coefficients and activations by default).
|
|
132
|
-
In order to limit the maximal model's size, a target
|
|
132
|
+
In order to limit the maximal model's size, a target resource utilization need to be passed after weights_memory
|
|
133
133
|
is set (in bytes).
|
|
134
134
|
Then, the quantized weights are optimized using gradient based post
|
|
135
135
|
training quantization by comparing points between the float and quantized models, and minimizing the observed
|
|
@@ -140,7 +140,7 @@ if FOUND_TF:
|
|
|
140
140
|
representative_data_gen (Callable): Dataset used for calibration.
|
|
141
141
|
gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
|
|
142
142
|
gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
|
|
143
|
-
|
|
143
|
+
target_resource_utilization (ResourceUtilization): ResourceUtilization object to limit the search of the mixed-precision configuration as desired.
|
|
144
144
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
145
145
|
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
146
146
|
|
|
@@ -174,12 +174,12 @@ if FOUND_TF:
|
|
|
174
174
|
|
|
175
175
|
>>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=1))
|
|
176
176
|
|
|
177
|
-
For mixed-precision set a target
|
|
178
|
-
Create a
|
|
177
|
+
For mixed-precision set a target resource utilization object:
|
|
178
|
+
Create a resource utilization object to limit our returned model's size. Note that this value affects only coefficients
|
|
179
179
|
that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value,
|
|
180
180
|
while the bias will not):
|
|
181
181
|
|
|
182
|
-
>>>
|
|
182
|
+
>>> ru = mct.core.ResourceUtilization(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
|
|
183
183
|
|
|
184
184
|
Create GPTQ config:
|
|
185
185
|
|
|
@@ -187,7 +187,7 @@ if FOUND_TF:
|
|
|
187
187
|
|
|
188
188
|
Pass the model with the representative dataset generator to get a quantized model:
|
|
189
189
|
|
|
190
|
-
>>> quantized_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization(model, repr_datagen, gptq_config,
|
|
190
|
+
>>> quantized_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization(model, repr_datagen, gptq_config, target_resource_utilization=ru, core_config=config)
|
|
191
191
|
|
|
192
192
|
"""
|
|
193
193
|
KerasModelValidation(model=in_model,
|
|
@@ -195,9 +195,9 @@ if FOUND_TF:
|
|
|
195
195
|
|
|
196
196
|
if core_config.mixed_precision_enable:
|
|
197
197
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
198
|
-
Logger.
|
|
199
|
-
|
|
200
|
-
|
|
198
|
+
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
|
|
199
|
+
"Ensure usage of the correct API for keras_post_training_quantization "
|
|
200
|
+
"or provide a valid mixed-precision configuration.") # pragma: no cover
|
|
201
201
|
|
|
202
202
|
tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)
|
|
203
203
|
|
|
@@ -232,12 +232,10 @@ else:
|
|
|
232
232
|
# If tensorflow is not installed,
|
|
233
233
|
# we raise an exception when trying to use these functions.
|
|
234
234
|
def get_keras_gptq_config(*args, **kwargs):
|
|
235
|
-
Logger.critical(
|
|
236
|
-
'
|
|
237
|
-
'Could not find Tensorflow package.') # pragma: no cover
|
|
235
|
+
Logger.critical("Tensorflow must be installed to use get_keras_gptq_config. "
|
|
236
|
+
"The 'tensorflow' package is missing.") # pragma: no cover
|
|
238
237
|
|
|
239
238
|
|
|
240
239
|
def keras_gradient_post_training_quantization(*args, **kwargs):
|
|
241
|
-
Logger.critical(
|
|
242
|
-
'
|
|
243
|
-
'Could not find Tensorflow package.') # pragma: no cover
|
|
240
|
+
Logger.critical("Tensorflow must be installed to use keras_gradient_post_training_quantization. "
|
|
241
|
+
"The 'tensorflow' package is missing.") # pragma: no cover
|
|
@@ -61,8 +61,8 @@ if FOUND_TF:
|
|
|
61
61
|
weights = {}
|
|
62
62
|
for weight, quantizer_vars, quantizer in layer.get_weights_vars():
|
|
63
63
|
if not isinstance(quantizer, BaseTrainableQuantizer):
|
|
64
|
-
Logger.
|
|
65
|
-
|
|
64
|
+
Logger.critical(f"Expecting a GPTQ trainable quantizer for layer '{layer.name}', but received {type(quantizer)}. "
|
|
65
|
+
f"Ensure a trainable quantizer is used.") # pragma: no cover
|
|
66
66
|
weights.update({weight: quantizer(training=False, inputs=quantizer_vars)})
|
|
67
67
|
|
|
68
68
|
quant_config = {WEIGHTS_QUANTIZATION_PARAMS: self.get_quant_config()}
|
|
@@ -105,6 +105,5 @@ if FOUND_TF:
|
|
|
105
105
|
else:
|
|
106
106
|
class BaseKerasGPTQTrainableQuantizer: # pragma: no cover
|
|
107
107
|
def __init__(self, *args, **kwargs):
|
|
108
|
-
Logger.critical(
|
|
109
|
-
'
|
|
110
|
-
'Could not find Tensorflow package.') # pragma: no cover
|
|
108
|
+
Logger.critical("Tensorflow must be installed to use BaseKerasGPTQTrainableQuantizer. "
|
|
109
|
+
"The 'tensorflow' package is missing.") # pragma: no cover
|
|
@@ -66,8 +66,7 @@ def quantization_builder(n: common.BaseNode,
|
|
|
66
66
|
activation_quantizers = []
|
|
67
67
|
if n.is_activation_quantization_enabled():
|
|
68
68
|
if n.final_activation_quantization_cfg is None:
|
|
69
|
-
Logger.critical(f
|
|
70
|
-
# pragma: no cover
|
|
69
|
+
Logger.critical(f"Cannot set quantizer for a node without a final activation quantization configuration.") # pragma: no cover
|
|
71
70
|
|
|
72
71
|
quant_method = n.final_activation_quantization_cfg.activation_quantization_method
|
|
73
72
|
|
|
@@ -76,7 +76,8 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
76
76
|
self.loss_list = []
|
|
77
77
|
self.input_scale = 1
|
|
78
78
|
if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
|
|
79
|
-
Logger.
|
|
79
|
+
Logger.critical("Input scale mismatch between float and GPTQ networks. "
|
|
80
|
+
"Ensure both networks have matching input scales.") # pragma: no cover
|
|
80
81
|
else:
|
|
81
82
|
self.input_scale = self.gptq_user_info.input_scale
|
|
82
83
|
|
|
@@ -87,9 +88,9 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
87
88
|
self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model)
|
|
88
89
|
if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
|
|
89
90
|
self.fxp_weights_list)):
|
|
90
|
-
Logger.
|
|
91
|
-
|
|
92
|
-
|
|
91
|
+
Logger.critical("GPTQ: Number of comparison points, layers with trainable weights, "
|
|
92
|
+
"and float vs. quantized weights for loss calculation do not match. "
|
|
93
|
+
"Verify consistency across these parameters for successful GPTQ training.")
|
|
93
94
|
|
|
94
95
|
self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights,
|
|
95
96
|
trainable_bias,
|
|
@@ -156,9 +157,9 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
156
157
|
# quantization, which in this case has an empty list).
|
|
157
158
|
if len(activation_quantizers) == 1:
|
|
158
159
|
return PytorchActivationQuantizationHolder(activation_quantizers[0])
|
|
159
|
-
Logger.
|
|
160
|
-
|
|
161
|
-
|
|
160
|
+
Logger.critical(f"'PytorchActivationQuantizationHolder' requires exactly one quantizer, "
|
|
161
|
+
f"but {len(activation_quantizers)} were found for node {n.name}. "
|
|
162
|
+
f"Ensure the node is configured with a single activation quantizer.")
|
|
162
163
|
|
|
163
164
|
def build_gptq_model(self):
|
|
164
165
|
"""
|
|
@@ -278,7 +279,8 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
278
279
|
if isinstance(layer, PytorchQuantizationWrapper):
|
|
279
280
|
node = self.graph_quant.find_node_by_name(name)
|
|
280
281
|
if len(node) != 1:
|
|
281
|
-
Logger.
|
|
282
|
+
Logger.critical(f"Cannot update GPTQ graph: Layer with name '{name}' is missing or not unique. "
|
|
283
|
+
f"Ensure each layer has a unique name and exists within the graph for updates.")
|
|
282
284
|
node = node[0]
|
|
283
285
|
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
|
|
284
286
|
fw_info=self.fw_info)
|
|
@@ -48,7 +48,7 @@ def get_gptq_trainable_parameters(fxp_model: nn.Module,
|
|
|
48
48
|
|
|
49
49
|
# collect trainable weights per quantizer
|
|
50
50
|
if kernel_attribute not in layer.weights_quantizers:
|
|
51
|
-
Logger.
|
|
51
|
+
Logger.critical(f"'{kernel_attribute}' was not found in the weight quantizers of layer '{layer.layer}'.")
|
|
52
52
|
quantizer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.WEIGHTS)
|
|
53
53
|
quantizer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.QPARAMS)
|
|
54
54
|
trainable_aux_weights.extend(quantizer_trainable_weights)
|
|
@@ -21,7 +21,7 @@ from model_compression_toolkit.logger import Logger
|
|
|
21
21
|
from model_compression_toolkit.constants import PYTORCH
|
|
22
22
|
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
23
23
|
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
|
24
|
-
from model_compression_toolkit.core.common.mixed_precision.
|
|
24
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
|
25
25
|
from model_compression_toolkit.core.runner import core_runner
|
|
26
26
|
from model_compression_toolkit.gptq.keras.quantization_facade import GPTQ_MOMENTUM
|
|
27
27
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
|
@@ -94,7 +94,7 @@ if FOUND_TORCH:
|
|
|
94
94
|
|
|
95
95
|
def pytorch_gradient_post_training_quantization(model: Module,
|
|
96
96
|
representative_data_gen: Callable,
|
|
97
|
-
|
|
97
|
+
target_resource_utilization: ResourceUtilization = None,
|
|
98
98
|
core_config: CoreConfig = CoreConfig(),
|
|
99
99
|
gptq_config: GradientPTQConfig = None,
|
|
100
100
|
gptq_representative_data_gen: Callable = None,
|
|
@@ -118,7 +118,7 @@ if FOUND_TORCH:
|
|
|
118
118
|
Args:
|
|
119
119
|
model (Module): Pytorch model to quantize.
|
|
120
120
|
representative_data_gen (Callable): Dataset used for calibration.
|
|
121
|
-
|
|
121
|
+
target_resource_utilization (ResourceUtilization): ResourceUtilization object to limit the search of the mixed-precision configuration as desired.
|
|
122
122
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
123
123
|
gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
|
|
124
124
|
gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
|
|
@@ -159,9 +159,9 @@ if FOUND_TORCH:
|
|
|
159
159
|
|
|
160
160
|
if core_config.mixed_precision_enable:
|
|
161
161
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
162
|
-
Logger.
|
|
163
|
-
|
|
164
|
-
|
|
162
|
+
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
|
|
163
|
+
"Ensure usage of the correct API for 'keras_post_training_quantization' "
|
|
164
|
+
"or provide a valid mixed-precision configuration.") # pragma: no cover
|
|
165
165
|
|
|
166
166
|
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
|
|
167
167
|
|
|
@@ -176,7 +176,7 @@ if FOUND_TORCH:
|
|
|
176
176
|
fw_info=DEFAULT_PYTORCH_INFO,
|
|
177
177
|
fw_impl=fw_impl,
|
|
178
178
|
tpc=target_platform_capabilities,
|
|
179
|
-
|
|
179
|
+
target_resource_utilization=target_resource_utilization,
|
|
180
180
|
tb_w=tb_w)
|
|
181
181
|
|
|
182
182
|
# ---------------------- #
|
|
@@ -202,12 +202,10 @@ else:
|
|
|
202
202
|
# If torch is not installed,
|
|
203
203
|
# we raise an exception when trying to use these functions.
|
|
204
204
|
def get_pytorch_gptq_config(*args, **kwargs):
|
|
205
|
-
Logger.critical(
|
|
206
|
-
'
|
|
207
|
-
'Could not find torch package.') # pragma: no cover
|
|
205
|
+
Logger.critical("PyTorch must be installed to use 'get_pytorch_gptq_config'. "
|
|
206
|
+
"The 'torch' package is missing.") # pragma: no cover
|
|
208
207
|
|
|
209
208
|
|
|
210
209
|
def pytorch_gradient_post_training_quantization(*args, **kwargs):
|
|
211
|
-
Logger.critical(
|
|
212
|
-
'
|
|
213
|
-
'Could not find the torch package.') # pragma: no cover
|
|
210
|
+
Logger.critical("PyTorch must be installed to use 'pytorch_gradient_post_training_quantization'. "
|
|
211
|
+
"The 'torch' package is missing.") # pragma: no cover
|
|
@@ -63,7 +63,7 @@ if FOUND_TORCH:
|
|
|
63
63
|
weights = {}
|
|
64
64
|
for weight, quantizer_vars, quantizer in layer.get_weights_vars():
|
|
65
65
|
if not isinstance(quantizer, BaseTrainableQuantizer):
|
|
66
|
-
Logger.
|
|
66
|
+
Logger.critical(f"Expecting a GPTQ trainable quantizer, " # pragma: no cover
|
|
67
67
|
f"but got {type(quantizer)} which is not callable.")
|
|
68
68
|
weights.update({weight: quantizer(training=False, inputs=quantizer_vars)})
|
|
69
69
|
|
|
@@ -87,6 +87,5 @@ if FOUND_TORCH:
|
|
|
87
87
|
else:
|
|
88
88
|
class BasePytorchGPTQTrainableQuantizer: # pragma: no cover
|
|
89
89
|
def __init__(self, *args, **kwargs):
|
|
90
|
-
Logger.critical(
|
|
91
|
-
'
|
|
92
|
-
'Could not find torch package.') # pragma: no cover
|
|
90
|
+
Logger.critical("PyTorch must be installed to use 'BasePytorchGPTQTrainableQuantizer'. "
|
|
91
|
+
"The 'torch' package is missing.") # pragma: no cover
|
|
@@ -65,8 +65,7 @@ def quantization_builder(n: common.BaseNode,
|
|
|
65
65
|
activation_quantizers = []
|
|
66
66
|
if n.is_activation_quantization_enabled():
|
|
67
67
|
if n.final_activation_quantization_cfg is None:
|
|
68
|
-
Logger.critical(f
|
|
69
|
-
# pragma: no cover
|
|
68
|
+
Logger.critical(f"Cannot set quantizer for a node without a final activation quantization configuration.") # pragma: no cover
|
|
70
69
|
|
|
71
70
|
quant_method = n.final_activation_quantization_cfg.activation_quantization_method
|
|
72
71
|
|
|
@@ -19,7 +19,7 @@ import os
|
|
|
19
19
|
from datetime import datetime
|
|
20
20
|
from pathlib import Path
|
|
21
21
|
|
|
22
|
-
LOGGER_NAME = '
|
|
22
|
+
LOGGER_NAME = 'Model Compression Toolkit'
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class Logger:
|
|
@@ -116,17 +116,6 @@ class Logger:
|
|
|
116
116
|
Logger.get_logger().critical(msg)
|
|
117
117
|
raise Exception(msg)
|
|
118
118
|
|
|
119
|
-
@staticmethod
|
|
120
|
-
def exception(msg: str):
|
|
121
|
-
"""
|
|
122
|
-
Log a message at 'exception' severity and raise an exception.
|
|
123
|
-
Args:
|
|
124
|
-
msg: Message to log.
|
|
125
|
-
|
|
126
|
-
"""
|
|
127
|
-
Logger.get_logger().exception(msg)
|
|
128
|
-
raise Exception(msg)
|
|
129
|
-
|
|
130
119
|
@staticmethod
|
|
131
120
|
def debug(msg: str):
|
|
132
121
|
"""
|
|
@@ -172,7 +161,6 @@ class Logger:
|
|
|
172
161
|
|
|
173
162
|
"""
|
|
174
163
|
Logger.get_logger().error(msg)
|
|
175
|
-
raise Exception(msg)
|
|
176
164
|
|
|
177
165
|
|
|
178
166
|
def set_log_folder(folder: str, level: int = logging.INFO):
|