mct-nightly 2.2.0.20250113.527__py3-none-any.whl → 2.2.0.20250114.84821__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/RECORD +103 -105
- model_compression_toolkit/__init__.py +2 -2
- model_compression_toolkit/core/common/framework_info.py +1 -3
- model_compression_toolkit/core/common/fusion/layer_fusing.py +6 -5
- model_compression_toolkit/core/common/graph/base_graph.py +20 -21
- model_compression_toolkit/core/common/graph/base_node.py +44 -17
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +7 -6
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +0 -6
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +26 -135
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +36 -62
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +667 -0
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +25 -202
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py +164 -470
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +30 -7
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +7 -6
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +0 -1
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +0 -1
- model_compression_toolkit/core/common/pruning/pruner.py +5 -3
- model_compression_toolkit/core/common/quantization/bit_width_config.py +6 -12
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -2
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +15 -14
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +1 -1
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/graph_prep_runner.py +12 -11
- model_compression_toolkit/core/keras/data_util.py +24 -5
- model_compression_toolkit/core/keras/default_framework_info.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +1 -2
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +5 -6
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +4 -5
- model_compression_toolkit/core/runner.py +33 -60
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +1 -1
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantization_facade.py +8 -9
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +8 -9
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/metadata.py +11 -10
- model_compression_toolkit/pruning/keras/pruning_facade.py +5 -6
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +6 -7
- model_compression_toolkit/ptq/keras/quantization_facade.py +8 -9
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -9
- model_compression_toolkit/qat/keras/quantization_facade.py +5 -6
- model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py +1 -1
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +5 -9
- model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +1 -1
- model_compression_toolkit/target_platform_capabilities/__init__.py +9 -0
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +2 -2
- model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +18 -18
- model_compression_toolkit/target_platform_capabilities/schema/v1.py +13 -13
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/__init__.py +6 -6
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2fw.py +10 -10
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2keras.py +3 -3
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2pytorch.py +3 -2
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/current_tpc.py +8 -8
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities.py → targetplatform2framework/framework_quantization_capabilities.py} +40 -40
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities_component.py → targetplatform2framework/framework_quantization_capabilities_component.py} +2 -2
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/layer_filter_params.py +0 -1
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/operations_to_layers.py +8 -8
- model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +24 -24
- model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +18 -18
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/{tp_model.py → tpc.py} +31 -32
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/{tp_model.py → tpc.py} +27 -27
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +4 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/{tp_model.py → tpc.py} +27 -27
- model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +1 -2
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +2 -1
- model_compression_toolkit/trainable_infrastructure/keras/activation_quantizers/lsq/symmetric_lsq.py +1 -2
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/xquant/common/model_folding_utils.py +7 -6
- model_compression_toolkit/xquant/keras/keras_report_utils.py +4 -4
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +0 -105
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +0 -33
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +0 -23
- {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attribute_filter.py +0 -0
@@ -16,7 +16,7 @@
|
|
16
16
|
import numpy as np
|
17
17
|
from pulp import *
|
18
18
|
from tqdm import tqdm
|
19
|
-
from typing import Dict,
|
19
|
+
from typing import Dict, Tuple
|
20
20
|
|
21
21
|
from model_compression_toolkit.logger import Logger
|
22
22
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget
|
@@ -218,13 +218,11 @@ def _add_set_of_ru_constraints(search_manager: MixedPrecisionSearchManager,
|
|
218
218
|
np.sum(indicated_ru_matrix[i], axis=0) + # sum of metric values over all configurations in a row
|
219
219
|
search_manager.min_ru[target][i] for i in range(indicated_ru_matrix.shape[0])])
|
220
220
|
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
aggr_ru = search_manager.compute_ru_functions[target].aggregate_fn(ru_sum_vector)
|
225
|
-
else:
|
226
|
-
aggr_ru = search_manager.compute_ru_functions[target].aggregate_fn(np.concatenate([ru_sum_vector, non_conf_ru_vector]))
|
221
|
+
ru_vec = ru_sum_vector
|
222
|
+
if non_conf_ru_vector is not None and non_conf_ru_vector.size:
|
223
|
+
ru_vec = np.concatenate([ru_vec, non_conf_ru_vector])
|
227
224
|
|
225
|
+
aggr_ru = _aggregate_for_lp(ru_vec, target)
|
228
226
|
for v in aggr_ru:
|
229
227
|
if isinstance(v, float):
|
230
228
|
if v > target_resource_utilization_value:
|
@@ -235,6 +233,31 @@ def _add_set_of_ru_constraints(search_manager: MixedPrecisionSearchManager,
|
|
235
233
|
lp_problem += v <= target_resource_utilization_value
|
236
234
|
|
237
235
|
|
236
|
+
def _aggregate_for_lp(ru_vec, target: RUTarget) -> list:
|
237
|
+
"""
|
238
|
+
Aggregate resource utilization values for the LP.
|
239
|
+
|
240
|
+
Args:
|
241
|
+
ru_vec: a vector of resource utilization values.
|
242
|
+
target: resource utilization target.
|
243
|
+
|
244
|
+
Returns:
|
245
|
+
Aggregated resource utilization.
|
246
|
+
"""
|
247
|
+
if target == RUTarget.TOTAL:
|
248
|
+
w = lpSum(v[0] for v in ru_vec)
|
249
|
+
return [w + v[1] for v in ru_vec]
|
250
|
+
|
251
|
+
if target in [RUTarget.WEIGHTS, RUTarget.BOPS]:
|
252
|
+
return [lpSum(ru_vec)]
|
253
|
+
|
254
|
+
if target == RUTarget.ACTIVATION:
|
255
|
+
# for max aggregation, each value constitutes a separate constraint
|
256
|
+
return list(ru_vec)
|
257
|
+
|
258
|
+
raise ValueError(f'Unexpected target {target}.')
|
259
|
+
|
260
|
+
|
238
261
|
def _build_layer_to_metrics_mapping(search_manager: MixedPrecisionSearchManager,
|
239
262
|
target_resource_utilization: ResourceUtilization,
|
240
263
|
eps: float = EPS) -> Dict[int, Dict[int, float]]:
|
@@ -113,11 +113,9 @@ class SensitivityEvaluation:
|
|
113
113
|
# in the new built MP model.
|
114
114
|
self.baseline_model, self.model_mp, self.conf_node2layers = self._build_models()
|
115
115
|
|
116
|
-
# Build images batches for inference comparison
|
117
|
-
|
118
|
-
|
119
|
-
# Casting images tensors to the framework tensor type.
|
120
|
-
self.images_batches = [self.fw_impl.to_tensor(img) for img in self.images_batches]
|
116
|
+
# Build images batches for inference comparison and cat to framework type
|
117
|
+
images_batches = self._get_images_batches(quant_config.num_of_images)
|
118
|
+
self.images_batches = [self.fw_impl.to_tensor(img) for img in images_batches]
|
121
119
|
|
122
120
|
# Initiating baseline_tensors_list since it is not initiated in SensitivityEvaluationManager init.
|
123
121
|
self.baseline_tensors_list = self._init_baseline_tensors_list()
|
@@ -80,8 +80,8 @@ def greedy_solution_refinement_procedure(mp_solution: List[int],
|
|
80
80
|
updated_ru.append(node_updated_ru)
|
81
81
|
|
82
82
|
# filter out new configs that don't hold the resource utilization restrictions
|
83
|
-
node_filtered_ru = [(node_idx, ru) for node_idx, ru in zip(valid_candidates, updated_ru)
|
84
|
-
target_resource_utilization.
|
83
|
+
node_filtered_ru = [(node_idx, ru) for node_idx, ru in zip(valid_candidates, updated_ru)
|
84
|
+
if target_resource_utilization.is_satisfied_by(ru)]
|
85
85
|
|
86
86
|
if len(node_filtered_ru) > 0:
|
87
87
|
sorted_by_ru = sorted(node_filtered_ru, key=lambda node_ru: (node_ru[1].total_memory,
|
@@ -24,7 +24,8 @@ from model_compression_toolkit.core.common.pruning.memory_calculator import Memo
|
|
24
24
|
from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import PruningFrameworkImplementation
|
25
25
|
from model_compression_toolkit.core.common.pruning.mask.per_simd_group_mask import PerSIMDGroupMask
|
26
26
|
from model_compression_toolkit.logger import Logger
|
27
|
-
from model_compression_toolkit.target_platform_capabilities.
|
27
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
28
|
+
FrameworkQuantizationCapabilities
|
28
29
|
|
29
30
|
|
30
31
|
class GreedyMaskCalculator:
|
@@ -42,7 +43,7 @@ class GreedyMaskCalculator:
|
|
42
43
|
target_resource_utilization: ResourceUtilization,
|
43
44
|
graph: Graph,
|
44
45
|
fw_impl: PruningFrameworkImplementation,
|
45
|
-
|
46
|
+
fqc: FrameworkQuantizationCapabilities,
|
46
47
|
simd_groups_indices: Dict[BaseNode, List[List[int]]]):
|
47
48
|
"""
|
48
49
|
Args:
|
@@ -52,7 +53,7 @@ class GreedyMaskCalculator:
|
|
52
53
|
target_resource_utilization (ResourceUtilization): The target resource utilization to achieve.
|
53
54
|
graph (Graph): The computational graph of the model.
|
54
55
|
fw_impl (PruningFrameworkImplementation): Framework-specific implementation details.
|
55
|
-
|
56
|
+
fqc (FrameworkQuantizationCapabilities): Platform-specific constraints and capabilities.
|
56
57
|
simd_groups_indices (Dict[BaseNode, List[List[int]]]): Indices of SIMD groups in each node.
|
57
58
|
"""
|
58
59
|
self.prunable_nodes = prunable_nodes
|
@@ -60,7 +61,7 @@ class GreedyMaskCalculator:
|
|
60
61
|
self.target_resource_utilization = target_resource_utilization
|
61
62
|
self.graph = graph
|
62
63
|
self.fw_impl = fw_impl
|
63
|
-
self.
|
64
|
+
self.fqc = fqc
|
64
65
|
|
65
66
|
self.simd_groups_indices = simd_groups_indices
|
66
67
|
self.simd_groups_scores = simd_groups_scores
|
@@ -90,7 +91,7 @@ class GreedyMaskCalculator:
|
|
90
91
|
"""
|
91
92
|
# Iteratively unprune the graph while monitoring the memory footprint.
|
92
93
|
current_memory = self.memory_calculator.get_pruned_graph_memory(masks=self.oc_pruning_mask.get_mask(),
|
93
|
-
include_padded_channels=self.
|
94
|
+
include_padded_channels=self.fqc.is_simd_padding)
|
94
95
|
if current_memory > self.target_resource_utilization.weights_memory:
|
95
96
|
Logger.critical(f"Insufficient memory for the target resource utilization: current memory {current_memory}, "
|
96
97
|
f"target memory {self.target_resource_utilization.weights_memory}.")
|
@@ -105,7 +106,7 @@ class GreedyMaskCalculator:
|
|
105
106
|
group_index=group_to_remain_idx,
|
106
107
|
mask_indicator=MaskIndicator.REMAINED)
|
107
108
|
current_memory = self.memory_calculator.get_pruned_graph_memory(masks=self.oc_pruning_mask.get_mask(),
|
108
|
-
include_padded_channels=self.
|
109
|
+
include_padded_channels=self.fqc.is_simd_padding)
|
109
110
|
|
110
111
|
# If the target memory is exceeded, revert the last addition.
|
111
112
|
if current_memory > self.target_resource_utilization.weights_memory:
|
@@ -23,7 +23,6 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_
|
|
23
23
|
from model_compression_toolkit.core.common.pruning.memory_calculator import MemoryCalculator
|
24
24
|
from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import PruningFrameworkImplementation
|
25
25
|
from model_compression_toolkit.logger import Logger
|
26
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
27
26
|
|
28
27
|
class MaskIndicator(Enum):
|
29
28
|
"""
|
@@ -23,7 +23,6 @@ from model_compression_toolkit.core.common.pruning.mask.per_channel_mask import
|
|
23
23
|
from model_compression_toolkit.core.common.pruning.memory_calculator import MemoryCalculator
|
24
24
|
from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import PruningFrameworkImplementation
|
25
25
|
from model_compression_toolkit.logger import Logger
|
26
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
27
26
|
|
28
27
|
class PerSIMDGroupMask:
|
29
28
|
def __init__(self,
|
@@ -29,7 +29,9 @@ from model_compression_toolkit.core.common.pruning.pruning_framework_implementat
|
|
29
29
|
from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo, \
|
30
30
|
unroll_simd_scores_to_per_channel_scores
|
31
31
|
from model_compression_toolkit.logger import Logger
|
32
|
-
from model_compression_toolkit.target_platform_capabilities.
|
32
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import \
|
33
|
+
FrameworkQuantizationCapabilities
|
34
|
+
|
33
35
|
|
34
36
|
class Pruner:
|
35
37
|
"""
|
@@ -43,7 +45,7 @@ class Pruner:
|
|
43
45
|
target_resource_utilization: ResourceUtilization,
|
44
46
|
representative_data_gen: Callable,
|
45
47
|
pruning_config: PruningConfig,
|
46
|
-
target_platform_capabilities:
|
48
|
+
target_platform_capabilities: FrameworkQuantizationCapabilities):
|
47
49
|
"""
|
48
50
|
Args:
|
49
51
|
float_graph (Graph): The floating-point representation of the model's computation graph.
|
@@ -52,7 +54,7 @@ class Pruner:
|
|
52
54
|
target_resource_utilization (ResourceUtilization): The target resource utilization to be achieved after pruning.
|
53
55
|
representative_data_gen (Callable): Generator function for representative dataset used in pruning analysis.
|
54
56
|
pruning_config (PruningConfig): Configuration object specifying how pruning should be performed.
|
55
|
-
target_platform_capabilities (
|
57
|
+
target_platform_capabilities (FrameworkQuantizationCapabilities): Object encapsulating the capabilities of the target hardware platform.
|
56
58
|
"""
|
57
59
|
self.float_graph = float_graph
|
58
60
|
self.fw_info = fw_info
|
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
from dataclasses import dataclass, field
|
15
16
|
from typing import List, Union, Dict
|
16
17
|
|
17
18
|
from model_compression_toolkit.core.common import Graph
|
@@ -19,6 +20,7 @@ from model_compression_toolkit.core.common.matchers.node_matcher import BaseNode
|
|
19
20
|
from model_compression_toolkit.logger import Logger
|
20
21
|
|
21
22
|
|
23
|
+
@dataclass
|
22
24
|
class ManualBitWidthSelection:
|
23
25
|
"""
|
24
26
|
Class to encapsulate the manual bit width selection configuration for a specific filter.
|
@@ -27,13 +29,11 @@ class ManualBitWidthSelection:
|
|
27
29
|
filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation.
|
28
30
|
bit_width (int): The bit width to be applied to the selected nodes.
|
29
31
|
"""
|
30
|
-
|
31
|
-
|
32
|
-
bit_width: int):
|
33
|
-
self.filter = filter
|
34
|
-
self.bit_width = bit_width
|
32
|
+
filter: BaseNodeMatcher
|
33
|
+
bit_width: int
|
35
34
|
|
36
35
|
|
36
|
+
@dataclass
|
37
37
|
class BitWidthConfig:
|
38
38
|
"""
|
39
39
|
Class to manage manual bit-width configurations.
|
@@ -41,13 +41,7 @@ class BitWidthConfig:
|
|
41
41
|
Attributes:
|
42
42
|
manual_activation_bit_width_selection_list (List[ManualBitWidthSelection]): A list of ManualBitWidthSelection objects defining manual bit-width configurations.
|
43
43
|
"""
|
44
|
-
|
45
|
-
manual_activation_bit_width_selection_list: List[ManualBitWidthSelection] = None):
|
46
|
-
self.manual_activation_bit_width_selection_list = [] if manual_activation_bit_width_selection_list is None else manual_activation_bit_width_selection_list
|
47
|
-
|
48
|
-
def __repr__(self):
|
49
|
-
# Used for debugging, thus no cover.
|
50
|
-
return str(self.__dict__) # pragma: no cover
|
44
|
+
manual_activation_bit_width_selection_list: List[ManualBitWidthSelection] = field(default_factory=list)
|
51
45
|
|
52
46
|
def set_manual_activation_bit_width(self,
|
53
47
|
filters: Union[List[BaseNodeMatcher], BaseNodeMatcher],
|
@@ -15,8 +15,7 @@
|
|
15
15
|
import copy
|
16
16
|
from typing import List
|
17
17
|
|
18
|
-
from
|
19
|
-
|
18
|
+
from mct_quantizers import QuantizationMethod
|
20
19
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
21
20
|
from model_compression_toolkit.constants import FLOAT_BITWIDTH
|
22
21
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
@@ -401,9 +401,9 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
401
401
|
# therefore, we need to look for the attribute in the op_cfg that is contained in the node attribute's name.
|
402
402
|
attrs_included_in_name = {k: v for k, v in op_cfg.attr_weights_configs_mapping.items() if k in attr}
|
403
403
|
if len(attrs_included_in_name) > 1: # pragma: no cover
|
404
|
-
Logger.critical(f"Found multiple attribute in
|
404
|
+
Logger.critical(f"Found multiple attribute in FQC OpConfig that are contained "
|
405
405
|
f"in the attribute name '{attr}'."
|
406
|
-
f"Please fix the
|
406
|
+
f"Please fix the FQC attribute names mapping such that each operator's attribute would "
|
407
407
|
f"have a unique matching name.")
|
408
408
|
if len(attrs_included_in_name) == 0:
|
409
409
|
attr_cfg = op_cfg.default_weight_attr_config
|
@@ -25,7 +25,7 @@ from model_compression_toolkit.constants import MIN_THRESHOLD
|
|
25
25
|
class CustomOpsetLayers(NamedTuple):
|
26
26
|
"""
|
27
27
|
This struct defines a set of operators from a specific framework, which will be used to configure a custom operator
|
28
|
-
set in the
|
28
|
+
set in the FQC.
|
29
29
|
|
30
30
|
Args:
|
31
31
|
operators: a list of framework operators to map to a certain custom opset name.
|
@@ -16,8 +16,8 @@
|
|
16
16
|
from collections.abc import Callable
|
17
17
|
from functools import partial
|
18
18
|
|
19
|
+
from mct_quantizers import QuantizationMethod
|
19
20
|
from model_compression_toolkit.logger import Logger
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
21
21
|
from model_compression_toolkit.core.common.quantization.quantizers.lut_kmeans_quantizer import lut_kmeans_quantizer
|
22
22
|
from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import power_of_two_quantizer, \
|
23
23
|
symmetric_quantizer, uniform_quantizer
|
@@ -16,8 +16,8 @@
|
|
16
16
|
from collections.abc import Callable
|
17
17
|
from functools import partial
|
18
18
|
|
19
|
+
from mct_quantizers import QuantizationMethod
|
19
20
|
from model_compression_toolkit.logger import Logger
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
21
21
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import \
|
22
22
|
lut_kmeans_tensor, lut_kmeans_histogram
|
23
23
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import \
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py
CHANGED
@@ -16,11 +16,11 @@ from copy import deepcopy
|
|
16
16
|
from typing import Tuple, Callable, List, Iterable, Optional
|
17
17
|
import numpy as np
|
18
18
|
import model_compression_toolkit.core.common.quantization.quantization_config as qc
|
19
|
+
from mct_quantizers import QuantizationMethod
|
19
20
|
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode, HessianScoresGranularity, \
|
20
21
|
HessianInfoService
|
21
22
|
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse, compute_mae, compute_lp_norm
|
22
23
|
from model_compression_toolkit.logger import Logger
|
23
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
24
24
|
from model_compression_toolkit.constants import FLOAT_32, NUM_QPARAM_HESSIAN_SAMPLES
|
25
25
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import uniform_quantize_tensor, \
|
26
26
|
reshape_tensor_for_per_channel_search
|
@@ -16,6 +16,7 @@ import numpy as np
|
|
16
16
|
from typing import Union, Tuple, Dict
|
17
17
|
|
18
18
|
import model_compression_toolkit.core.common.quantization.quantization_config as qc
|
19
|
+
from mct_quantizers import QuantizationMethod
|
19
20
|
from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES, SIGNED
|
20
21
|
from model_compression_toolkit.core.common.hessian import HessianInfoService
|
21
22
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_search import \
|
@@ -23,7 +24,6 @@ from model_compression_toolkit.core.common.quantization.quantization_params_gene
|
|
23
24
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import max_power_of_two, get_tensor_max
|
24
25
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
|
25
26
|
get_threshold_selection_tensor_error_function, get_threshold_selection_histogram_error_function
|
26
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
27
27
|
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
|
28
28
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import quantize_tensor
|
29
29
|
|
@@ -15,7 +15,7 @@
|
|
15
15
|
import numpy as np
|
16
16
|
from typing import Dict, Union
|
17
17
|
|
18
|
-
from
|
18
|
+
from mct_quantizers import QuantizationMethod
|
19
19
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
|
20
20
|
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
|
21
21
|
from model_compression_toolkit.core.common.quantization import quantization_params_generation
|
@@ -25,7 +25,7 @@ from model_compression_toolkit.core.common.quantization.quantization_params_gene
|
|
25
25
|
qparams_symmetric_selection_histogram_search, kl_qparams_symmetric_selection_histogram_search
|
26
26
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import \
|
27
27
|
get_tensor_max
|
28
|
-
from
|
28
|
+
from mct_quantizers import QuantizationMethod
|
29
29
|
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
|
30
30
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import quantize_tensor
|
31
31
|
|
@@ -24,7 +24,7 @@ from model_compression_toolkit.core.common.quantization.quantization_params_gene
|
|
24
24
|
get_threshold_selection_tensor_error_function, get_threshold_selection_histogram_error_function
|
25
25
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import get_tensor_max, \
|
26
26
|
get_tensor_min
|
27
|
-
from
|
27
|
+
from mct_quantizers import QuantizationMethod
|
28
28
|
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
|
29
29
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import uniform_quantize_tensor
|
30
30
|
|
@@ -33,9 +33,10 @@ from model_compression_toolkit.core.common.quantization.quantization_params_fn_s
|
|
33
33
|
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
|
34
34
|
get_weights_quantization_fn
|
35
35
|
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
|
36
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
37
36
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
|
38
37
|
QuantizationConfigOptions
|
38
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
39
|
+
FrameworkQuantizationCapabilities
|
39
40
|
|
40
41
|
|
41
42
|
def set_quantization_configuration_to_graph(graph: Graph,
|
@@ -71,14 +72,14 @@ def set_quantization_configuration_to_graph(graph: Graph,
|
|
71
72
|
graph=graph,
|
72
73
|
quant_config=quant_config,
|
73
74
|
fw_info=graph.fw_info,
|
74
|
-
|
75
|
+
fqc=graph.fqc,
|
75
76
|
mixed_precision_enable=mixed_precision_enable,
|
76
77
|
manual_bit_width_override=nodes_to_manipulate_bit_widths.get(n))
|
77
78
|
return graph
|
78
79
|
|
79
80
|
|
80
81
|
def filter_node_qco_by_graph(node: BaseNode,
|
81
|
-
|
82
|
+
fqc: FrameworkQuantizationCapabilities,
|
82
83
|
graph: Graph,
|
83
84
|
node_qc_options: QuantizationConfigOptions
|
84
85
|
) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
|
@@ -90,7 +91,7 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
90
91
|
|
91
92
|
Args:
|
92
93
|
node: Node for filtering.
|
93
|
-
|
94
|
+
fqc: FQC to extract the QuantizationConfigOptions for the next nodes.
|
94
95
|
graph: Graph object.
|
95
96
|
node_qc_options: Node's QuantizationConfigOptions.
|
96
97
|
|
@@ -108,7 +109,7 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
108
109
|
next_nodes = []
|
109
110
|
while len(_next_nodes):
|
110
111
|
n = _next_nodes.pop(0)
|
111
|
-
qco = n.get_qco(
|
112
|
+
qco = n.get_qco(fqc)
|
112
113
|
qp = [qc.quantization_preserving for qc in qco.quantization_configurations]
|
113
114
|
if not all(qp) and any(qp):
|
114
115
|
Logger.error(f'Attribute "quantization_preserving" should be the same for all QuantizaionConfigOptions in {n}.')
|
@@ -117,7 +118,7 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
117
118
|
next_nodes.append(n)
|
118
119
|
|
119
120
|
if len(next_nodes):
|
120
|
-
next_nodes_qc_options = [_node.get_qco(
|
121
|
+
next_nodes_qc_options = [_node.get_qco(fqc) for _node in next_nodes]
|
121
122
|
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
|
122
123
|
for qc_opts in next_nodes_qc_options
|
123
124
|
for op_cfg in qc_opts.quantization_configurations])
|
@@ -126,7 +127,7 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
126
127
|
_node_qc_options = [_option for _option in _node_qc_options
|
127
128
|
if _option.activation_n_bits <= next_nodes_supported_input_bitwidth]
|
128
129
|
if len(_node_qc_options) == 0:
|
129
|
-
Logger.critical(f"Graph doesn't match
|
130
|
+
Logger.critical(f"Graph doesn't match FQC bit configurations: {node} -> {next_nodes}.")
|
130
131
|
|
131
132
|
# Verify base config match
|
132
133
|
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
|
@@ -136,9 +137,9 @@ def filter_node_qco_by_graph(node: BaseNode,
|
|
136
137
|
if len(_node_qc_options) > 0:
|
137
138
|
output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
|
138
139
|
_base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
|
139
|
-
Logger.warning(f"Node {node} base quantization config changed to match Graph and
|
140
|
+
Logger.warning(f"Node {node} base quantization config changed to match Graph and FQC configuration.\nCause: {node} -> {next_nodes}.")
|
140
141
|
else:
|
141
|
-
Logger.critical(f"Graph doesn't match
|
142
|
+
Logger.critical(f"Graph doesn't match FQC bit configurations: {node} -> {next_nodes}.") # pragma: no cover
|
142
143
|
|
143
144
|
return _base_config, _node_qc_options
|
144
145
|
|
@@ -147,7 +148,7 @@ def set_quantization_configs_to_node(node: BaseNode,
|
|
147
148
|
graph: Graph,
|
148
149
|
quant_config: QuantizationConfig,
|
149
150
|
fw_info: FrameworkInfo,
|
150
|
-
|
151
|
+
fqc: FrameworkQuantizationCapabilities,
|
151
152
|
mixed_precision_enable: bool = False,
|
152
153
|
manual_bit_width_override: Optional[int] = None):
|
153
154
|
"""
|
@@ -158,12 +159,12 @@ def set_quantization_configs_to_node(node: BaseNode,
|
|
158
159
|
graph (Graph): Model's internal representation graph.
|
159
160
|
quant_config (QuantizationConfig): Quantization configuration to generate the node's configurations from.
|
160
161
|
fw_info (FrameworkInfo): Information needed for quantization about the specific framework.
|
161
|
-
|
162
|
+
fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to get default OpQuantizationConfig.
|
162
163
|
mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
|
163
164
|
manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width. Defaults to None.
|
164
165
|
"""
|
165
|
-
node_qc_options = node.get_qco(
|
166
|
-
base_config, node_qc_options_list = filter_node_qco_by_graph(node,
|
166
|
+
node_qc_options = node.get_qco(fqc)
|
167
|
+
base_config, node_qc_options_list = filter_node_qco_by_graph(node, fqc, graph, node_qc_options)
|
167
168
|
|
168
169
|
# If a manual_bit_width_override is given, filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override,
|
169
170
|
# and update base_config accordingly.
|
@@ -257,7 +258,7 @@ def _create_node_single_candidate_qc(qc: QuantizationConfig,
|
|
257
258
|
attrs_with_enabled_quantization = [attr for attr, cfg in op_cfg.attr_weights_configs_mapping.items()
|
258
259
|
if cfg.enable_weights_quantization]
|
259
260
|
if len(attrs_with_enabled_quantization) > 1:
|
260
|
-
Logger.warning(f"Multiple weights attributes quantization is enabled via the provided
|
261
|
+
Logger.warning(f"Multiple weights attributes quantization is enabled via the provided FQC."
|
261
262
|
f"Quantizing any attribute other than the kernel is experimental "
|
262
263
|
f"and may be subject to unstable behavior."
|
263
264
|
f"Attributes with enabled weights quantization: {attrs_with_enabled_quantization}.")
|
@@ -26,7 +26,7 @@ from model_compression_toolkit.logger import Logger
|
|
26
26
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
27
27
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
28
28
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
29
|
-
from
|
29
|
+
from mct_quantizers import QuantizationMethod
|
30
30
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig
|
31
31
|
|
32
32
|
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.core import common
|
|
22
22
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
23
23
|
from model_compression_toolkit.core.common.graph.graph_matchers import EdgeMatcher, NodeOperationMatcher
|
24
24
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
25
|
-
from
|
25
|
+
from mct_quantizers import QuantizationMethod
|
26
26
|
from model_compression_toolkit.constants import THRESHOLD, RANGE_MIN, RANGE_MAX
|
27
27
|
from model_compression_toolkit.logger import Logger
|
28
28
|
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.logger import Logger
|
|
22
22
|
from model_compression_toolkit.core.common import FrameworkInfo, Graph, BaseNode
|
23
23
|
from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
|
24
24
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
25
|
-
from
|
25
|
+
from mct_quantizers import QuantizationMethod
|
26
26
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig
|
27
27
|
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import create_node_activation_qc, \
|
28
28
|
set_quantization_configs_to_node
|
@@ -359,7 +359,7 @@ def shift_negative_function(graph: Graph,
|
|
359
359
|
node=pad_node,
|
360
360
|
graph=graph,
|
361
361
|
quant_config=core_config.quantization_config,
|
362
|
-
|
362
|
+
fqc=graph.fqc,
|
363
363
|
mixed_precision_enable=core_config.is_mixed_precision_enabled)
|
364
364
|
|
365
365
|
for candidate_qc in pad_node.candidates_quantization_cfg:
|
@@ -376,7 +376,7 @@ def shift_negative_function(graph: Graph,
|
|
376
376
|
node=add_node,
|
377
377
|
graph=graph,
|
378
378
|
quant_config=core_config.quantization_config,
|
379
|
-
|
379
|
+
fqc=graph.fqc,
|
380
380
|
mixed_precision_enable=core_config.is_mixed_precision_enabled)
|
381
381
|
|
382
382
|
original_non_linear_activation_nbits = non_linear_node_cfg_candidate.activation_n_bits
|
@@ -392,7 +392,7 @@ def shift_negative_function(graph: Graph,
|
|
392
392
|
bypass_candidate_qc.activation_quantization_cfg.activation_quantization_params[SIGNED] = False
|
393
393
|
graph.shift_stats_collector(bypass_node, np.array(shift_value))
|
394
394
|
|
395
|
-
add_node_qco = add_node.get_qco(graph.
|
395
|
+
add_node_qco = add_node.get_qco(graph.fqc).quantization_configurations
|
396
396
|
for op_qc_idx, candidate_qc in enumerate(add_node.candidates_quantization_cfg):
|
397
397
|
for attr in add_node.get_node_weights_attributes():
|
398
398
|
candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False
|
@@ -533,7 +533,7 @@ def apply_shift_negative_correction(graph: Graph,
|
|
533
533
|
nodes = list(graph.nodes())
|
534
534
|
for n in nodes:
|
535
535
|
# Skip substitution if QuantizationMethod is uniform.
|
536
|
-
node_qco = n.get_qco(graph.
|
536
|
+
node_qco = n.get_qco(graph.fqc)
|
537
537
|
if any([op_qc.activation_quantization_method is QuantizationMethod.UNIFORM
|
538
538
|
for op_qc in node_qco.quantization_configurations]):
|
539
539
|
continue
|
@@ -29,8 +29,9 @@ from model_compression_toolkit.core.common.quantization.set_node_quantization_co
|
|
29
29
|
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
|
30
30
|
from model_compression_toolkit.core.common.substitutions.linear_collapsing_substitution import \
|
31
31
|
linear_collapsing_substitute
|
32
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
33
32
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter
|
33
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
34
|
+
FrameworkQuantizationCapabilities
|
34
35
|
|
35
36
|
|
36
37
|
def graph_preparation_runner(in_model: Any,
|
@@ -38,7 +39,7 @@ def graph_preparation_runner(in_model: Any,
|
|
38
39
|
quantization_config: QuantizationConfig,
|
39
40
|
fw_info: FrameworkInfo,
|
40
41
|
fw_impl: FrameworkImplementation,
|
41
|
-
|
42
|
+
fqc: FrameworkQuantizationCapabilities,
|
42
43
|
bit_width_config: BitWidthConfig = None,
|
43
44
|
tb_w: TensorboardWriter = None,
|
44
45
|
mixed_precision_enable: bool = False,
|
@@ -58,7 +59,7 @@ def graph_preparation_runner(in_model: Any,
|
|
58
59
|
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
59
60
|
groups of layers by how they should be quantized, etc.).
|
60
61
|
fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
|
61
|
-
|
62
|
+
fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities object that models the inference target platform and
|
62
63
|
the attached framework operator's information.
|
63
64
|
bit_width_config (BitWidthConfig): Config for bit-width selection. Defaults to None.
|
64
65
|
tb_w (TensorboardWriter): TensorboardWriter object for logging.
|
@@ -71,7 +72,7 @@ def graph_preparation_runner(in_model: Any,
|
|
71
72
|
|
72
73
|
graph = read_model_to_graph(in_model,
|
73
74
|
representative_data_gen,
|
74
|
-
|
75
|
+
fqc,
|
75
76
|
fw_info,
|
76
77
|
fw_impl)
|
77
78
|
|
@@ -79,7 +80,7 @@ def graph_preparation_runner(in_model: Any,
|
|
79
80
|
tb_w.add_graph(graph, 'initial_graph')
|
80
81
|
|
81
82
|
transformed_graph = get_finalized_graph(graph,
|
82
|
-
|
83
|
+
fqc,
|
83
84
|
quantization_config,
|
84
85
|
bit_width_config,
|
85
86
|
fw_info,
|
@@ -92,7 +93,7 @@ def graph_preparation_runner(in_model: Any,
|
|
92
93
|
|
93
94
|
|
94
95
|
def get_finalized_graph(initial_graph: Graph,
|
95
|
-
|
96
|
+
fqc: FrameworkQuantizationCapabilities,
|
96
97
|
quant_config: QuantizationConfig = DEFAULTCONFIG,
|
97
98
|
bit_width_config: BitWidthConfig = None,
|
98
99
|
fw_info: FrameworkInfo = None,
|
@@ -106,7 +107,7 @@ def get_finalized_graph(initial_graph: Graph,
|
|
106
107
|
|
107
108
|
Args:
|
108
109
|
initial_graph (Graph): Graph to apply the changes to.
|
109
|
-
|
110
|
+
fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities object that describes the desired inference target platform (includes fusing patterns MCT should handle).
|
110
111
|
quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be
|
111
112
|
quantized.
|
112
113
|
bit_width_config (BitWidthConfig): Config for bit-width selection. Defaults to None.
|
@@ -160,7 +161,7 @@ def get_finalized_graph(initial_graph: Graph,
|
|
160
161
|
######################################
|
161
162
|
# Layer fusing
|
162
163
|
######################################
|
163
|
-
transformed_graph = fusion(transformed_graph,
|
164
|
+
transformed_graph = fusion(transformed_graph, fqc)
|
164
165
|
|
165
166
|
######################################
|
166
167
|
# Channel equalization
|
@@ -185,7 +186,7 @@ def get_finalized_graph(initial_graph: Graph,
|
|
185
186
|
|
186
187
|
def read_model_to_graph(in_model: Any,
|
187
188
|
representative_data_gen: Callable,
|
188
|
-
|
189
|
+
fqc: FrameworkQuantizationCapabilities,
|
189
190
|
fw_info: FrameworkInfo = None,
|
190
191
|
fw_impl: FrameworkImplementation = None) -> Graph:
|
191
192
|
|
@@ -195,7 +196,7 @@ def read_model_to_graph(in_model: Any,
|
|
195
196
|
Args:
|
196
197
|
in_model: Model to optimize and prepare for quantization.
|
197
198
|
representative_data_gen: Dataset used for calibration.
|
198
|
-
|
199
|
+
fqc: FrameworkQuantizationCapabilities object that models the inference target platform and
|
199
200
|
the attached framework operator's information.
|
200
201
|
fw_info: Information needed for quantization about the specific framework (e.g.,
|
201
202
|
kernel channels indices, groups of layers by how they should be quantized, etc.)
|
@@ -207,5 +208,5 @@ def read_model_to_graph(in_model: Any,
|
|
207
208
|
graph = fw_impl.model_reader(in_model,
|
208
209
|
representative_data_gen)
|
209
210
|
graph.set_fw_info(fw_info)
|
210
|
-
graph.
|
211
|
+
graph.set_fqc(fqc)
|
211
212
|
return graph
|