mct-nightly 1.11.0.20240321.357__py3-none-any.whl → 1.11.0.20240322.404__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240322.404.dist-info}/METADATA +17 -9
- {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240322.404.dist-info}/RECORD +152 -152
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/constants.py +1 -1
- model_compression_toolkit/core/__init__.py +3 -3
- model_compression_toolkit/core/common/collectors/base_collector.py +2 -2
- model_compression_toolkit/core/common/data_loader.py +3 -3
- model_compression_toolkit/core/common/graph/base_graph.py +10 -13
- model_compression_toolkit/core/common/graph/base_node.py +3 -3
- model_compression_toolkit/core/common/graph/edge.py +2 -1
- model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +2 -4
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
- model_compression_toolkit/core/common/hessian/hessian_info_service.py +2 -3
- model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +1 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +24 -23
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +110 -112
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +114 -0
- model_compression_toolkit/core/common/mixed_precision/{kpi_tools/kpi_data.py → resource_utilization_tools/resource_utilization_data.py} +19 -19
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +105 -0
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +26 -0
- model_compression_toolkit/core/common/mixed_precision/{kpi_tools/kpi_methods.py → resource_utilization_tools/ru_methods.py} +61 -61
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +75 -71
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -4
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +34 -34
- model_compression_toolkit/core/common/model_collector.py +2 -2
- model_compression_toolkit/core/common/network_editors/actions.py +3 -3
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +12 -12
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +2 -2
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +2 -2
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +7 -7
- model_compression_toolkit/core/common/pruning/prune_graph.py +2 -3
- model_compression_toolkit/core/common/pruning/pruner.py +7 -7
- model_compression_toolkit/core/common/pruning/pruning_config.py +1 -1
- model_compression_toolkit/core/common/pruning/pruning_info.py +2 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +7 -4
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +4 -6
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -4
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +8 -6
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +2 -2
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +4 -6
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +4 -7
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +3 -3
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +3 -3
- model_compression_toolkit/core/common/user_info.py +1 -1
- model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +3 -3
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +2 -2
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +4 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +2 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py +3 -3
- model_compression_toolkit/core/keras/hessian/trace_hessian_calculator_keras.py +1 -2
- model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py +5 -6
- model_compression_toolkit/core/keras/keras_implementation.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +2 -4
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +7 -7
- model_compression_toolkit/core/keras/reader/common.py +2 -2
- model_compression_toolkit/core/keras/reader/node_builder.py +1 -1
- model_compression_toolkit/core/keras/{kpi_data_facade.py → resource_utilization_data_facade.py} +25 -24
- model_compression_toolkit/core/keras/tf_tensor_numpy.py +4 -2
- model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +3 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +6 -11
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +2 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/hessian/activation_trace_hessian_calculator_pytorch.py +3 -7
- model_compression_toolkit/core/pytorch/hessian/trace_hessian_calculator_pytorch.py +1 -2
- model_compression_toolkit/core/pytorch/hessian/weights_trace_hessian_calculator_pytorch.py +2 -2
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +3 -3
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +5 -7
- model_compression_toolkit/core/pytorch/reader/reader.py +2 -2
- model_compression_toolkit/core/pytorch/{kpi_data_facade.py → resource_utilization_data_facade.py} +24 -22
- model_compression_toolkit/core/pytorch/utils.py +3 -2
- model_compression_toolkit/core/runner.py +43 -42
- model_compression_toolkit/data_generation/common/data_generation.py +18 -18
- model_compression_toolkit/data_generation/common/model_info_exctractors.py +1 -1
- model_compression_toolkit/data_generation/keras/keras_data_generation.py +7 -10
- model_compression_toolkit/data_generation/keras/model_info_exctractors.py +2 -1
- model_compression_toolkit/data_generation/keras/optimization_functions/image_initilization.py +2 -1
- model_compression_toolkit/data_generation/keras/optimization_functions/output_loss_functions.py +2 -4
- model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py +2 -1
- model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py +8 -11
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -3
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -3
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +8 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +7 -8
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +19 -12
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +10 -11
- model_compression_toolkit/gptq/common/gptq_graph.py +3 -3
- model_compression_toolkit/gptq/common/gptq_training.py +14 -12
- model_compression_toolkit/gptq/keras/gptq_training.py +10 -8
- model_compression_toolkit/gptq/keras/graph_info.py +1 -1
- model_compression_toolkit/gptq/keras/quantization_facade.py +15 -17
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +4 -5
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +1 -2
- model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -8
- model_compression_toolkit/gptq/pytorch/graph_info.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +11 -13
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -4
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +1 -2
- model_compression_toolkit/logger.py +1 -13
- model_compression_toolkit/pruning/keras/pruning_facade.py +11 -12
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +11 -12
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -14
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -8
- model_compression_toolkit/qat/keras/quantization_facade.py +20 -22
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -3
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +12 -14
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -3
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +1 -1
- model_compression_toolkit/target_platform_capabilities/immutable.py +4 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +4 -8
- model_compression_toolkit/target_platform_capabilities/target_platform/current_tp_model.py +1 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/fusing.py +43 -8
- model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +13 -18
- model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +2 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attribute_filter.py +2 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/current_tpc.py +2 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +5 -5
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +1 -2
- model_compression_toolkit/trainable_infrastructure/common/base_trainable_quantizer.py +13 -13
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +14 -7
- model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +5 -5
- model_compression_toolkit/trainable_infrastructure/keras/base_keras_quantizer.py +2 -3
- model_compression_toolkit/trainable_infrastructure/keras/load_model.py +4 -5
- model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py +3 -4
- model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi.py +0 -112
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_aggregation_methods.py +0 -105
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_functions_mapping.py +0 -26
- {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240322.404.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240322.404.dist-info}/WHEEL +0 -0
- {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240322.404.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/common/mixed_precision/{kpi_tools → resource_utilization_tools}/__init__.py +0 -0
|
@@ -89,8 +89,8 @@ def fx_graph_module_generation(pytorch_model: torch.nn.Module,
|
|
|
89
89
|
try:
|
|
90
90
|
symbolic_traced = symbolic_trace(pytorch_model)
|
|
91
91
|
except torch.fx.proxy.TraceError as e:
|
|
92
|
-
Logger.
|
|
93
|
-
|
|
92
|
+
Logger.critical(f'Error parsing model with torch.fx\n'
|
|
93
|
+
f'fx error: {e}')
|
|
94
94
|
inputs = next(representative_data_gen())
|
|
95
95
|
input_for_shape_infer = [to_tensor(i) for i in inputs]
|
|
96
96
|
ShapeProp(symbolic_traced).propagate(*input_for_shape_infer)
|
model_compression_toolkit/core/pytorch/{kpi_data_facade.py → resource_utilization_data_facade.py}
RENAMED
|
@@ -18,9 +18,9 @@ from typing import Callable
|
|
|
18
18
|
from model_compression_toolkit.logger import Logger
|
|
19
19
|
from model_compression_toolkit.constants import PYTORCH
|
|
20
20
|
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
|
21
|
-
from model_compression_toolkit.core.common.mixed_precision.
|
|
21
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
|
22
22
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
23
|
-
from model_compression_toolkit.core.common.mixed_precision.
|
|
23
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import compute_resource_utilization_data
|
|
24
24
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
|
25
25
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
|
|
26
26
|
from model_compression_toolkit.constants import FOUND_TORCH
|
|
@@ -36,13 +36,14 @@ if FOUND_TORCH:
|
|
|
36
36
|
PYTORCH_DEFAULT_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
def
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
39
|
+
def pytorch_resource_utilization_data(in_model: Module,
|
|
40
|
+
representative_data_gen: Callable,
|
|
41
|
+
core_config: CoreConfig = CoreConfig(),
|
|
42
|
+
target_platform_capabilities: TargetPlatformCapabilities = PYTORCH_DEFAULT_TPC
|
|
43
|
+
) -> ResourceUtilization:
|
|
43
44
|
"""
|
|
44
|
-
Computes
|
|
45
|
-
Builds the computation graph from the given model and target platform capabilities, and uses it to compute the
|
|
45
|
+
Computes resource utilization data that can be used to calculate the desired target resource utilization for mixed-precision quantization.
|
|
46
|
+
Builds the computation graph from the given model and target platform capabilities, and uses it to compute the resource utilization data.
|
|
46
47
|
|
|
47
48
|
Args:
|
|
48
49
|
in_model (Model): PyTorch model to quantize.
|
|
@@ -52,7 +53,7 @@ if FOUND_TORCH:
|
|
|
52
53
|
|
|
53
54
|
Returns:
|
|
54
55
|
|
|
55
|
-
A
|
|
56
|
+
A ResourceUtilization object with total weights parameters sum and max activation tensor.
|
|
56
57
|
|
|
57
58
|
Examples:
|
|
58
59
|
|
|
@@ -66,29 +67,30 @@ if FOUND_TORCH:
|
|
|
66
67
|
>>> import numpy as np
|
|
67
68
|
>>> def repr_datagen(): yield [np.random.random((1, 3, 224, 224))]
|
|
68
69
|
|
|
69
|
-
Import mct and call for
|
|
70
|
+
Import mct and call for resource utilization data calculation:
|
|
70
71
|
|
|
71
72
|
>>> import model_compression_toolkit as mct
|
|
72
|
-
>>>
|
|
73
|
+
>>> ru_data = mct.core.pytorch_resource_utilization_data(module, repr_datagen)
|
|
73
74
|
|
|
74
75
|
"""
|
|
75
76
|
|
|
76
77
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
77
|
-
Logger.
|
|
78
|
-
|
|
78
|
+
Logger.critical("Resource utilization data computation requires a MixedPrecisionQuantizationConfig object. "
|
|
79
|
+
"The provided 'mixed_precision_config' is not of this type.")
|
|
79
80
|
|
|
80
81
|
fw_impl = PytorchImplementation()
|
|
81
82
|
|
|
82
|
-
return
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
83
|
+
return compute_resource_utilization_data(in_model,
|
|
84
|
+
representative_data_gen,
|
|
85
|
+
core_config,
|
|
86
|
+
target_platform_capabilities,
|
|
87
|
+
DEFAULT_PYTORCH_INFO,
|
|
88
|
+
fw_impl)
|
|
88
89
|
|
|
89
90
|
else:
|
|
90
91
|
# If torch is not installed,
|
|
91
92
|
# we raise an exception when trying to use this function.
|
|
92
|
-
def
|
|
93
|
-
Logger.critical(
|
|
94
|
-
'
|
|
93
|
+
def pytorch_resource_utilization_data(*args, **kwargs):
|
|
94
|
+
Logger.critical("PyTorch must be installed to use 'pytorch_resource_utilization_data'. "
|
|
95
|
+
"The 'torch' package is missing.") # pragma: no cover
|
|
96
|
+
|
|
@@ -16,6 +16,7 @@ import torch
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
from typing import Union
|
|
18
18
|
from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
|
|
19
|
+
from model_compression_toolkit.logger import Logger
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
def set_model(model: torch.nn.Module, train_mode: bool = False):
|
|
@@ -58,7 +59,7 @@ def to_torch_tensor(tensor):
|
|
|
58
59
|
elif isinstance(tensor, (int, float)):
|
|
59
60
|
return torch.from_numpy(np.array(tensor).astype(np.float32)).to(working_device)
|
|
60
61
|
else:
|
|
61
|
-
|
|
62
|
+
Logger.critical(f'Unsupported type for conversion to Torch.tensor: {type(tensor)}.')
|
|
62
63
|
|
|
63
64
|
|
|
64
65
|
def torch_tensor_to_numpy(tensor: Union[torch.Tensor, list, tuple]) -> Union[np.ndarray, list, tuple]:
|
|
@@ -79,4 +80,4 @@ def torch_tensor_to_numpy(tensor: Union[torch.Tensor, list, tuple]) -> Union[np.
|
|
|
79
80
|
elif isinstance(tensor, torch.Tensor):
|
|
80
81
|
return tensor.cpu().detach().contiguous().numpy()
|
|
81
82
|
else:
|
|
82
|
-
|
|
83
|
+
Logger.critical(f'Unsupported type for conversion to Numpy array: {type(tensor)}.')
|
|
@@ -26,10 +26,10 @@ from model_compression_toolkit.logger import Logger
|
|
|
26
26
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
27
27
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
28
28
|
from model_compression_toolkit.core.common.mixed_precision.bit_width_setter import set_bit_widths
|
|
29
|
-
from model_compression_toolkit.core.common.mixed_precision.
|
|
30
|
-
from model_compression_toolkit.core.common.mixed_precision.
|
|
31
|
-
from model_compression_toolkit.core.common.mixed_precision.
|
|
32
|
-
from model_compression_toolkit.core.common.mixed_precision.
|
|
29
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget
|
|
30
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_aggregation_methods import MpRuAggregation
|
|
31
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_functions_mapping import ru_functions_mapping
|
|
32
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import MpRuMetric
|
|
33
33
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_facade import search_bit_width
|
|
34
34
|
from model_compression_toolkit.core.common.network_editors.edit_network import edit_network_graph
|
|
35
35
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
|
@@ -47,7 +47,7 @@ def core_runner(in_model: Any,
|
|
|
47
47
|
fw_info: FrameworkInfo,
|
|
48
48
|
fw_impl: FrameworkImplementation,
|
|
49
49
|
tpc: TargetPlatformCapabilities,
|
|
50
|
-
|
|
50
|
+
target_resource_utilization: ResourceUtilization = None,
|
|
51
51
|
tb_w: TensorboardWriter = None):
|
|
52
52
|
"""
|
|
53
53
|
Quantize a trained model using post-training quantization.
|
|
@@ -67,7 +67,7 @@ def core_runner(in_model: Any,
|
|
|
67
67
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
|
68
68
|
tpc: TargetPlatformCapabilities object that models the inference target platform and
|
|
69
69
|
the attached framework operator's information.
|
|
70
|
-
|
|
70
|
+
target_resource_utilization: ResourceUtilization to constraint the search of the mixed-precision configuration for the model.
|
|
71
71
|
tb_w: TensorboardWriter object for logging
|
|
72
72
|
|
|
73
73
|
Returns:
|
|
@@ -84,9 +84,9 @@ def core_runner(in_model: Any,
|
|
|
84
84
|
' consider increasing the batch size')
|
|
85
85
|
|
|
86
86
|
# Checking whether to run mixed precision quantization
|
|
87
|
-
if
|
|
87
|
+
if target_resource_utilization is not None:
|
|
88
88
|
if core_config.mixed_precision_config is None:
|
|
89
|
-
Logger.critical("Provided an initialized
|
|
89
|
+
Logger.critical("Provided an initialized target_resource_utilization, that means that mixed precision quantization is "
|
|
90
90
|
"enabled, but the provided MixedPrecisionQuantizationConfig is None.")
|
|
91
91
|
core_config.mixed_precision_config.set_mixed_precision_enable()
|
|
92
92
|
|
|
@@ -119,7 +119,7 @@ def core_runner(in_model: Any,
|
|
|
119
119
|
bit_widths_config = search_bit_width(tg,
|
|
120
120
|
fw_info,
|
|
121
121
|
fw_impl,
|
|
122
|
-
|
|
122
|
+
target_resource_utilization,
|
|
123
123
|
core_config.mixed_precision_config,
|
|
124
124
|
representative_data_gen,
|
|
125
125
|
hessian_info_service=hessian_info_service)
|
|
@@ -139,11 +139,11 @@ def core_runner(in_model: Any,
|
|
|
139
139
|
# This is since some actions regard the final configuration and should be edited.
|
|
140
140
|
edit_network_graph(tg, fw_info, core_config.debug_config.network_editor)
|
|
141
141
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
142
|
+
_set_final_resource_utilization(graph=tg,
|
|
143
|
+
final_bit_widths_config=bit_widths_config,
|
|
144
|
+
ru_functions_dict=ru_functions_mapping,
|
|
145
|
+
fw_info=fw_info,
|
|
146
|
+
fw_impl=fw_impl)
|
|
147
147
|
|
|
148
148
|
if core_config.mixed_precision_enable:
|
|
149
149
|
# Retrieve lists of tuples (node, node's final weights/activation bitwidth)
|
|
@@ -164,49 +164,50 @@ def core_runner(in_model: Any,
|
|
|
164
164
|
return tg, bit_widths_config, hessian_info_service
|
|
165
165
|
|
|
166
166
|
|
|
167
|
-
def
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
167
|
+
def _set_final_resource_utilization(graph: Graph,
|
|
168
|
+
final_bit_widths_config: List[int],
|
|
169
|
+
ru_functions_dict: Dict[RUTarget, Tuple[MpRuMetric, MpRuAggregation]],
|
|
170
|
+
fw_info: FrameworkInfo,
|
|
171
|
+
fw_impl: FrameworkImplementation):
|
|
172
172
|
"""
|
|
173
|
-
Computing the
|
|
173
|
+
Computing the resource utilization of the model according to the final bit-width configuration,
|
|
174
174
|
and setting it (inplace) in the graph's UserInfo field.
|
|
175
175
|
|
|
176
176
|
Args:
|
|
177
|
-
graph: Graph to compute the
|
|
177
|
+
graph: Graph to compute the resource utilization for.
|
|
178
178
|
final_bit_widths_config: The final bit-width configuration to quantize the model accordingly.
|
|
179
|
-
|
|
179
|
+
ru_functions_dict: A mapping between a RUTarget and a pair of resource utilization method and resource utilization aggregation functions.
|
|
180
180
|
fw_info: A FrameworkInfo object.
|
|
181
181
|
fw_impl: FrameworkImplementation object with specific framework methods implementation.
|
|
182
182
|
|
|
183
183
|
"""
|
|
184
184
|
|
|
185
|
-
|
|
186
|
-
for
|
|
187
|
-
|
|
188
|
-
if
|
|
189
|
-
|
|
185
|
+
final_ru_dict = {}
|
|
186
|
+
for ru_target, ru_funcs in ru_functions_dict.items():
|
|
187
|
+
ru_method, ru_aggr = ru_funcs
|
|
188
|
+
if ru_target == RUTarget.BOPS:
|
|
189
|
+
final_ru_dict[ru_target] = \
|
|
190
|
+
ru_aggr(ru_method(final_bit_widths_config, graph, fw_info, fw_impl, False), False)[0]
|
|
190
191
|
else:
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
if len(final_bit_widths_config) > 0 and len(
|
|
194
|
-
|
|
195
|
-
elif len(final_bit_widths_config) > 0 and len(
|
|
196
|
-
|
|
197
|
-
elif len(final_bit_widths_config) == 0 and len(
|
|
192
|
+
non_conf_ru = ru_method([], graph, fw_info, fw_impl)
|
|
193
|
+
conf_ru = ru_method(final_bit_widths_config, graph, fw_info, fw_impl)
|
|
194
|
+
if len(final_bit_widths_config) > 0 and len(non_conf_ru) > 0:
|
|
195
|
+
final_ru_dict[ru_target] = ru_aggr(np.concatenate([conf_ru, non_conf_ru]), False)[0]
|
|
196
|
+
elif len(final_bit_widths_config) > 0 and len(non_conf_ru) == 0:
|
|
197
|
+
final_ru_dict[ru_target] = ru_aggr(conf_ru, False)[0]
|
|
198
|
+
elif len(final_bit_widths_config) == 0 and len(non_conf_ru) > 0:
|
|
198
199
|
# final_bit_widths_config == 0 ==> no configurable nodes,
|
|
199
|
-
# thus,
|
|
200
|
-
|
|
200
|
+
# thus, ru can be computed from non_conf_ru alone
|
|
201
|
+
final_ru_dict[ru_target] = ru_aggr(non_conf_ru, False)[0]
|
|
201
202
|
else:
|
|
202
203
|
# No relevant nodes have been quantized with affect on the given target - since we only consider
|
|
203
204
|
# in the model's final size the quantized layers size, this means that the final size for this target
|
|
204
205
|
# is zero.
|
|
205
|
-
Logger.warning(f"No relevant quantized layers for the
|
|
206
|
-
f"final
|
|
207
|
-
|
|
206
|
+
Logger.warning(f"No relevant quantized layers for the ru target {ru_target} were found, the recorded"
|
|
207
|
+
f"final ru for this target would be 0.")
|
|
208
|
+
final_ru_dict[ru_target] = 0
|
|
208
209
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
graph.user_info.
|
|
210
|
+
final_ru = ResourceUtilization()
|
|
211
|
+
final_ru.set_resource_utilization_by_target(final_ru_dict)
|
|
212
|
+
graph.user_info.final_resource_utilization = final_ru
|
|
212
213
|
graph.user_info.mixed_precision_cfg = final_bit_widths_config
|
|
@@ -64,53 +64,53 @@ def get_data_generation_classes(
|
|
|
64
64
|
|
|
65
65
|
# Check if the image pipeline type is valid
|
|
66
66
|
if image_pipeline is None:
|
|
67
|
-
Logger.
|
|
68
|
-
f'Invalid image_pipeline_type {data_generation_config.image_pipeline_type}.'
|
|
69
|
-
f'Please
|
|
67
|
+
Logger.critical(
|
|
68
|
+
f'Invalid image_pipeline_type {data_generation_config.image_pipeline_type}. '
|
|
69
|
+
f'Please select one from {ImagePipelineType.get_values()}.')
|
|
70
70
|
|
|
71
71
|
# Get the normalization values corresponding to the specified type
|
|
72
72
|
normalization = image_normalization_dict.get(data_generation_config.image_normalization_type)
|
|
73
73
|
|
|
74
74
|
# Check if the image normalization type is valid
|
|
75
75
|
if normalization is None:
|
|
76
|
-
Logger.
|
|
77
|
-
f'Invalid image_normalization_type {data_generation_config.image_normalization_type}.'
|
|
78
|
-
f'Please
|
|
76
|
+
Logger.critical(
|
|
77
|
+
f'Invalid image_normalization_type {data_generation_config.image_normalization_type}. '
|
|
78
|
+
f'Please select one from {ImageNormalizationType.get_values()}.')
|
|
79
79
|
|
|
80
80
|
# Get the layer weighting function corresponding to the specified type
|
|
81
81
|
bn_layer_weighting_fn = bn_layer_weighting_function_dict.get(data_generation_config.layer_weighting_type)
|
|
82
82
|
|
|
83
83
|
if bn_layer_weighting_fn is None:
|
|
84
|
-
Logger.
|
|
85
|
-
f'Invalid layer_weighting_type {data_generation_config.layer_weighting_type}.'
|
|
86
|
-
f'Please
|
|
84
|
+
Logger.critical(
|
|
85
|
+
f'Invalid layer_weighting_type {data_generation_config.layer_weighting_type}. '
|
|
86
|
+
f'Please select one from {BNLayerWeightingType.get_values()}.')
|
|
87
87
|
|
|
88
88
|
# Get the image initialization function corresponding to the specified type
|
|
89
89
|
image_initialization_fn = image_initialization_function_dict.get(data_generation_config.data_init_type)
|
|
90
90
|
|
|
91
91
|
# Check if the data initialization type is valid
|
|
92
92
|
if image_initialization_fn is None:
|
|
93
|
-
Logger.
|
|
94
|
-
f'Invalid data_init_type {data_generation_config.data_init_type}.'
|
|
95
|
-
f'Please
|
|
93
|
+
Logger.critical(
|
|
94
|
+
f'Invalid data_init_type {data_generation_config.data_init_type}. '
|
|
95
|
+
f'Please select one from {DataInitType.get_values()}.')
|
|
96
96
|
|
|
97
97
|
# Get the BatchNorm alignment loss function corresponding to the specified type
|
|
98
98
|
bn_alignment_loss_fn = bn_alignment_loss_function_dict.get(data_generation_config.bn_alignment_loss_type)
|
|
99
99
|
|
|
100
100
|
# Check if the BatchNorm alignment loss type is valid
|
|
101
101
|
if bn_alignment_loss_fn is None:
|
|
102
|
-
Logger.
|
|
103
|
-
f'Invalid bn_alignment_loss_type {data_generation_config.bn_alignment_loss_type}.'
|
|
104
|
-
f'Please
|
|
102
|
+
Logger.critical(
|
|
103
|
+
f'Invalid bn_alignment_loss_type {data_generation_config.bn_alignment_loss_type}. '
|
|
104
|
+
f'Please select one from {BatchNormAlignemntLossType.get_values()}.')
|
|
105
105
|
|
|
106
106
|
# Get the output loss function corresponding to the specified type
|
|
107
107
|
output_loss_fn = output_loss_function_dict.get(data_generation_config.output_loss_type)
|
|
108
108
|
|
|
109
109
|
# Check if the output loss type is valid
|
|
110
110
|
if output_loss_fn is None:
|
|
111
|
-
Logger.
|
|
112
|
-
f'Invalid output_loss_type {data_generation_config.output_loss_type}.'
|
|
113
|
-
f'Please
|
|
111
|
+
Logger.critical(
|
|
112
|
+
f'Invalid output_loss_type {data_generation_config.output_loss_type}. '
|
|
113
|
+
f'Please select one from {OutputLossType.get_values()}.')
|
|
114
114
|
|
|
115
115
|
# Initialize the dataset for data generation
|
|
116
116
|
init_dataset = image_initialization_fn(
|
|
@@ -35,7 +35,7 @@ class OriginalBNStatsHolder:
|
|
|
35
35
|
"""
|
|
36
36
|
self.bn_params = self.get_bn_params(model, bn_layer_types)
|
|
37
37
|
if self.get_num_bn_layers() == 0:
|
|
38
|
-
Logger.
|
|
38
|
+
Logger.critical(
|
|
39
39
|
f'Data generation requires a model with at least one BatchNorm layer.')
|
|
40
40
|
|
|
41
41
|
def get_bn_layer_names(self) -> List[str]:
|
|
@@ -181,18 +181,16 @@ if FOUND_TF:
|
|
|
181
181
|
output_loss_function_dict=output_loss_function_dict)
|
|
182
182
|
|
|
183
183
|
if not all(normalization[1]):
|
|
184
|
-
Logger.
|
|
185
|
-
f'Invalid normalization
|
|
186
|
-
f'will lead to division by zero., Please choose non-zero normalization std')
|
|
184
|
+
Logger.critical(
|
|
185
|
+
f'Invalid normalization standard deviation {normalization[1]} set to zero, which will lead to division by zero. Please select a non-zero normalization standard deviation.')
|
|
187
186
|
|
|
188
187
|
# Get the scheduler functions corresponding to the specified scheduler type
|
|
189
188
|
scheduler_get_fn = scheduler_step_function_dict.get(data_generation_config.scheduler_type)
|
|
190
189
|
|
|
191
190
|
# Check if the scheduler type is valid
|
|
192
191
|
if scheduler_get_fn is None:
|
|
193
|
-
Logger.
|
|
194
|
-
f'Invalid scheduler_type {data_generation_config.scheduler_type}.'
|
|
195
|
-
f'Please choose one of {SchedulerType.get_values()}')
|
|
192
|
+
Logger.critical(
|
|
193
|
+
f'Invalid scheduler_type {data_generation_config.scheduler_type}. Please select one from {SchedulerType.get_values()}.')
|
|
196
194
|
|
|
197
195
|
# Create a scheduler object with the specified number of iterations
|
|
198
196
|
scheduler = scheduler_get_fn(n_iter=data_generation_config.n_iter,
|
|
@@ -358,10 +356,9 @@ if FOUND_TF:
|
|
|
358
356
|
|
|
359
357
|
else:
|
|
360
358
|
def get_keras_data_generation_config(*args, **kwargs):
|
|
361
|
-
Logger.critical(
|
|
362
|
-
|
|
359
|
+
Logger.critical(
|
|
360
|
+
"Tensorflow must be installed to use get_tensorflow_data_generation_config. The 'tensorflow' package is missing.") # pragma: no cover
|
|
363
361
|
|
|
364
362
|
|
|
365
363
|
def keras_data_generation_experimental(*args, **kwargs):
|
|
366
|
-
Logger.critical(
|
|
367
|
-
'Could not find Tensorflow package.') # pragma: no cover
|
|
364
|
+
Logger.critical("Tensorflow must be installed to use tensorflow_data_generation_experimental. The 'tensorflow' package is missing.") # pragma: no cover
|
|
@@ -21,6 +21,7 @@ from tensorflow.keras.layers import BatchNormalization
|
|
|
21
21
|
from model_compression_toolkit.data_generation.common.enums import ImageGranularity
|
|
22
22
|
from model_compression_toolkit.data_generation.common.model_info_exctractors import OriginalBNStatsHolder, \
|
|
23
23
|
ActivationExtractor
|
|
24
|
+
from model_compression_toolkit.logger import Logger
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class KerasOriginalBNStatsHolder(OriginalBNStatsHolder):
|
|
@@ -95,7 +96,7 @@ class KerasActivationExtractor(ActivationExtractor):
|
|
|
95
96
|
self.bn_layer_names = [layer.name for layer in model.layers if isinstance(layer,
|
|
96
97
|
self.layer_types_to_extract_inputs)]
|
|
97
98
|
self.num_layers = len(self.bn_layer_names)
|
|
98
|
-
|
|
99
|
+
Logger.info(f'Number of layers = {self.num_layers}')
|
|
99
100
|
|
|
100
101
|
# Initialize stats containers
|
|
101
102
|
self.activations = {}
|
model_compression_toolkit/data_generation/keras/optimization_functions/image_initilization.py
CHANGED
|
@@ -21,6 +21,7 @@ from tensorflow.data import Dataset
|
|
|
21
21
|
|
|
22
22
|
from model_compression_toolkit.data_generation.common.constants import NUM_INPUT_CHANNELS
|
|
23
23
|
from model_compression_toolkit.data_generation.common.enums import DataInitType
|
|
24
|
+
from model_compression_toolkit.logger import Logger
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
# Define a function to generate a dataset of Gaussian noise images.
|
|
@@ -91,7 +92,7 @@ def get_random_gaussian_data(
|
|
|
91
92
|
Returns:
|
|
92
93
|
Tuple[int, Any]: A tuple containing the number of batches and a data loader iterator.
|
|
93
94
|
"""
|
|
94
|
-
|
|
95
|
+
Logger.info(f'Start generating random Gaussian data')
|
|
95
96
|
image_shape = size + (NUM_INPUT_CHANNELS,)
|
|
96
97
|
dataset = generate_gaussian_noise_images(num_samples=n_images, image_shape=image_shape,
|
|
97
98
|
mean=mean_factor, std=std_factor, batch_size=batch_size)
|
model_compression_toolkit/data_generation/keras/optimization_functions/output_loss_functions.py
CHANGED
|
@@ -43,10 +43,8 @@ def regularized_min_max_diff(
|
|
|
43
43
|
tf.Tensor: The calculated loss.
|
|
44
44
|
"""
|
|
45
45
|
if activation_extractor.last_linear_layers is None:
|
|
46
|
-
Logger.
|
|
47
|
-
f'Cannot compute regularized min
|
|
48
|
-
f'requires linear layer without a following BatchNormalization layer. Please choose one of '
|
|
49
|
-
f'{OutputLossType.get_values()}.')
|
|
46
|
+
Logger.critical(
|
|
47
|
+
f'Cannot compute regularized min-max output loss for the input model. This loss requires a linear layer without a subsequent BatchNormalization layer. Please select one from {OutputLossType.get_values()}.')
|
|
50
48
|
|
|
51
49
|
with tape.stop_recording():
|
|
52
50
|
weights_last_layer = activation_extractor.last_linear_layers.get_weights()[0]
|
|
@@ -23,6 +23,7 @@ from model_compression_toolkit.data_generation.common.model_info_exctractors imp
|
|
|
23
23
|
ActivationExtractor
|
|
24
24
|
from model_compression_toolkit.data_generation.pytorch.constants import OUTPUT
|
|
25
25
|
from model_compression_toolkit.data_generation.common.constants import IMAGE_INPUT, NUM_INPUT_CHANNELS
|
|
26
|
+
from model_compression_toolkit.logger import Logger
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
class PytorchOriginalBNStatsHolder(OriginalBNStatsHolder):
|
|
@@ -125,7 +126,7 @@ class PytorchActivationExtractor(ActivationExtractor):
|
|
|
125
126
|
self.layer_types_to_extract_inputs = tuple(layer_types_to_extract_inputs)
|
|
126
127
|
self.last_layer_types_to_extract_inputs = tuple(last_layer_types_to_extract_inputs)
|
|
127
128
|
self.num_layers = sum([1 if isinstance(layer, tuple(layer_types_to_extract_inputs)) else 0 for layer in model.modules()])
|
|
128
|
-
|
|
129
|
+
Logger.info(f'Number of layers = {self.num_layers}')
|
|
129
130
|
self.hooks = {} # Dictionary to store InputHook instances by layer name
|
|
130
131
|
self.last_linear_layers_hooks = {} # Dictionary to store InputHook instances by layer name
|
|
131
132
|
self.hook_handles = [] # List to store hook handles
|
|
@@ -198,9 +198,8 @@ if FOUND_TORCH:
|
|
|
198
198
|
|
|
199
199
|
# Check if the scheduler type is valid
|
|
200
200
|
if scheduler_get_fn is None or scheduler_step_fn is None:
|
|
201
|
-
Logger.
|
|
202
|
-
|
|
203
|
-
f'{SchedulerType.get_values()}')
|
|
201
|
+
Logger.critical(f'Invalid output_loss_type {data_generation_config.scheduler_type}. '
|
|
202
|
+
f'Please select one from {SchedulerType.get_values()}.')
|
|
204
203
|
|
|
205
204
|
# Create a scheduler object with the specified number of iterations
|
|
206
205
|
scheduler = scheduler_get_fn(data_generation_config.n_iter)
|
|
@@ -218,8 +217,8 @@ if FOUND_TORCH:
|
|
|
218
217
|
# Create an orig_bn_stats_holder object to hold original BatchNorm statistics
|
|
219
218
|
orig_bn_stats_holder = PytorchOriginalBNStatsHolder(model, data_generation_config.bn_layer_types)
|
|
220
219
|
if orig_bn_stats_holder.get_num_bn_layers() == 0:
|
|
221
|
-
Logger.
|
|
222
|
-
f'Data generation requires a model with at least one
|
|
220
|
+
Logger.critical(
|
|
221
|
+
f'Data generation requires a model with at least one BatchNorm layer.')
|
|
223
222
|
|
|
224
223
|
# Create an ImagesOptimizationHandler object for handling optimization
|
|
225
224
|
all_imgs_opt_handler = PytorchImagesOptimizationHandler(model=model,
|
|
@@ -355,12 +354,10 @@ else:
|
|
|
355
354
|
# If torch is not installed,
|
|
356
355
|
# we raise an exception when trying to use these functions.
|
|
357
356
|
def get_pytorch_data_generation_config(*args, **kwargs):
|
|
358
|
-
Logger.critical('
|
|
359
|
-
'
|
|
360
|
-
'Could not find torch package.') # pragma: no cover
|
|
357
|
+
Logger.critical('PyTorch must be installed to use get_pytorch_data_generation_config. '
|
|
358
|
+
"The 'torch' package is missing.") # pragma: no cover
|
|
361
359
|
|
|
362
360
|
|
|
363
361
|
def pytorch_data_generation_experimental(*args, **kwargs):
|
|
364
|
-
Logger.critical(
|
|
365
|
-
'
|
|
366
|
-
'Could not find the torch package.') # pragma: no cover
|
|
362
|
+
Logger.critical("PyTorch must be installed to use 'pytorch_data_generation_experimental'. "
|
|
363
|
+
"The 'torch' package is missing.") # pragma: no cover
|
|
@@ -99,7 +99,7 @@ class FakelyQuantKerasExporter(BaseKerasExporter):
|
|
|
99
99
|
elif isinstance(layer.layer, (layers.Conv2D, layers.Dense, layers.Conv2DTranspose)):
|
|
100
100
|
weights_list.append(layer.get_quantized_weights()['kernel'])
|
|
101
101
|
else:
|
|
102
|
-
Logger.
|
|
102
|
+
Logger.critical(f'KerasQuantizationWrapper should wrap only DepthwiseConv2D, Conv2D, Dense'
|
|
103
103
|
f' and Conv2DTranspose layers but wrapped layer is {layer.layer}')
|
|
104
104
|
|
|
105
105
|
if layer.layer.bias is not None:
|
|
@@ -101,6 +101,5 @@ if FOUND_TF:
|
|
|
101
101
|
return exporter.get_custom_objects()
|
|
102
102
|
else:
|
|
103
103
|
def keras_export_model(*args, **kwargs):
|
|
104
|
-
Logger.
|
|
105
|
-
|
|
106
|
-
'Could not find some or all of TensorFlow packages.') # pragma: no cover
|
|
104
|
+
Logger.critical("Tensorflow must be installed to use keras_export_model. "
|
|
105
|
+
"The 'tensorflow' package is missing.") # pragma: no cover
|
|
@@ -103,6 +103,5 @@ if FOUND_TORCH:
|
|
|
103
103
|
|
|
104
104
|
else:
|
|
105
105
|
def pytorch_export_model(*args, **kwargs):
|
|
106
|
-
Logger.
|
|
107
|
-
|
|
108
|
-
'Could not find PyTorch packages.') # pragma: no cover
|
|
106
|
+
Logger.critical("PyTorch must be installed to use 'pytorch_export_model'. "
|
|
107
|
+
"The 'torch' package is missing.") # pragma: no cover
|
model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -65,7 +65,7 @@ if FOUND_TF:
|
|
|
65
65
|
if len(activation_quantizers) == 1:
|
|
66
66
|
return KerasActivationQuantizationHolder(activation_quantizers[0])
|
|
67
67
|
|
|
68
|
-
Logger.
|
|
68
|
+
Logger.critical(
|
|
69
69
|
f'ActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
|
|
70
70
|
f'were found for node {node}')
|
|
71
71
|
|
|
@@ -89,9 +89,13 @@ if FOUND_TF:
|
|
|
89
89
|
get_activation_quantizer_holder(n,
|
|
90
90
|
fw_impl=C.keras.keras_implementation.KerasImplementation())).build_model()
|
|
91
91
|
exportable_model.trainable = False
|
|
92
|
+
|
|
93
|
+
Logger.info("Please run your accuracy evaluation on the exported quantized model to verify it's accuracy.\n"
|
|
94
|
+
"Checkout the FAQ and Troubleshooting pages for resolving common issues and improving the quantized model accuracy:\n"
|
|
95
|
+
"FAQ: https://github.com/sony/model_optimization/tree/main/FAQ.md"
|
|
96
|
+
"Quantization Troubleshooting: https://github.com/sony/model_optimization/tree/main/quantization_troubleshooting.md")
|
|
92
97
|
return exportable_model, user_info
|
|
93
98
|
else:
|
|
94
99
|
def get_exportable_keras_model(*args, **kwargs): # pragma: no cover
|
|
95
|
-
Logger.
|
|
96
|
-
|
|
97
|
-
'Could not find Tensorflow package.')
|
|
100
|
+
Logger.critical("Tensorflow must be installed to use get_exportable_keras_model. "
|
|
101
|
+
"The 'tensorflow' package is missing.") # pragma: no cover
|
|
@@ -44,7 +44,7 @@ def get_inferable_quantizer_kwargs(node_qc: BaseNodeQuantizationConfig,
|
|
|
44
44
|
|
|
45
45
|
if quantization_target == QuantizationTarget.Weights:
|
|
46
46
|
if not isinstance(node_qc, NodeWeightsQuantizationConfig):
|
|
47
|
-
Logger.
|
|
47
|
+
Logger.critical(f"Non-compatible node quantization config was given for quantization target Weights.") # pragma: no cover
|
|
48
48
|
|
|
49
49
|
if attr_name is None:
|
|
50
50
|
Logger.error(f"Attribute name was not specified for retrieving weights quantizer kwargs.")
|
|
@@ -84,7 +84,7 @@ def get_inferable_quantizer_kwargs(node_qc: BaseNodeQuantizationConfig,
|
|
|
84
84
|
|
|
85
85
|
elif quantization_target == QuantizationTarget.Activation:
|
|
86
86
|
if not isinstance(node_qc, NodeActivationQuantizationConfig):
|
|
87
|
-
Logger.
|
|
87
|
+
Logger.critical(f"Non-compatible node quantization config was given for quantization target Activation.") # pragma: no cover
|
|
88
88
|
|
|
89
89
|
quantization_method = node_qc.activation_quantization_method
|
|
90
90
|
|