mct-nightly 2.2.0.20250113.527__py3-none-any.whl → 2.2.0.20250114.84821__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/RECORD +103 -105
- model_compression_toolkit/__init__.py +2 -2
- model_compression_toolkit/core/common/framework_info.py +1 -3
- model_compression_toolkit/core/common/fusion/layer_fusing.py +6 -5
- model_compression_toolkit/core/common/graph/base_graph.py +20 -21
- model_compression_toolkit/core/common/graph/base_node.py +44 -17
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +7 -6
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +0 -6
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +26 -135
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +36 -62
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +667 -0
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +25 -202
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py +164 -470
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +30 -7
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +7 -6
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +0 -1
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +0 -1
- model_compression_toolkit/core/common/pruning/pruner.py +5 -3
- model_compression_toolkit/core/common/quantization/bit_width_config.py +6 -12
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -2
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +15 -14
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +1 -1
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/graph_prep_runner.py +12 -11
- model_compression_toolkit/core/keras/data_util.py +24 -5
- model_compression_toolkit/core/keras/default_framework_info.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +1 -2
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +5 -6
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +4 -5
- model_compression_toolkit/core/runner.py +33 -60
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +1 -1
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantization_facade.py +8 -9
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +8 -9
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/metadata.py +11 -10
- model_compression_toolkit/pruning/keras/pruning_facade.py +5 -6
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +6 -7
- model_compression_toolkit/ptq/keras/quantization_facade.py +8 -9
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -9
- model_compression_toolkit/qat/keras/quantization_facade.py +5 -6
- model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py +1 -1
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +5 -9
- model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +1 -1
- model_compression_toolkit/target_platform_capabilities/__init__.py +9 -0
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +2 -2
- model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +18 -18
- model_compression_toolkit/target_platform_capabilities/schema/v1.py +13 -13
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/__init__.py +6 -6
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2fw.py +10 -10
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2keras.py +3 -3
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2pytorch.py +3 -2
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/current_tpc.py +8 -8
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities.py → targetplatform2framework/framework_quantization_capabilities.py} +40 -40
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities_component.py → targetplatform2framework/framework_quantization_capabilities_component.py} +2 -2
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/layer_filter_params.py +0 -1
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/operations_to_layers.py +8 -8
- model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +24 -24
- model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +18 -18
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/{tp_model.py → tpc.py} +31 -32
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/{tp_model.py → tpc.py} +27 -27
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +4 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/{tp_model.py → tpc.py} +27 -27
- model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +1 -2
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +2 -1
- model_compression_toolkit/trainable_infrastructure/keras/activation_quantizers/lsq/symmetric_lsq.py +1 -2
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/xquant/common/model_folding_utils.py +7 -6
- model_compression_toolkit/xquant/keras/keras_report_utils.py +4 -4
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +0 -105
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +0 -33
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +0 -23
- {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attribute_filter.py +0 -0
@@ -32,8 +32,9 @@ from model_compression_toolkit.core.common.collectors.statistics_collector impor
|
|
32
32
|
from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
|
33
33
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
34
34
|
from model_compression_toolkit.logger import Logger
|
35
|
-
from model_compression_toolkit.target_platform_capabilities.
|
36
|
-
|
35
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
|
36
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
37
|
+
FrameworkQuantizationCapabilities
|
37
38
|
|
38
39
|
OutTensor = namedtuple('OutTensor', 'node node_out_index')
|
39
40
|
|
@@ -86,29 +87,29 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
86
87
|
|
87
88
|
self.fw_info = fw_info
|
88
89
|
|
89
|
-
def
|
90
|
-
|
90
|
+
def set_fqc(self,
|
91
|
+
fqc: FrameworkQuantizationCapabilities):
|
91
92
|
"""
|
92
|
-
Set the graph's
|
93
|
+
Set the graph's FQC.
|
93
94
|
Args:
|
94
|
-
|
95
|
+
fqc: FrameworkQuantizationCapabilities object.
|
95
96
|
"""
|
96
|
-
# validate graph nodes are either from the framework or a custom layer defined in the
|
97
|
-
# Validate graph nodes are either built-in layers from the framework or custom layers defined in the
|
98
|
-
|
99
|
-
|
97
|
+
# validate graph nodes are either from the framework or a custom layer defined in the FQC
|
98
|
+
# Validate graph nodes are either built-in layers from the framework or custom layers defined in the FQC
|
99
|
+
fqc_layers = fqc.op_sets_to_layers.get_layers()
|
100
|
+
fqc_filtered_layers = [layer for layer in fqc_layers if isinstance(layer, LayerFilterParams)]
|
100
101
|
for n in self.nodes:
|
101
|
-
|
102
|
-
any([n.is_match_filter_params(filtered_layer) for filtered_layer in
|
102
|
+
is_node_in_fqc = any([n.is_match_type(_type) for _type in fqc_layers]) or \
|
103
|
+
any([n.is_match_filter_params(filtered_layer) for filtered_layer in fqc_filtered_layers])
|
103
104
|
if n.is_custom:
|
104
|
-
if not
|
105
|
+
if not is_node_in_fqc:
|
105
106
|
Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. '
|
106
|
-
' Please add the custom layer to
|
107
|
+
' Please add the custom layer to Framework Quantization Capabilities (FQC), or file a feature '
|
107
108
|
'request or an issue if you believe this should be supported.') # pragma: no cover
|
108
|
-
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(
|
109
|
+
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(fqc).quantization_configurations]):
|
109
110
|
Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') # pragma: no cover
|
110
111
|
|
111
|
-
self.
|
112
|
+
self.fqc = fqc
|
112
113
|
|
113
114
|
def get_topo_sorted_nodes(self):
|
114
115
|
"""
|
@@ -544,10 +545,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
544
545
|
potential_conf_nodes = [n for n in list(self) if fw_info.is_kernel_op(n.type)]
|
545
546
|
|
546
547
|
def is_configurable(n):
|
547
|
-
|
548
|
-
return (n.
|
549
|
-
not n.is_all_weights_candidates_equal(kernel_attr) and
|
550
|
-
(not n.reuse or include_reused_nodes))
|
548
|
+
kernel_attrs = fw_info.get_kernel_op_attributes(n.type)
|
549
|
+
return any(n.is_configurable_weight(attr) for attr in kernel_attrs) and (not n.reuse or include_reused_nodes)
|
551
550
|
|
552
551
|
return [n for n in potential_conf_nodes if is_configurable(n)]
|
553
552
|
|
@@ -576,7 +575,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
576
575
|
Returns:
|
577
576
|
A list of nodes that their activation can be configured (namely, has one or more activation qc candidate).
|
578
577
|
"""
|
579
|
-
return [n for n in list(self) if n.
|
578
|
+
return [n for n in list(self) if n.has_configurable_activation()]
|
580
579
|
|
581
580
|
def get_sorted_activation_configurable_nodes(self) -> List[BaseNode]:
|
582
581
|
"""
|
@@ -25,7 +25,9 @@ from model_compression_toolkit.logger import Logger
|
|
25
25
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \
|
26
26
|
OpQuantizationConfig
|
27
27
|
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
|
28
|
-
from model_compression_toolkit.target_platform_capabilities.
|
28
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
|
29
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
30
|
+
FrameworkQuantizationCapabilities
|
29
31
|
|
30
32
|
|
31
33
|
class BaseNode:
|
@@ -150,6 +152,27 @@ class BaseNode:
|
|
150
152
|
|
151
153
|
return False
|
152
154
|
|
155
|
+
def is_configurable_weight(self, attr_name: str) -> bool:
|
156
|
+
"""
|
157
|
+
Checks whether the specific weight attribute has a configurable quantization.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
attr_name: weight attribute name.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
Whether the weight attribute is configurable.
|
164
|
+
"""
|
165
|
+
return self.is_weights_quantization_enabled(attr_name) and not self.is_all_weights_candidates_equal(attr_name)
|
166
|
+
|
167
|
+
def has_configurable_activation(self) -> bool:
|
168
|
+
"""
|
169
|
+
Checks whether the activation has a configurable quantization.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
Whether the activation has a configurable quantization.
|
173
|
+
"""
|
174
|
+
return self.is_activation_quantization_enabled() and not self.is_all_activation_candidates_equal()
|
175
|
+
|
153
176
|
def __repr__(self):
|
154
177
|
"""
|
155
178
|
|
@@ -420,11 +443,15 @@ class BaseNode:
|
|
420
443
|
|
421
444
|
Returns: Output size.
|
422
445
|
"""
|
423
|
-
|
446
|
+
# shape can be tuple or list, and multiple shapes can be packed in list or tuple
|
447
|
+
if self.output_shape and isinstance(self.output_shape[0], (tuple, list)):
|
448
|
+
output_shapes = self.output_shape
|
449
|
+
else:
|
450
|
+
output_shapes = [self.output_shape]
|
424
451
|
|
425
452
|
# remove batch size (first element) from output shape
|
426
453
|
output_shapes = [s[1:] for s in output_shapes]
|
427
|
-
|
454
|
+
# for scalar shape (None,) prod returns 1
|
428
455
|
return sum([np.prod([x for x in output_shape if x is not None]) for output_shape in output_shapes])
|
429
456
|
|
430
457
|
def find_min_candidates_indices(self) -> List[int]:
|
@@ -536,34 +563,34 @@ class BaseNode:
|
|
536
563
|
# the inner method would log an exception.
|
537
564
|
return [c.weights_quantization_cfg.get_attr_config(attr) for c in self.candidates_quantization_cfg]
|
538
565
|
|
539
|
-
def get_qco(self,
|
566
|
+
def get_qco(self, fqc: FrameworkQuantizationCapabilities) -> QuantizationConfigOptions:
|
540
567
|
"""
|
541
568
|
Get the QuantizationConfigOptions of the node according
|
542
|
-
to the mappings from layers/LayerFilterParams to the OperatorsSet in the
|
569
|
+
to the mappings from layers/LayerFilterParams to the OperatorsSet in the TargetPlatformCapabilities.
|
543
570
|
|
544
571
|
Args:
|
545
|
-
|
572
|
+
fqc: FQC to extract the QuantizationConfigOptions for the node.
|
546
573
|
|
547
574
|
Returns:
|
548
575
|
QuantizationConfigOptions of the node.
|
549
576
|
"""
|
550
577
|
|
551
|
-
if
|
552
|
-
Logger.critical(f'Can not retrieve QC options for None
|
578
|
+
if fqc is None:
|
579
|
+
Logger.critical(f'Can not retrieve QC options for None FQC') # pragma: no cover
|
553
580
|
|
554
|
-
for fl, qco in
|
581
|
+
for fl, qco in fqc.filterlayer2qco.items():
|
555
582
|
if self.is_match_filter_params(fl):
|
556
583
|
return qco
|
557
584
|
# Extract qco with is_match_type to overcome mismatch of function types in TF 2.15
|
558
|
-
matching_qcos = [_qco for _type, _qco in
|
585
|
+
matching_qcos = [_qco for _type, _qco in fqc.layer2qco.items() if self.is_match_type(_type)]
|
559
586
|
if matching_qcos:
|
560
587
|
if all([_qco == matching_qcos[0] for _qco in matching_qcos]):
|
561
588
|
return matching_qcos[0]
|
562
589
|
else:
|
563
590
|
Logger.critical(f"Found duplicate qco types for node '{self.name}' of type '{self.type}'!") # pragma: no cover
|
564
|
-
return tpc.
|
591
|
+
return fqc.tpc.default_qco
|
565
592
|
|
566
|
-
def filter_node_qco_by_graph(self,
|
593
|
+
def filter_node_qco_by_graph(self, fqc: FrameworkQuantizationCapabilities,
|
567
594
|
next_nodes: List, node_qc_options: QuantizationConfigOptions
|
568
595
|
) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
|
569
596
|
"""
|
@@ -573,7 +600,7 @@ class BaseNode:
|
|
573
600
|
filters out quantization config that don't comply to these attributes.
|
574
601
|
|
575
602
|
Args:
|
576
|
-
|
603
|
+
fqc: FQC to extract the QuantizationConfigOptions for the next nodes.
|
577
604
|
next_nodes: Output nodes of current node.
|
578
605
|
node_qc_options: Node's QuantizationConfigOptions.
|
579
606
|
|
@@ -584,7 +611,7 @@ class BaseNode:
|
|
584
611
|
_base_config = node_qc_options.base_config
|
585
612
|
_node_qc_options = node_qc_options.quantization_configurations
|
586
613
|
if len(next_nodes):
|
587
|
-
next_nodes_qc_options = [_node.get_qco(
|
614
|
+
next_nodes_qc_options = [_node.get_qco(fqc) for _node in next_nodes]
|
588
615
|
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
|
589
616
|
for qc_opts in next_nodes_qc_options
|
590
617
|
for op_cfg in qc_opts.quantization_configurations])
|
@@ -593,7 +620,7 @@ class BaseNode:
|
|
593
620
|
_node_qc_options = [_option for _option in _node_qc_options
|
594
621
|
if _option.activation_n_bits <= next_nodes_supported_input_bitwidth]
|
595
622
|
if len(_node_qc_options) == 0:
|
596
|
-
Logger.critical(f"Graph doesn't match
|
623
|
+
Logger.critical(f"Graph doesn't match FQC bit configurations: {self} -> {next_nodes}.") # pragma: no cover
|
597
624
|
|
598
625
|
# Verify base config match
|
599
626
|
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
|
@@ -603,9 +630,9 @@ class BaseNode:
|
|
603
630
|
if len(_node_qc_options) > 0:
|
604
631
|
output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
|
605
632
|
_base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
|
606
|
-
Logger.warning(f"Node {self} base quantization config changed to match Graph and
|
633
|
+
Logger.warning(f"Node {self} base quantization config changed to match Graph and FQC configuration.\nCause: {self} -> {next_nodes}.")
|
607
634
|
else:
|
608
|
-
Logger.critical(f"Graph doesn't match
|
635
|
+
Logger.critical(f"Graph doesn't match FQC bit configurations: {self} -> {next_nodes}.") # pragma: no cover
|
609
636
|
|
610
637
|
return _base_config, _node_qc_options
|
611
638
|
|
@@ -17,18 +17,19 @@ import numpy as np
|
|
17
17
|
from model_compression_toolkit.core import ResourceUtilization, FrameworkInfo
|
18
18
|
from model_compression_toolkit.core.common import Graph
|
19
19
|
from model_compression_toolkit.logger import Logger
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
21
|
+
FrameworkQuantizationCapabilities
|
21
22
|
|
22
23
|
|
23
24
|
def filter_candidates_for_mixed_precision(graph: Graph,
|
24
25
|
target_resource_utilization: ResourceUtilization,
|
25
26
|
fw_info: FrameworkInfo,
|
26
|
-
|
27
|
+
fqc: FrameworkQuantizationCapabilities):
|
27
28
|
"""
|
28
29
|
Filters out candidates in case of mixed precision search for only weights or activation compression.
|
29
30
|
For instance, if running only weights compression - filters out candidates of activation configurable nodes
|
30
31
|
such that only a single candidate would remain, with the bitwidth equal to the one defined in the matching layer's
|
31
|
-
base config in the
|
32
|
+
base config in the FQC.
|
32
33
|
|
33
34
|
Note: This function modifies the graph inplace!
|
34
35
|
|
@@ -36,7 +37,7 @@ def filter_candidates_for_mixed_precision(graph: Graph,
|
|
36
37
|
graph: A graph representation of the model to be quantized.
|
37
38
|
target_resource_utilization: The resource utilization of the target device.
|
38
39
|
fw_info: fw_info: Information needed for quantization about the specific framework.
|
39
|
-
|
40
|
+
fqc: FrameworkQuantizationCapabilities object that describes the desired inference target platform.
|
40
41
|
|
41
42
|
"""
|
42
43
|
|
@@ -50,7 +51,7 @@ def filter_candidates_for_mixed_precision(graph: Graph,
|
|
50
51
|
weights_conf = graph.get_weights_configurable_nodes(fw_info)
|
51
52
|
activation_configurable_nodes = [n for n in graph.get_activation_configurable_nodes() if n not in weights_conf]
|
52
53
|
for n in activation_configurable_nodes:
|
53
|
-
base_cfg_nbits = n.get_qco(
|
54
|
+
base_cfg_nbits = n.get_qco(fqc).base_config.activation_n_bits
|
54
55
|
filtered_conf = [c for c in n.candidates_quantization_cfg if
|
55
56
|
c.activation_quantization_cfg.enable_activation_quantization and
|
56
57
|
c.activation_quantization_cfg.activation_n_bits == base_cfg_nbits]
|
@@ -67,7 +68,7 @@ def filter_candidates_for_mixed_precision(graph: Graph,
|
|
67
68
|
weight_configurable_nodes = [n for n in graph.get_weights_configurable_nodes(fw_info) if n not in activation_conf]
|
68
69
|
for n in weight_configurable_nodes:
|
69
70
|
kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
|
70
|
-
base_cfg_nbits = n.get_qco(
|
71
|
+
base_cfg_nbits = n.get_qco(fqc).base_config.attr_weights_configs_mapping[kernel_attr].weights_n_bits
|
71
72
|
filtered_conf = [c for c in n.candidates_quantization_cfg if
|
72
73
|
c.weights_quantization_cfg.get_attr_config(kernel_attr).enable_weights_quantization and
|
73
74
|
c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == base_cfg_nbits]
|
@@ -22,7 +22,6 @@ from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
|
|
22
22
|
from model_compression_toolkit.core.common import Graph
|
23
23
|
from model_compression_toolkit.core.common.hessian import HessianInfoService
|
24
24
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget
|
25
|
-
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_functions_mapping import ru_functions_mapping
|
26
25
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
27
26
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import MixedPrecisionSearchManager
|
28
27
|
from model_compression_toolkit.core.common.mixed_precision.search_methods.linear_programming import \
|
@@ -105,16 +104,11 @@ def search_bit_width(graph_to_search_cfg: Graph,
|
|
105
104
|
disable_activation_for_metric=disable_activation_for_metric,
|
106
105
|
hessian_info_service=hessian_info_service)
|
107
106
|
|
108
|
-
# Each pair of (resource utilization method, resource utilization aggregation) should match to a specific
|
109
|
-
# provided target resource utilization
|
110
|
-
ru_functions = ru_functions_mapping
|
111
|
-
|
112
107
|
# Instantiate a manager object
|
113
108
|
search_manager = MixedPrecisionSearchManager(graph,
|
114
109
|
fw_info,
|
115
110
|
fw_impl,
|
116
111
|
se,
|
117
|
-
ru_functions,
|
118
112
|
target_resource_utilization,
|
119
113
|
original_graph=graph_to_search_cfg)
|
120
114
|
|
@@ -13,23 +13,24 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from typing import Callable,
|
17
|
-
|
16
|
+
from typing import Callable, Dict, List
|
17
|
+
|
18
18
|
import numpy as np
|
19
19
|
|
20
20
|
from model_compression_toolkit.core.common import BaseNode
|
21
|
-
from model_compression_toolkit.logger import Logger
|
22
21
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
22
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
23
23
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
24
24
|
from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \
|
25
25
|
VirtualSplitWeightsNode, VirtualSplitActivationNode
|
26
|
-
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import
|
27
|
-
|
28
|
-
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.
|
29
|
-
|
30
|
-
from model_compression_toolkit.core.common.
|
31
|
-
|
26
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
|
27
|
+
RUTarget, ResourceUtilization
|
28
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
|
29
|
+
ResourceUtilizationCalculator, TargetInclusionCriterion, BitwidthMode
|
30
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import \
|
31
|
+
MixedPrecisionRUHelper
|
32
32
|
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
|
33
|
+
from model_compression_toolkit.logger import Logger
|
33
34
|
|
34
35
|
|
35
36
|
class MixedPrecisionSearchManager:
|
@@ -42,7 +43,6 @@ class MixedPrecisionSearchManager:
|
|
42
43
|
fw_info: FrameworkInfo,
|
43
44
|
fw_impl: FrameworkImplementation,
|
44
45
|
sensitivity_evaluator: SensitivityEvaluation,
|
45
|
-
ru_functions: Dict[RUTarget, RuFunctions],
|
46
46
|
target_resource_utilization: ResourceUtilization,
|
47
47
|
original_graph: Graph = None):
|
48
48
|
"""
|
@@ -53,8 +53,6 @@ class MixedPrecisionSearchManager:
|
|
53
53
|
fw_impl: FrameworkImplementation object with specific framework methods implementation.
|
54
54
|
sensitivity_evaluator: A SensitivityEvaluation which provides a function that evaluates the sensitivity of
|
55
55
|
a bit-width configuration for the MP model.
|
56
|
-
ru_functions: A dictionary with pairs of (MpRuMethod, MpRuAggregationMethod) mapping a RUTarget to
|
57
|
-
a couple of resource utilization metric function and resource utilization aggregation function.
|
58
56
|
target_resource_utilization: Target Resource Utilization to bound our feasible solution space s.t the configuration does not violate it.
|
59
57
|
original_graph: In case we have a search over a virtual graph (if we have BOPS utilization target), then this argument
|
60
58
|
will contain the original graph (for config reconstruction purposes).
|
@@ -69,29 +67,17 @@ class MixedPrecisionSearchManager:
|
|
69
67
|
self.compute_metric_fn = self.get_sensitivity_metric()
|
70
68
|
self._cuts = None
|
71
69
|
|
72
|
-
|
73
|
-
|
74
|
-
self.compute_ru_functions = {ru_target: ru_fn for ru_target, ru_fn in ru_functions.items() if ru_target in ru_types}
|
70
|
+
self.ru_metrics = target_resource_utilization.get_restricted_metrics()
|
71
|
+
self.ru_helper = MixedPrecisionRUHelper(graph, fw_info, fw_impl)
|
75
72
|
self.target_resource_utilization = target_resource_utilization
|
76
73
|
self.min_ru_config = self.graph.get_min_candidates_config(fw_info)
|
77
74
|
self.max_ru_config = self.graph.get_max_candidates_config(fw_info)
|
78
|
-
self.min_ru = self.
|
75
|
+
self.min_ru = self.ru_helper.compute_utilization(self.ru_metrics, self.min_ru_config)
|
79
76
|
self.non_conf_ru_dict = self._non_configurable_nodes_ru()
|
80
77
|
|
81
78
|
self.config_reconstruction_helper = ConfigReconstructionHelper(virtual_graph=self.graph,
|
82
79
|
original_graph=self.original_graph)
|
83
80
|
|
84
|
-
@property
|
85
|
-
def cuts(self) -> List[Cut]:
|
86
|
-
"""
|
87
|
-
Calculates graph cuts. Written as property, so it will only be calculated once and
|
88
|
-
only if cuts are needed.
|
89
|
-
|
90
|
-
"""
|
91
|
-
if self._cuts is None:
|
92
|
-
self._cuts = calc_graph_cuts(self.original_graph)
|
93
|
-
return self._cuts
|
94
|
-
|
95
81
|
def get_search_space(self) -> Dict[int, List[int]]:
|
96
82
|
"""
|
97
83
|
The search space is a mapping from a node's index to a list of integers (possible bitwidths candidates indeces
|
@@ -122,40 +108,6 @@ class MixedPrecisionSearchManager:
|
|
122
108
|
|
123
109
|
return self.sensitivity_evaluator.compute_metric
|
124
110
|
|
125
|
-
def _calc_ru_fn(self, ru_target, ru_fn, mp_cfg) -> np.ndarray:
|
126
|
-
"""
|
127
|
-
Computes a resource utilization for a certain mixed precision configuration.
|
128
|
-
The method computes a resource utilization vector for specific target resource utilization.
|
129
|
-
|
130
|
-
Returns: resource utilization value.
|
131
|
-
|
132
|
-
"""
|
133
|
-
# ru_fn is a pair of resource utilization computation method and
|
134
|
-
# resource utilization aggregation method (in this method we only need the first one)
|
135
|
-
if ru_target is RUTarget.ACTIVATION:
|
136
|
-
return ru_fn.metric_fn(mp_cfg, self.graph, self.fw_info, self.fw_impl, self.cuts)
|
137
|
-
else:
|
138
|
-
return ru_fn.metric_fn(mp_cfg, self.graph, self.fw_info, self.fw_impl)
|
139
|
-
|
140
|
-
def compute_min_ru(self) -> Dict[RUTarget, np.ndarray]:
|
141
|
-
"""
|
142
|
-
Computes a resource utilization vector with the values matching to the minimal mp configuration
|
143
|
-
(i.e., each node is configured with the quantization candidate that would give the minimal size of the
|
144
|
-
node's resource utilization).
|
145
|
-
The method computes the minimal resource utilization vector for each target resource utilization.
|
146
|
-
|
147
|
-
Returns: A dictionary mapping each target resource utilization to its respective minimal
|
148
|
-
resource utilization values.
|
149
|
-
|
150
|
-
"""
|
151
|
-
min_ru = {}
|
152
|
-
for ru_target, ru_fn in self.compute_ru_functions.items():
|
153
|
-
# ru_fns is a pair of resource utilization computation method and
|
154
|
-
# resource utilization aggregation method (in this method we only need the first one)
|
155
|
-
min_ru[ru_target] = self._calc_ru_fn(ru_target, ru_fn, self.min_ru_config)
|
156
|
-
|
157
|
-
return min_ru
|
158
|
-
|
159
111
|
def compute_resource_utilization_matrix(self, target: RUTarget) -> np.ndarray:
|
160
112
|
"""
|
161
113
|
Computes and builds a resource utilization matrix, to be used for the mixed-precision search problem formalization.
|
@@ -184,7 +136,8 @@ class MixedPrecisionSearchManager:
|
|
184
136
|
# always be 0 for all entries in the results vector.
|
185
137
|
candidate_rus = np.zeros(shape=self.min_ru[target].shape)
|
186
138
|
else:
|
187
|
-
candidate_rus = self.
|
139
|
+
candidate_rus = self.compute_node_ru_for_candidate(c, candidate_idx, target) - self.min_ru[target]
|
140
|
+
|
188
141
|
ru_matrix.append(np.asarray(candidate_rus))
|
189
142
|
|
190
143
|
# We need to transpose the calculated ru matrix to allow later multiplication with
|
@@ -195,40 +148,6 @@ class MixedPrecisionSearchManager:
|
|
195
148
|
np_ru_matrix = np.array(ru_matrix)
|
196
149
|
return np.moveaxis(np_ru_matrix, source=0, destination=len(np_ru_matrix.shape) - 1)
|
197
150
|
|
198
|
-
def compute_candidate_relative_ru(self,
|
199
|
-
conf_node_idx: int,
|
200
|
-
candidate_idx: int,
|
201
|
-
target: RUTarget) -> np.ndarray:
|
202
|
-
"""
|
203
|
-
Computes a resource utilization vector for a given candidates of a given configurable node,
|
204
|
-
i.e., the matching resource utilization vector which is obtained by computing the given target's
|
205
|
-
resource utilization function on a minimal configuration in which the given
|
206
|
-
layer's candidates is changed to the new given one.
|
207
|
-
The result is normalized by subtracting the target's minimal resource utilization vector.
|
208
|
-
|
209
|
-
Args:
|
210
|
-
conf_node_idx: The index of a node in a sorted configurable nodes list.
|
211
|
-
candidate_idx: The index of a node's quantization configuration candidate.
|
212
|
-
target: The target for which the resource utilization is calculated (a RUTarget value).
|
213
|
-
|
214
|
-
Returns: Normalized node's resource utilization vector
|
215
|
-
|
216
|
-
"""
|
217
|
-
return self.compute_node_ru_for_candidate(conf_node_idx, candidate_idx, target) - \
|
218
|
-
self.get_min_target_resource_utilization(target)
|
219
|
-
|
220
|
-
def get_min_target_resource_utilization(self, target: RUTarget) -> np.ndarray:
|
221
|
-
"""
|
222
|
-
Returns the minimal resource utilization vector (pre-calculated on initialization) of a specific target.
|
223
|
-
|
224
|
-
Args:
|
225
|
-
target: The target for which the resource utilization is calculated (a RUTarget value).
|
226
|
-
|
227
|
-
Returns: Minimal resource utilization vector.
|
228
|
-
|
229
|
-
"""
|
230
|
-
return self.min_ru[target]
|
231
|
-
|
232
151
|
def compute_node_ru_for_candidate(self, conf_node_idx: int, candidate_idx: int, target: RUTarget) -> np.ndarray:
|
233
152
|
"""
|
234
153
|
Computes a resource utilization vector after replacing the given node's configuration candidate in the minimal
|
@@ -243,7 +162,8 @@ class MixedPrecisionSearchManager:
|
|
243
162
|
|
244
163
|
"""
|
245
164
|
cfg = self.replace_config_in_index(self.min_ru_config, conf_node_idx, candidate_idx)
|
246
|
-
|
165
|
+
# TODO compute for all targets at once. Currently the way up to add_set_of_ru_constraints is per target.
|
166
|
+
return self.ru_helper.compute_utilization({target}, cfg)[target]
|
247
167
|
|
248
168
|
@staticmethod
|
249
169
|
def replace_config_in_index(mp_cfg: List[int], idx: int, value: int) -> List[int]:
|
@@ -270,21 +190,10 @@ class MixedPrecisionSearchManager:
|
|
270
190
|
|
271
191
|
Returns: A mapping between a RUTarget and its non-configurable nodes' resource utilization vector.
|
272
192
|
"""
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
# compute for non-configurable nodes
|
278
|
-
if target == RUTarget.BOPS:
|
279
|
-
ru_vector = None
|
280
|
-
elif target == RUTarget.ACTIVATION:
|
281
|
-
ru_vector = ru_fns.metric_fn([], self.graph, self.fw_info, self.fw_impl, self.cuts)
|
282
|
-
else:
|
283
|
-
ru_vector = ru_fns.metric_fn([], self.graph, self.fw_info, self.fw_impl)
|
284
|
-
|
285
|
-
non_conf_ru_dict[target] = ru_vector
|
286
|
-
|
287
|
-
return non_conf_ru_dict
|
193
|
+
ru_metrics = self.ru_metrics - {RUTarget.BOPS}
|
194
|
+
ru = self.ru_helper.compute_utilization(ru_targets=ru_metrics, mp_cfg=None)
|
195
|
+
ru[RUTarget.BOPS] = None
|
196
|
+
return ru
|
288
197
|
|
289
198
|
def compute_resource_utilization_for_config(self, config: List[int]) -> ResourceUtilization:
|
290
199
|
"""
|
@@ -297,29 +206,11 @@ class MixedPrecisionSearchManager:
|
|
297
206
|
with the given config.
|
298
207
|
|
299
208
|
"""
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
if ru_target == RUTarget.BOPS:
|
306
|
-
configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.original_graph, self.fw_info, self.fw_impl, False)
|
307
|
-
elif ru_target == RUTarget.ACTIVATION:
|
308
|
-
configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.graph, self.fw_info, self.fw_impl, self.cuts)
|
309
|
-
else:
|
310
|
-
configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.original_graph, self.fw_info, self.fw_impl)
|
311
|
-
non_configurable_nodes_ru_vector = self.non_conf_ru_dict.get(ru_target)
|
312
|
-
if non_configurable_nodes_ru_vector is None or len(non_configurable_nodes_ru_vector) == 0:
|
313
|
-
ru_ru = self.compute_ru_functions[ru_target].aggregate_fn(configurable_nodes_ru_vector, False)
|
314
|
-
else:
|
315
|
-
ru_ru = self.compute_ru_functions[ru_target].aggregate_fn(
|
316
|
-
np.concatenate([configurable_nodes_ru_vector, non_configurable_nodes_ru_vector]), False)
|
317
|
-
|
318
|
-
ru_dict[ru_target] = ru_ru[0]
|
319
|
-
|
320
|
-
config_ru = ResourceUtilization()
|
321
|
-
config_ru.set_resource_utilization_by_target(ru_dict)
|
322
|
-
return config_ru
|
209
|
+
act_qcs, w_qcs = self.ru_helper.get_configurable_qcs(config)
|
210
|
+
ru = self.ru_helper.ru_calculator.compute_resource_utilization(
|
211
|
+
target_criterion=TargetInclusionCriterion.AnyQuantized, bitwidth_mode=BitwidthMode.QCustom, act_qcs=act_qcs,
|
212
|
+
w_qcs=w_qcs)
|
213
|
+
return ru
|
323
214
|
|
324
215
|
def finalize_distance_metric(self, layer_to_metrics_mapping: Dict[int, Dict[int, float]]):
|
325
216
|
"""
|