mct-nightly 1.11.0.20240321.357__py3-none-any.whl → 1.11.0.20240323.408__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240323.408.dist-info}/METADATA +17 -9
- {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240323.408.dist-info}/RECORD +152 -152
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/constants.py +1 -1
- model_compression_toolkit/core/__init__.py +3 -3
- model_compression_toolkit/core/common/collectors/base_collector.py +2 -2
- model_compression_toolkit/core/common/data_loader.py +3 -3
- model_compression_toolkit/core/common/graph/base_graph.py +10 -13
- model_compression_toolkit/core/common/graph/base_node.py +3 -3
- model_compression_toolkit/core/common/graph/edge.py +2 -1
- model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +2 -4
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
- model_compression_toolkit/core/common/hessian/hessian_info_service.py +2 -3
- model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +1 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +24 -23
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +110 -112
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +114 -0
- model_compression_toolkit/core/common/mixed_precision/{kpi_tools/kpi_data.py → resource_utilization_tools/resource_utilization_data.py} +19 -19
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +105 -0
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +26 -0
- model_compression_toolkit/core/common/mixed_precision/{kpi_tools/kpi_methods.py → resource_utilization_tools/ru_methods.py} +61 -61
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +75 -71
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -4
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +34 -34
- model_compression_toolkit/core/common/model_collector.py +2 -2
- model_compression_toolkit/core/common/network_editors/actions.py +3 -3
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +12 -12
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +2 -2
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +2 -2
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +7 -7
- model_compression_toolkit/core/common/pruning/prune_graph.py +2 -3
- model_compression_toolkit/core/common/pruning/pruner.py +7 -7
- model_compression_toolkit/core/common/pruning/pruning_config.py +1 -1
- model_compression_toolkit/core/common/pruning/pruning_info.py +2 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +7 -4
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +4 -6
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -4
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +8 -6
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +2 -2
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +4 -6
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +4 -7
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +3 -3
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +3 -3
- model_compression_toolkit/core/common/user_info.py +1 -1
- model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +3 -3
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +2 -2
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +4 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +2 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py +3 -3
- model_compression_toolkit/core/keras/hessian/trace_hessian_calculator_keras.py +1 -2
- model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py +5 -6
- model_compression_toolkit/core/keras/keras_implementation.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +2 -4
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +7 -7
- model_compression_toolkit/core/keras/reader/common.py +2 -2
- model_compression_toolkit/core/keras/reader/node_builder.py +1 -1
- model_compression_toolkit/core/keras/{kpi_data_facade.py → resource_utilization_data_facade.py} +25 -24
- model_compression_toolkit/core/keras/tf_tensor_numpy.py +4 -2
- model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +3 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +6 -11
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +2 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/hessian/activation_trace_hessian_calculator_pytorch.py +3 -7
- model_compression_toolkit/core/pytorch/hessian/trace_hessian_calculator_pytorch.py +1 -2
- model_compression_toolkit/core/pytorch/hessian/weights_trace_hessian_calculator_pytorch.py +2 -2
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +3 -3
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +5 -7
- model_compression_toolkit/core/pytorch/reader/reader.py +2 -2
- model_compression_toolkit/core/pytorch/{kpi_data_facade.py → resource_utilization_data_facade.py} +24 -22
- model_compression_toolkit/core/pytorch/utils.py +3 -2
- model_compression_toolkit/core/runner.py +43 -42
- model_compression_toolkit/data_generation/common/data_generation.py +18 -18
- model_compression_toolkit/data_generation/common/model_info_exctractors.py +1 -1
- model_compression_toolkit/data_generation/keras/keras_data_generation.py +7 -10
- model_compression_toolkit/data_generation/keras/model_info_exctractors.py +2 -1
- model_compression_toolkit/data_generation/keras/optimization_functions/image_initilization.py +2 -1
- model_compression_toolkit/data_generation/keras/optimization_functions/output_loss_functions.py +2 -4
- model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py +2 -1
- model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py +8 -11
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -3
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -3
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +8 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +7 -8
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +19 -12
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +10 -11
- model_compression_toolkit/gptq/common/gptq_graph.py +3 -3
- model_compression_toolkit/gptq/common/gptq_training.py +14 -12
- model_compression_toolkit/gptq/keras/gptq_training.py +10 -8
- model_compression_toolkit/gptq/keras/graph_info.py +1 -1
- model_compression_toolkit/gptq/keras/quantization_facade.py +15 -17
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +4 -5
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +1 -2
- model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -8
- model_compression_toolkit/gptq/pytorch/graph_info.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +11 -13
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -4
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +1 -2
- model_compression_toolkit/logger.py +1 -13
- model_compression_toolkit/pruning/keras/pruning_facade.py +11 -12
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +11 -12
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -14
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -8
- model_compression_toolkit/qat/keras/quantization_facade.py +20 -22
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -3
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +12 -14
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -3
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +1 -1
- model_compression_toolkit/target_platform_capabilities/immutable.py +4 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +4 -8
- model_compression_toolkit/target_platform_capabilities/target_platform/current_tp_model.py +1 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/fusing.py +43 -8
- model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +13 -18
- model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +2 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attribute_filter.py +2 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/current_tpc.py +2 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +5 -5
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +1 -2
- model_compression_toolkit/trainable_infrastructure/common/base_trainable_quantizer.py +13 -13
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +14 -7
- model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +5 -5
- model_compression_toolkit/trainable_infrastructure/keras/base_keras_quantizer.py +2 -3
- model_compression_toolkit/trainable_infrastructure/keras/load_model.py +4 -5
- model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py +3 -4
- model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi.py +0 -112
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_aggregation_methods.py +0 -105
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_functions_mapping.py +0 -26
- {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240323.408.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240323.408.dist-info}/WHEEL +0 -0
- {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240323.408.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/common/mixed_precision/{kpi_tools → resource_utilization_tools}/__init__.py +0 -0
|
@@ -207,13 +207,13 @@ class MemoryCalculator:
|
|
|
207
207
|
kernel_attr = self.fw_info.get_kernel_op_attributes(node.type)
|
|
208
208
|
# Ensure only one kernel attribute exists for the given node.
|
|
209
209
|
if len(kernel_attr) != 1:
|
|
210
|
-
Logger.
|
|
210
|
+
Logger.critical(f"Expected a single attribute, but found {len(kernel_attr)} attributes for node '{node}'. Ensure the node configuration is correct.")
|
|
211
211
|
kernel_attr = kernel_attr[0]
|
|
212
212
|
|
|
213
213
|
# Retrieve and validate the axis index for the output channels.
|
|
214
214
|
_, ic_axis = self.fw_info.kernel_channels_mapping.get(node.type)
|
|
215
215
|
if ic_axis is None or int(ic_axis) != ic_axis:
|
|
216
|
-
Logger.
|
|
216
|
+
Logger.critical(f"Invalid input channel axis type for node '{node}': expected integer but got '{ic_axis}'.")
|
|
217
217
|
|
|
218
218
|
# Get the number of output channels based on the kernel attribute and axis.
|
|
219
219
|
num_ic = node.get_weights_by_keys(kernel_attr).shape[ic_axis]
|
|
@@ -295,7 +295,7 @@ class MemoryCalculator:
|
|
|
295
295
|
for w_attr, w in node.weights.items():
|
|
296
296
|
io_axis = [io_axis for attr, io_axis in attributes_and_oc_axis.items() if attr in w_attr]
|
|
297
297
|
if len(io_axis) != 1:
|
|
298
|
-
Logger.
|
|
298
|
+
Logger.critical(f"Each weight must correspond to exactly one IO (Input/Output) axis; however, the current configuration has '{io_axis}' axes.")
|
|
299
299
|
out_axis, in_axis = io_axis[0]
|
|
300
300
|
|
|
301
301
|
# Apply input and output masks to the weight tensor.
|
|
@@ -313,7 +313,7 @@ class MemoryCalculator:
|
|
|
313
313
|
# Get the node channel axis from framework info
|
|
314
314
|
channel_axis = self.fw_info.out_channel_axis_mapping.get(node.type)
|
|
315
315
|
if channel_axis is None:
|
|
316
|
-
Logger.
|
|
316
|
+
Logger.critical(f"The channel axis is undefined. Please ensure the channel axis is explicitly defined for node {node.type} in the framework info.")
|
|
317
317
|
|
|
318
318
|
# Check if node.output_shape is a list of lists.
|
|
319
319
|
# In this case make sure all the out channels are the same value
|
|
@@ -322,7 +322,7 @@ class MemoryCalculator:
|
|
|
322
322
|
if all(len(sublist) > channel_axis and sublist[channel_axis] == compare_value for sublist in node.output_shape):
|
|
323
323
|
num_oc = compare_value
|
|
324
324
|
else:
|
|
325
|
-
Logger.
|
|
325
|
+
Logger.critical("The number of output channels must be the same across all outputs of the node.")
|
|
326
326
|
else:
|
|
327
327
|
num_oc = node.output_shape[channel_axis]
|
|
328
328
|
|
|
@@ -348,7 +348,7 @@ class MemoryCalculator:
|
|
|
348
348
|
"""
|
|
349
349
|
mask = np.ones(w.shape[axis], dtype=bool) if mask is None else mask.astype(bool)
|
|
350
350
|
if w.shape[axis] != len(mask):
|
|
351
|
-
Logger.
|
|
351
|
+
Logger.critical(f"Expected a mask length of {len(mask)}, but got {w.shape[axis]}. Ensure the mask aligns with the tensor shape.")
|
|
352
352
|
pruned_w = np.take(w, np.where(mask)[0], axis=axis)
|
|
353
353
|
return pruned_w
|
|
354
354
|
|
|
@@ -370,7 +370,7 @@ class MemoryCalculator:
|
|
|
370
370
|
The adjusted number of parameters considering padded channels.
|
|
371
371
|
"""
|
|
372
372
|
if not (num_oc >= 1 and int(num_oc) == num_oc):
|
|
373
|
-
Logger.
|
|
373
|
+
Logger.critical(f"Expected the number of output channels to be a non-negative integer, but received '{num_oc}'.")
|
|
374
374
|
|
|
375
375
|
nparams_per_oc = node_nparams / num_oc
|
|
376
376
|
if int(nparams_per_oc) != nparams_per_oc:
|
|
@@ -50,8 +50,7 @@ def build_pruned_graph(graph: Graph,
|
|
|
50
50
|
|
|
51
51
|
# Check that each entry node corresponds to a pruning section has an output-channel mask.
|
|
52
52
|
if len(pruning_sections) != len(masks):
|
|
53
|
-
Logger.
|
|
54
|
-
f"but {len(masks)} masks were given and found {len(pruning_sections)} pruning sections.")
|
|
53
|
+
Logger.critical(f"Expected to find the same number of masks as the number of pruning sections, but {len(masks)} masks were given for {len(pruning_sections)} pruning sections.") # progmra: no cover
|
|
55
54
|
|
|
56
55
|
# Apply the pruning masks to each pruning section.
|
|
57
56
|
for pruning_section in pruning_sections:
|
|
@@ -59,7 +58,7 @@ def build_pruned_graph(graph: Graph,
|
|
|
59
58
|
# Retrieve the corresponding mask using the node's name (since we use a graph's copy).
|
|
60
59
|
mask = [v for k, v in masks.items() if k.name == pruning_section.entry_node.name]
|
|
61
60
|
if len(mask) != 1:
|
|
62
|
-
Logger.
|
|
61
|
+
Logger.critical(f"Expected to find a single node with name {pruning_section.entry_node.name} in masks dictionary, but found {len(mask)}.")
|
|
63
62
|
mask = mask[0]
|
|
64
63
|
|
|
65
64
|
# If the mask indicates that some channels are to be pruned, apply it.
|
|
@@ -18,7 +18,7 @@ from typing import Callable, List, Dict, Tuple
|
|
|
18
18
|
|
|
19
19
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
20
20
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
21
|
-
from model_compression_toolkit.core.common.mixed_precision.
|
|
21
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
|
22
22
|
from model_compression_toolkit.core.common.pruning.greedy_mask_calculator import GreedyMaskCalculator
|
|
23
23
|
from model_compression_toolkit.core.common.pruning.importance_metrics.importance_metric_factory import \
|
|
24
24
|
get_importance_metric
|
|
@@ -33,14 +33,14 @@ from model_compression_toolkit.target_platform_capabilities.target_platform impo
|
|
|
33
33
|
|
|
34
34
|
class Pruner:
|
|
35
35
|
"""
|
|
36
|
-
Pruner class responsible for applying pruning to a computational graph to meet a target
|
|
36
|
+
Pruner class responsible for applying pruning to a computational graph to meet a target resource utilization.
|
|
37
37
|
It identifies and prunes less significant channels based on importance scores, considering SIMD constraints.
|
|
38
38
|
"""
|
|
39
39
|
def __init__(self,
|
|
40
40
|
float_graph: Graph,
|
|
41
41
|
fw_info: FrameworkInfo,
|
|
42
42
|
fw_impl: PruningFrameworkImplementation,
|
|
43
|
-
|
|
43
|
+
target_resource_utilization: ResourceUtilization,
|
|
44
44
|
representative_data_gen: Callable,
|
|
45
45
|
pruning_config: PruningConfig,
|
|
46
46
|
target_platform_capabilities: TargetPlatformCapabilities):
|
|
@@ -49,7 +49,7 @@ class Pruner:
|
|
|
49
49
|
float_graph (Graph): The floating-point representation of the model's computation graph.
|
|
50
50
|
fw_info (FrameworkInfo): Contains metadata and helper functions for the framework.
|
|
51
51
|
fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning.
|
|
52
|
-
|
|
52
|
+
target_resource_utilization (ResourceUtilization): The target resource utilization to be achieved after pruning.
|
|
53
53
|
representative_data_gen (Callable): Generator function for representative dataset used in pruning analysis.
|
|
54
54
|
pruning_config (PruningConfig): Configuration object specifying how pruning should be performed.
|
|
55
55
|
target_platform_capabilities (TargetPlatformCapabilities): Object encapsulating the capabilities of the target hardware platform.
|
|
@@ -57,7 +57,7 @@ class Pruner:
|
|
|
57
57
|
self.float_graph = float_graph
|
|
58
58
|
self.fw_info = fw_info
|
|
59
59
|
self.fw_impl = fw_impl
|
|
60
|
-
self.
|
|
60
|
+
self.target_resource_utilization = target_resource_utilization
|
|
61
61
|
self.representative_data_gen = representative_data_gen
|
|
62
62
|
self.pruning_config = pruning_config
|
|
63
63
|
self.target_platform_capabilities = target_platform_capabilities
|
|
@@ -84,7 +84,7 @@ class Pruner:
|
|
|
84
84
|
mask_calculator = GreedyMaskCalculator(entry_nodes,
|
|
85
85
|
self.fw_info,
|
|
86
86
|
self.simd_scores,
|
|
87
|
-
self.
|
|
87
|
+
self.target_resource_utilization,
|
|
88
88
|
self.float_graph,
|
|
89
89
|
self.fw_impl,
|
|
90
90
|
self.target_platform_capabilities,
|
|
@@ -92,7 +92,7 @@ class Pruner:
|
|
|
92
92
|
mask_calculator.compute_mask()
|
|
93
93
|
self.per_oc_mask = mask_calculator.get_mask()
|
|
94
94
|
else:
|
|
95
|
-
Logger.
|
|
95
|
+
Logger.critical("Only GREEDY ChannelsFilteringStrategy is currently supported.")
|
|
96
96
|
|
|
97
97
|
Logger.info("Start pruning graph...")
|
|
98
98
|
_pruned_graph = build_pruned_graph(self.float_graph,
|
|
@@ -32,7 +32,7 @@ class ChannelsFilteringStrategy(Enum):
|
|
|
32
32
|
"""
|
|
33
33
|
Enum for specifying the strategy used for filtering (pruning) channels:
|
|
34
34
|
|
|
35
|
-
GREEDY - Prune the least important channel groups up to allowed resources
|
|
35
|
+
GREEDY - Prune the least important channel groups up to the allowed resources utilization limit (for now, only weights_memory is considered).
|
|
36
36
|
|
|
37
37
|
"""
|
|
38
38
|
GREEDY = 0 # Greedy strategy for pruning channels based on importance metrics.
|
|
@@ -75,8 +75,8 @@ def unroll_simd_scores_to_per_channel_scores(simd_scores: Dict[BaseNode, np.ndar
|
|
|
75
75
|
Dict[BaseNode, np.ndarray]: Expanded scores for each individual channel.
|
|
76
76
|
"""
|
|
77
77
|
if simd_scores is None or simd_groups_indices is None:
|
|
78
|
-
Logger.
|
|
79
|
-
|
|
78
|
+
Logger.critical(f"Failed to find scores and indices to create unrolled scores for pruning information."
|
|
79
|
+
f" Scores: {simd_scores}, Group indices: {simd_groups_indices}.")
|
|
80
80
|
_scores = {}
|
|
81
81
|
for node, groups_indices in simd_groups_indices.items():
|
|
82
82
|
node_scores = simd_scores[node]
|
|
@@ -59,8 +59,10 @@ class CandidateNodeQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
59
59
|
self.activation_quantization_cfg = activation_quantization_cfg
|
|
60
60
|
else:
|
|
61
61
|
if any(v is None for v in (qc, op_cfg, activation_quantization_fn, activation_quantization_params_fn)):
|
|
62
|
-
Logger.
|
|
63
|
-
|
|
62
|
+
Logger.critical(
|
|
63
|
+
"Missing required arguments to initialize a node activation quantization configuration. "
|
|
64
|
+
"Ensure QuantizationConfig, OpQuantizationConfig, activation quantization function, "
|
|
65
|
+
"and parameters function are provided.")
|
|
64
66
|
self.activation_quantization_cfg = (
|
|
65
67
|
NodeActivationQuantizationConfig(qc=qc,
|
|
66
68
|
op_cfg=op_cfg,
|
|
@@ -71,8 +73,9 @@ class CandidateNodeQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
71
73
|
self.weights_quantization_cfg = weights_quantization_cfg
|
|
72
74
|
else:
|
|
73
75
|
if any(v is None for v in (qc, op_cfg, node_attrs_list)):
|
|
74
|
-
Logger.
|
|
75
|
-
|
|
76
|
+
Logger.critical("Missing required arguments to initialize a node weights quantization configuration. "
|
|
77
|
+
"Ensure QuantizationConfig, OpQuantizationConfig, weights quantization function, "
|
|
78
|
+
"parameters function, and weights attribute quantization config are provided.")
|
|
76
79
|
self.weights_quantization_cfg = NodeWeightsQuantizationConfig(qc=qc, op_cfg=op_cfg,
|
|
77
80
|
weights_channels_axis=weights_channels_axis,
|
|
78
81
|
node_attrs_list=node_attrs_list)
|
|
@@ -122,7 +122,9 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
122
122
|
self.activation_quantization_params)
|
|
123
123
|
|
|
124
124
|
if fake_quant is None:
|
|
125
|
-
Logger.
|
|
125
|
+
Logger.critical(
|
|
126
|
+
"Layer is intended to be quantized, but the fake_quant function is None.") # pragma: no cover
|
|
127
|
+
|
|
126
128
|
return fake_quant(tensors)
|
|
127
129
|
|
|
128
130
|
@property
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
from collections.abc import Callable
|
|
17
17
|
from functools import partial
|
|
18
18
|
|
|
19
|
+
from model_compression_toolkit.logger import Logger
|
|
19
20
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
20
21
|
from model_compression_toolkit.core.common.quantization.quantizers.lut_kmeans_quantizer import lut_kmeans_quantizer
|
|
21
22
|
from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import power_of_two_quantizer, \
|
|
@@ -42,6 +43,7 @@ def get_weights_quantization_fn(weights_quantization_method: QuantizationMethod)
|
|
|
42
43
|
elif weights_quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER, QuantizationMethod.LUT_SYM_QUANTIZER]:
|
|
43
44
|
quantizer_fn = lut_kmeans_quantizer
|
|
44
45
|
else:
|
|
45
|
-
|
|
46
|
-
f
|
|
46
|
+
Logger.critical(
|
|
47
|
+
f"No quantizer function found for the specified quantization method: {weights_quantization_method}")
|
|
48
|
+
|
|
47
49
|
return quantizer_fn
|
|
@@ -47,9 +47,8 @@ def get_activation_quantization_params_fn(activation_quantization_method: Quanti
|
|
|
47
47
|
elif activation_quantization_method == QuantizationMethod.LUT_POT_QUANTIZER:
|
|
48
48
|
params_fn = lut_kmeans_histogram
|
|
49
49
|
else:
|
|
50
|
-
Logger.
|
|
51
|
-
f
|
|
52
|
-
f'quantization method {activation_quantization_method}') # pragma: no cover
|
|
50
|
+
Logger.critical(
|
|
51
|
+
f"No parameter function found for the specified quantization method: {activation_quantization_method}") # pragma: no cover
|
|
53
52
|
return params_fn
|
|
54
53
|
|
|
55
54
|
|
|
@@ -74,7 +73,6 @@ def get_weights_quantization_params_fn(weights_quantization_method: Quantization
|
|
|
74
73
|
elif weights_quantization_method == QuantizationMethod.LUT_SYM_QUANTIZER:
|
|
75
74
|
params_fn = partial(lut_kmeans_tensor, is_symmetric=True)
|
|
76
75
|
else:
|
|
77
|
-
Logger.
|
|
78
|
-
f
|
|
79
|
-
f'quantization method {weights_quantization_method}') # pragma: no cover
|
|
76
|
+
Logger.critical(
|
|
77
|
+
f"No parameter function found for the specified quantization method: {weights_quantization_method}") # pragma: no cover
|
|
80
78
|
return params_fn
|
|
@@ -60,8 +60,7 @@ def lut_kmeans_tensor(tensor_data: np.ndarray,
|
|
|
60
60
|
the thresholds per channel and the multiplier num bits.
|
|
61
61
|
"""
|
|
62
62
|
if n_bits >= LUT_VALUES_BITWIDTH:
|
|
63
|
-
Logger.critical(f'Look-Up-Table bit configuration
|
|
64
|
-
f'{LUT_VALUES_BITWIDTH}') # pragma: no cover
|
|
63
|
+
Logger.critical(f'Look-Up-Table (LUT) bit configuration exceeds maximum: {n_bits} bits provided, must be less than {LUT_VALUES_BITWIDTH} bits.') # pragma: no cover
|
|
65
64
|
# TODO: need to set this externally
|
|
66
65
|
if len(np.unique(tensor_data.flatten())) < 2 ** n_bits:
|
|
67
66
|
n_clusters = len(np.unique(tensor_data.flatten()))
|
|
@@ -121,8 +120,7 @@ def lut_kmeans_histogram(bins: np.ndarray,
|
|
|
121
120
|
"""
|
|
122
121
|
|
|
123
122
|
if n_bits >= LUT_VALUES_BITWIDTH:
|
|
124
|
-
Logger.critical(f'Look-Up-Table bit configuration
|
|
125
|
-
f'{LUT_VALUES_BITWIDTH}') # pragma: no cover
|
|
123
|
+
Logger.critical(f'Look-Up-Table (LUT) bit configuration exceeds maximum: {n_bits} bits provided, must be less than {LUT_VALUES_BITWIDTH} bits.') # pragma: no cover
|
|
126
124
|
|
|
127
125
|
bins_with_values = np.abs(bins)[1:][counts > 0]
|
|
128
126
|
if len(np.unique(bins_with_values.flatten())) < 2 ** n_bits:
|
|
@@ -238,7 +238,7 @@ def get_tensor_max(tensor_data: np.ndarray,
|
|
|
238
238
|
|
|
239
239
|
"""
|
|
240
240
|
if n_bits < 1:
|
|
241
|
-
Logger.
|
|
241
|
+
Logger.critical(f"Parameter n_bits must be positive; however 'n_bits'={n_bits} was provided.")
|
|
242
242
|
if is_uniform_quantization:
|
|
243
243
|
expansion_factor = 1.0
|
|
244
244
|
elif n_bits == 1:
|
|
@@ -52,11 +52,13 @@ def power_of_two_quantizer(tensor_data: np.ndarray,
|
|
|
52
52
|
"""
|
|
53
53
|
threshold = quantization_params.get(THRESHOLD)
|
|
54
54
|
if threshold is None:
|
|
55
|
-
Logger.
|
|
55
|
+
Logger.critical(f"'{THRESHOLD}' parameter must be defined in 'quantization_params'") # pragma: no cover
|
|
56
|
+
|
|
56
57
|
if not threshold_is_power_of_two(threshold, per_channel):
|
|
57
|
-
Logger.
|
|
58
|
+
Logger.critical(f"Expected '{THRESHOLD}' parameter to be a power of two, but received {threshold}.")# pragma: no cover
|
|
59
|
+
|
|
58
60
|
if (per_channel and (threshold <= 0).any()) or ((not per_channel) and threshold <= 0):
|
|
59
|
-
Logger.
|
|
61
|
+
Logger.critical(f"'{THRESHOLD}' parameter must positive") # pragma: no cover
|
|
60
62
|
|
|
61
63
|
|
|
62
64
|
return quantize_tensor(tensor_data,
|
|
@@ -88,10 +90,10 @@ def symmetric_quantizer(tensor_data: np.ndarray,
|
|
|
88
90
|
"""
|
|
89
91
|
threshold = quantization_params.get(THRESHOLD)
|
|
90
92
|
if threshold is None:
|
|
91
|
-
Logger.
|
|
93
|
+
Logger.critical(f"'{THRESHOLD}' parameter must be defined in 'quantization_params'") # pragma: no cover
|
|
92
94
|
|
|
93
95
|
if (per_channel and np.any(threshold <= 0)) or (not per_channel and threshold <= 0):
|
|
94
|
-
Logger.
|
|
96
|
+
Logger.critical(f"'{THRESHOLD}' parameter must positive") # pragma: no cover
|
|
95
97
|
|
|
96
98
|
return quantize_tensor(tensor_data,
|
|
97
99
|
threshold,
|
|
@@ -122,6 +124,6 @@ def uniform_quantizer(tensor_data: np.ndarray,
|
|
|
122
124
|
range_min = quantization_params.get(RANGE_MIN)
|
|
123
125
|
range_max = quantization_params.get(RANGE_MAX)
|
|
124
126
|
if range_min is None or range_max is None:
|
|
125
|
-
Logger.
|
|
127
|
+
Logger.critical("'quantization range' parameters must be defined in 'quantization_params'") # pragma: no cover
|
|
126
128
|
|
|
127
129
|
return uniform_quantize_tensor(tensor_data, range_min, range_max, n_bits)
|
|
@@ -112,7 +112,7 @@ def create_node_activation_qc(qc: QuantizationConfig,
|
|
|
112
112
|
|
|
113
113
|
activation_quantization_fn = fw_info.activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
|
|
114
114
|
if activation_quantization_fn is None:
|
|
115
|
-
Logger.critical('Unknown quantization method
|
|
115
|
+
Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
|
|
116
116
|
|
|
117
117
|
activation_quantization_params_fn = get_activation_quantization_params_fn(op_cfg.activation_quantization_method)
|
|
118
118
|
|
|
@@ -149,7 +149,7 @@ def _create_node_single_candidate_qc(qc: QuantizationConfig,
|
|
|
149
149
|
# get parameters for activation quantization
|
|
150
150
|
activation_quantization_fn = fw_info.activation_quantizer_mapping.get(op_cfg.activation_quantization_method)
|
|
151
151
|
if activation_quantization_fn is None:
|
|
152
|
-
Logger.critical('Unknown quantization method
|
|
152
|
+
Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
|
|
153
153
|
|
|
154
154
|
activation_quantization_params_fn = get_activation_quantization_params_fn(op_cfg.activation_quantization_method)
|
|
155
155
|
|
model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py
CHANGED
|
@@ -190,13 +190,11 @@ def _get_bias_correction_term_of_node(input_channels_axis: int,
|
|
|
190
190
|
"""
|
|
191
191
|
|
|
192
192
|
if output_channels_axis is None:
|
|
193
|
-
Logger.
|
|
194
|
-
f'Unknown output channel axis for node
|
|
195
|
-
f' please update channel mapping function')
|
|
193
|
+
Logger.critical(
|
|
194
|
+
f'Unknown output channel axis for node: {n.name}. Please update the channel mapping function.')
|
|
196
195
|
if input_channels_axis is None:
|
|
197
|
-
Logger.
|
|
198
|
-
f'Unknown input channel axis for node
|
|
199
|
-
f' please update channel mapping function')
|
|
196
|
+
Logger.critical(
|
|
197
|
+
f'Unknown input channel axis for node: {n.name}. Please update the channel mapping function')
|
|
200
198
|
# Compute the bias correction term.
|
|
201
199
|
correction = _compute_bias_correction(n.get_weights_by_keys(fw_impl.constants.KERNEL),
|
|
202
200
|
quantized_kernel,
|
|
@@ -103,15 +103,13 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
|
|
|
103
103
|
# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
|
|
104
104
|
# we should skip the substitution.
|
|
105
105
|
if source_node.is_reused():
|
|
106
|
-
Logger.
|
|
107
|
-
"substitution and SMC feature") # pragma: no cover
|
|
106
|
+
Logger.critical("BN folding substitution cannot proceed if the linear operator is part of a reused group.") # pragma: no cover
|
|
108
107
|
|
|
109
108
|
bn_node = edge_nodes[1]
|
|
110
109
|
|
|
111
110
|
if len(graph.get_next_nodes(source_node)) > 1 or len(graph.get_prev_nodes(bn_node)) > 1:
|
|
112
|
-
Logger.
|
|
113
|
-
"
|
|
114
|
-
"skip the the BN folding substitution and SMC feature") # pragma: no cover
|
|
111
|
+
Logger.critical(
|
|
112
|
+
"BN folding substitution cannot proceed if the linear operator has multiple outputs or the BN layer has multiple inputs.") # pragma: no cover
|
|
115
113
|
|
|
116
114
|
kernel = source_node.get_weights_by_keys(self.kernel_str)
|
|
117
115
|
bias = source_node.get_weights_by_keys(self.bias_str)
|
|
@@ -199,5 +197,4 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
|
|
|
199
197
|
conv_bn_kernel_cfg.set_weights_quantization_param(corr_dict)
|
|
200
198
|
|
|
201
199
|
else:
|
|
202
|
-
Logger.
|
|
203
|
-
"quantization method of Power of 2") # pragma: no cover
|
|
200
|
+
Logger.critical("Second moment statistics correction feature is not supported for weights quantization methods other than 'SYMMETRIC' and 'UNIFORM'.") # pragma: no cover
|
|
@@ -134,7 +134,7 @@ def insert_node_after_node(graph: Graph,
|
|
|
134
134
|
|
|
135
135
|
last_nodes = graph.get_next_nodes(first_node)
|
|
136
136
|
if len(last_nodes) != 1:
|
|
137
|
-
Logger.
|
|
137
|
+
Logger.critical(f'Insertion requires exactly one successor node; {len(last_nodes)} successors found.') # pragma: no cover
|
|
138
138
|
last_node = last_nodes[0]
|
|
139
139
|
insert_node_between_two_nodes(graph, node_to_insert, first_node, last_node)
|
|
140
140
|
|
|
@@ -156,7 +156,7 @@ def insert_node_before_node(graph: Graph,
|
|
|
156
156
|
"""
|
|
157
157
|
first_nodes = graph.get_prev_nodes(last_node)
|
|
158
158
|
if len(first_nodes) != 1:
|
|
159
|
-
Logger.
|
|
159
|
+
Logger.critical('Insertion requires exactly one predecessor node; multiple or no predecessors found.') # pragma: no cover
|
|
160
160
|
first_node = first_nodes[0]
|
|
161
161
|
insert_node_between_two_nodes(graph, node_to_insert, first_node, last_node)
|
|
162
162
|
|
|
@@ -235,7 +235,7 @@ def shift_negative_function(graph: Graph,
|
|
|
235
235
|
min_to_correct, max_value2compare = graph.get_out_stats_collector(non_linear_node).get_min_max_values()
|
|
236
236
|
|
|
237
237
|
if not non_linear_node.is_all_activation_candidates_equal():
|
|
238
|
-
Logger.
|
|
238
|
+
Logger.critical("Shift negative correction is not supported for more than one activation quantization "
|
|
239
239
|
"configuration candidate") # pragma: no cover
|
|
240
240
|
|
|
241
241
|
# all candidates have same activation config, so taking the first candidate for calculations
|
model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py
CHANGED
|
@@ -48,7 +48,7 @@ class BaseVirtualActivationWeightsComposition(BaseSubstitution):
|
|
|
48
48
|
|
|
49
49
|
if len(graph.out_edges(act_node)) > 1:
|
|
50
50
|
Logger.warning(f"Node {act_node.name} has multiple outgoing edges, which is not supported with "
|
|
51
|
-
f"mixed-precision bit-operations
|
|
51
|
+
f"mixed-precision bit-operations utilization, thus, edge {act_node.name} --> {weights_node.name} "
|
|
52
52
|
f"would not be counted in the bit-operations calculations.")
|
|
53
53
|
return graph
|
|
54
54
|
|
|
@@ -65,9 +65,9 @@ class BaseWeightsActivationSplit(BaseSubstitution):
|
|
|
65
65
|
for c in node.candidates_quantization_cfg]
|
|
66
66
|
if not set(expected_candidates).issubset(all_candidates_bits):
|
|
67
67
|
# Node is not composite, therefore, can't be split
|
|
68
|
-
Logger.critical(f"The
|
|
69
|
-
f"
|
|
70
|
-
f"all model layers
|
|
68
|
+
Logger.critical(f"The node {node.name} cannot be split as it has non-composite candidates. "
|
|
69
|
+
f"For mixed-precision search with BOPS target resource utilization, "
|
|
70
|
+
f"all model layers must be composite.") # pragma: no cover
|
|
71
71
|
|
|
72
72
|
weights_node = VirtualSplitWeightsNode(node, kernel_attr)
|
|
73
73
|
activation_node = VirtualSplitActivationNode(node, self.activation_layer_type, self.fw_attr)
|
|
@@ -38,9 +38,9 @@ def get_keras_model_builder(mode: ModelBuilderMode) -> type:
|
|
|
38
38
|
"""
|
|
39
39
|
|
|
40
40
|
if not isinstance(mode, ModelBuilderMode):
|
|
41
|
-
Logger.
|
|
41
|
+
Logger.critical(f"Expected a ModelBuilderMode type for 'mode', but received {type(mode)} instead.")
|
|
42
42
|
if mode is None:
|
|
43
|
-
Logger.
|
|
43
|
+
Logger.critical(f"get_keras_model_builder received 'mode' is None")
|
|
44
44
|
if mode not in keras_model_builders.keys():
|
|
45
|
-
Logger.
|
|
45
|
+
Logger.critical(f"'mode' {mode} is not recognized in the Keras model builders factory.")
|
|
46
46
|
return keras_model_builders.get(mode)
|
|
@@ -88,8 +88,8 @@ def node_builder(n: common.BaseNode) -> Layer:
|
|
|
88
88
|
try:
|
|
89
89
|
node_instance = _layer_class.from_config(framework_attr) # Build layer from node's configuration.
|
|
90
90
|
except Exception as e:
|
|
91
|
-
|
|
92
|
-
Logger.
|
|
91
|
+
Logger.info(e) # pragma: no cover
|
|
92
|
+
Logger.critical(
|
|
93
93
|
f"Keras can not de-serialize layer {_layer_class} in order to build a static graph representation. This is probably because "
|
|
94
94
|
f"your model contains custom layers which MCT doesn't support. Please provide a model without custom layers.") # pragma: no cover
|
|
95
95
|
with tf.name_scope(n.name):
|
|
@@ -104,8 +104,7 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
|
|
|
104
104
|
# or single precision).
|
|
105
105
|
node_weights_qc = n.get_unique_weights_candidates(kernel_attr)
|
|
106
106
|
if not len(node_weights_qc) == 1:
|
|
107
|
-
Logger.
|
|
108
|
-
f"but {len(node_weights_qc)} different configurations exist.")
|
|
107
|
+
Logger.critical(f"Expected a unique weights configuration for node {n.name}, but found {len(node_weights_qc)} configurations.")# pragma: no cover
|
|
109
108
|
|
|
110
109
|
quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Weights,
|
|
111
110
|
node_weights_qc[0].weights_quantization_cfg
|
|
@@ -143,8 +142,7 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
|
|
|
143
142
|
|
|
144
143
|
max_cfg_candidates = n.find_max_candidates_indices()
|
|
145
144
|
if not len(max_cfg_candidates) == 1:
|
|
146
|
-
Logger.
|
|
147
|
-
f"but some node have multiple potential maximal candidates")
|
|
145
|
+
Logger.critical(f"A maximal configuration candidate must be defined; found multiple potential maximal candidates.")# pragma: no cover
|
|
148
146
|
|
|
149
147
|
max_candidate_idx = max_cfg_candidates[0]
|
|
150
148
|
|
|
@@ -211,8 +209,7 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
|
|
|
211
209
|
if len(activation_quantizers) == 1:
|
|
212
210
|
return KerasActivationQuantizationHolder(activation_quantizers[0])
|
|
213
211
|
|
|
214
|
-
Logger.
|
|
215
|
-
f'{len(activation_quantizers)} quantizers were found for node {n}')
|
|
212
|
+
Logger.critical(f"'KerasActivationQuantizationHolder' supports only one quantizer, but found {len(activation_quantizers)} for node {n}")# pragma: no cover
|
|
216
213
|
|
|
217
214
|
def build_model(self) -> Tuple[Model, UserInformation,
|
|
218
215
|
Dict[str, Union[KerasQuantizationWrapper, KerasActivationQuantizationHolder]]]:
|
|
@@ -292,6 +289,5 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
|
|
|
292
289
|
elif weights_quant and act_quant:
|
|
293
290
|
return self._get_weights_quant_layers(n, layers_list) + self._get_activation_quant_layers(n, layers_list)
|
|
294
291
|
else:
|
|
295
|
-
Logger.
|
|
296
|
-
f"but both are disabled.")
|
|
292
|
+
Logger.critical(f"Expected node {n.name} to have either weights or activation quantization configured, but both are disabled.")# pragma: no cover
|
|
297
293
|
|
|
@@ -25,6 +25,7 @@ from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
|
25
25
|
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
|
26
26
|
from model_compression_toolkit.constants import THRESHOLD
|
|
27
27
|
from model_compression_toolkit.core.keras.constants import KERNEL
|
|
28
|
+
from model_compression_toolkit.logger import Logger
|
|
28
29
|
|
|
29
30
|
input_node = NodeOperationMatcher(InputLayer)
|
|
30
31
|
zeropad_node = NodeOperationMatcher(ZeroPadding2D)
|
|
@@ -80,8 +81,8 @@ class BaseInputScaling(common.BaseSubstitution):
|
|
|
80
81
|
linear_layer = nodes_list[-1]
|
|
81
82
|
|
|
82
83
|
if not input_layer.is_all_activation_candidates_equal():
|
|
83
|
-
|
|
84
|
-
"candidate")
|
|
84
|
+
Logger.critical("Input scaling is not supported for nodes with more than one activation quantization configuration "
|
|
85
|
+
"candidate.")
|
|
85
86
|
|
|
86
87
|
# all candidates have same activation config, so taking the first candidate for calculations
|
|
87
88
|
threshold = input_layer.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params.get(THRESHOLD)
|
|
@@ -104,7 +104,7 @@ def conv2d_collapsing_fn(first_node: BaseNode,
|
|
|
104
104
|
|
|
105
105
|
return kernel_collapsed, bias_collapsed
|
|
106
106
|
else:
|
|
107
|
-
Logger.
|
|
107
|
+
Logger.critical(f"Layer collapsing unsupported for combination: {first_node.type} and {second_node.type}.")
|
|
108
108
|
|
|
109
109
|
|
|
110
110
|
def keras_linear_collapsing() -> Conv2DCollapsing:
|
|
@@ -161,7 +161,7 @@ def op2d_add_const_collapsing_fn(op2d_node: BaseNode,
|
|
|
161
161
|
# read constant from add node (either 1st or 2nd positional weight)
|
|
162
162
|
const = add_node.weights.get(0, add_node.weights.get(1))
|
|
163
163
|
if const is None:
|
|
164
|
-
Logger.
|
|
164
|
+
Logger.critical(f'Failed to read constant from add node: {add_node.name}.') # pragma: no cover
|
|
165
165
|
|
|
166
166
|
# return new bias
|
|
167
167
|
if bias is None:
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py
CHANGED
|
@@ -66,7 +66,7 @@ class MatmulToDenseSubstitution(common.BaseSubstitution):
|
|
|
66
66
|
# read const from matmul inputs
|
|
67
67
|
w = matmul_node.weights.get(1)
|
|
68
68
|
if w is None:
|
|
69
|
-
Logger.
|
|
69
|
+
Logger.critical(f"Matmul substitution failed: Unable to locate weight for node {matmul_node.name}.") # pragma: no cover
|
|
70
70
|
|
|
71
71
|
if len(w.shape) != 2:
|
|
72
72
|
# weight tensor should be of shape (Cin, Cout)
|
|
@@ -448,7 +448,7 @@ class MultiHeadAttentionDecomposition(common.BaseSubstitution):
|
|
|
448
448
|
"""
|
|
449
449
|
|
|
450
450
|
if mha_node.reuse:
|
|
451
|
-
Logger.
|
|
451
|
+
Logger.critical("Reuse of MultiHeadAttention layers is currently not supported.") # pragma: no cover
|
|
452
452
|
params = MHAParams(mha_node)
|
|
453
453
|
|
|
454
454
|
mha_in_edges = graph.in_edges(mha_node)
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py
CHANGED
|
@@ -62,7 +62,7 @@ def residual_collapsing_fn(first_node: BaseNode,
|
|
|
62
62
|
|
|
63
63
|
return kernel
|
|
64
64
|
else:
|
|
65
|
-
Logger.
|
|
65
|
+
Logger.critical(f"Residual collapsing is unsupported for {first_node.type} node types.")
|
|
66
66
|
|
|
67
67
|
|
|
68
68
|
def keras_residual_collapsing() -> ResidualCollapsing:
|
|
@@ -64,7 +64,7 @@ class ActivationTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
|
|
|
64
64
|
model_output_nodes = [ot.node for ot in self.graph.get_outputs()]
|
|
65
65
|
|
|
66
66
|
if self.hessian_request.target_node in model_output_nodes:
|
|
67
|
-
Logger.
|
|
67
|
+
Logger.critical("Trying to compute activation Hessian approximation with respect to the model output. "
|
|
68
68
|
"This operation is not supported. "
|
|
69
69
|
"Remove the output node from the set of node targets in the Hessian request.")
|
|
70
70
|
|
|
@@ -83,7 +83,7 @@ class ActivationTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
|
|
|
83
83
|
outputs = model(*self.input_images)
|
|
84
84
|
|
|
85
85
|
if len(outputs) != len(grad_model_outputs):
|
|
86
|
-
Logger.
|
|
86
|
+
Logger.critical(
|
|
87
87
|
f"Model for computing activation Hessian approximation expects {len(grad_model_outputs)} "
|
|
88
88
|
f"outputs, but got {len(outputs)} output tensors.")
|
|
89
89
|
|
|
@@ -166,4 +166,4 @@ class ActivationTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
|
|
|
166
166
|
return trace_approx_by_node.numpy().tolist()
|
|
167
167
|
|
|
168
168
|
else:
|
|
169
|
-
Logger.
|
|
169
|
+
Logger.critical(f"{self.hessian_request.granularity} is not supported for Keras activation hessian\'s trace approximation calculator.")
|
|
@@ -74,7 +74,6 @@ class TraceHessianCalculatorKeras(TraceHessianCalculator):
|
|
|
74
74
|
concat_axis_dim = [o.shape[0] for o in _r_tensors]
|
|
75
75
|
if not all(d == concat_axis_dim[0] for d in concat_axis_dim):
|
|
76
76
|
Logger.critical(
|
|
77
|
-
"
|
|
78
|
-
"is not equal in all outputs.")
|
|
77
|
+
"Unable to concatenate tensors for gradient calculation due to mismatched shapes along the first axis.")# pragma: no cover
|
|
79
78
|
|
|
80
79
|
return tf.concat(_r_tensors, axis=1)
|