mct-nightly 2.2.0.20250113.134913__py3-none-any.whl → 2.2.0.20250114.84821__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/RECORD +102 -104
- model_compression_toolkit/__init__.py +2 -2
- model_compression_toolkit/core/common/framework_info.py +1 -3
- model_compression_toolkit/core/common/fusion/layer_fusing.py +6 -5
- model_compression_toolkit/core/common/graph/base_graph.py +20 -21
- model_compression_toolkit/core/common/graph/base_node.py +44 -17
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +7 -6
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +0 -6
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +26 -135
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +36 -62
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +667 -0
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +25 -202
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py +164 -470
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +30 -7
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +7 -6
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +0 -1
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +0 -1
- model_compression_toolkit/core/common/pruning/pruner.py +5 -3
- model_compression_toolkit/core/common/quantization/bit_width_config.py +6 -12
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -2
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +15 -14
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +1 -1
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/graph_prep_runner.py +12 -11
- model_compression_toolkit/core/keras/default_framework_info.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +1 -2
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +5 -6
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +4 -5
- model_compression_toolkit/core/runner.py +33 -60
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +1 -1
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantization_facade.py +8 -9
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +8 -9
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/metadata.py +11 -10
- model_compression_toolkit/pruning/keras/pruning_facade.py +5 -6
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +6 -7
- model_compression_toolkit/ptq/keras/quantization_facade.py +8 -9
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -9
- model_compression_toolkit/qat/keras/quantization_facade.py +5 -6
- model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py +1 -1
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +5 -9
- model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +1 -1
- model_compression_toolkit/target_platform_capabilities/__init__.py +9 -0
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +2 -2
- model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +18 -18
- model_compression_toolkit/target_platform_capabilities/schema/v1.py +13 -13
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/__init__.py +6 -6
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2fw.py +10 -10
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2keras.py +3 -3
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2pytorch.py +3 -2
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/current_tpc.py +8 -8
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities.py → targetplatform2framework/framework_quantization_capabilities.py} +40 -40
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities_component.py → targetplatform2framework/framework_quantization_capabilities_component.py} +2 -2
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/layer_filter_params.py +0 -1
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/operations_to_layers.py +8 -8
- model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +24 -24
- model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +18 -18
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/{tp_model.py → tpc.py} +31 -32
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/{tp_model.py → tpc.py} +27 -27
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +4 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/{tp_model.py → tpc.py} +27 -27
- model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +1 -2
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +2 -1
- model_compression_toolkit/trainable_infrastructure/keras/activation_quantizers/lsq/symmetric_lsq.py +1 -2
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/xquant/common/model_folding_utils.py +7 -6
- model_compression_toolkit/xquant/keras/keras_report_utils.py +4 -4
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +0 -105
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +0 -33
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +0 -23
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attribute_filter.py +0 -0
@@ -17,7 +17,7 @@ from typing import Callable, Tuple
|
|
17
17
|
|
18
18
|
from model_compression_toolkit import get_target_platform_capabilities
|
19
19
|
from model_compression_toolkit.constants import TENSORFLOW
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
21
21
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
22
22
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
23
23
|
from model_compression_toolkit.core.common.pruning.pruner import Pruner
|
@@ -26,17 +26,16 @@ from model_compression_toolkit.core.common.pruning.pruning_info import PruningIn
|
|
26
26
|
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
|
27
27
|
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
|
28
28
|
from model_compression_toolkit.logger import Logger
|
29
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
30
29
|
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
31
30
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
32
31
|
|
33
32
|
if FOUND_TF:
|
33
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
|
34
|
+
AttachTpcToKeras
|
34
35
|
from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
|
35
36
|
from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation
|
36
37
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
37
38
|
from tensorflow.keras.models import Model
|
38
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
|
39
|
-
AttachTpcToKeras
|
40
39
|
|
41
40
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
42
41
|
|
@@ -44,7 +43,7 @@ if FOUND_TF:
|
|
44
43
|
target_resource_utilization: ResourceUtilization,
|
45
44
|
representative_data_gen: Callable,
|
46
45
|
pruning_config: PruningConfig = PruningConfig(),
|
47
|
-
target_platform_capabilities:
|
46
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]:
|
48
47
|
"""
|
49
48
|
Perform structured pruning on a Keras model to meet a specified target resource utilization.
|
50
49
|
This function prunes the provided model according to the target resource utilization by grouping and pruning
|
@@ -62,7 +61,7 @@ if FOUND_TF:
|
|
62
61
|
target_resource_utilization (ResourceUtilization): The target Key Performance Indicators to be achieved through pruning.
|
63
62
|
representative_data_gen (Callable): A function to generate representative data for pruning analysis.
|
64
63
|
pruning_config (PruningConfig): Configuration settings for the pruning process. Defaults to standard config.
|
65
|
-
target_platform_capabilities (
|
64
|
+
target_platform_capabilities (FrameworkQuantizationCapabilities): Platform-specific constraints and capabilities. Defaults to DEFAULT_KERAS_TPC.
|
66
65
|
|
67
66
|
Returns:
|
68
67
|
Tuple[Model, PruningInfo]: A tuple containing the pruned Keras model and associated pruning information.
|
@@ -16,7 +16,7 @@
|
|
16
16
|
from typing import Callable, Tuple
|
17
17
|
from model_compression_toolkit import get_target_platform_capabilities
|
18
18
|
from model_compression_toolkit.constants import PYTORCH
|
19
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
20
20
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
21
21
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
22
22
|
from model_compression_toolkit.core.common.pruning.pruner import Pruner
|
@@ -25,7 +25,6 @@ from model_compression_toolkit.core.common.pruning.pruning_info import PruningIn
|
|
25
25
|
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
|
26
26
|
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
|
27
27
|
from model_compression_toolkit.logger import Logger
|
28
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
29
28
|
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
30
29
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
31
30
|
|
@@ -38,7 +37,7 @@ if FOUND_TORCH:
|
|
38
37
|
PruningPytorchImplementation
|
39
38
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
40
39
|
from torch.nn import Module
|
41
|
-
from model_compression_toolkit.target_platform_capabilities.
|
40
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
|
42
41
|
AttachTpcToPytorch
|
43
42
|
|
44
43
|
# Set the default Target Platform Capabilities (TPC) for PyTorch.
|
@@ -48,7 +47,7 @@ if FOUND_TORCH:
|
|
48
47
|
target_resource_utilization: ResourceUtilization,
|
49
48
|
representative_data_gen: Callable,
|
50
49
|
pruning_config: PruningConfig = PruningConfig(),
|
51
|
-
target_platform_capabilities:
|
50
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYOTRCH_TPC) -> \
|
52
51
|
Tuple[Module, PruningInfo]:
|
53
52
|
"""
|
54
53
|
Perform structured pruning on a Pytorch model to meet a specified target resource utilization.
|
@@ -121,12 +120,12 @@ if FOUND_TORCH:
|
|
121
120
|
|
122
121
|
# Attach TPC to framework
|
123
122
|
attach2pytorch = AttachTpcToPytorch()
|
124
|
-
|
123
|
+
framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities)
|
125
124
|
|
126
125
|
# Convert the original Pytorch model to an internal graph representation.
|
127
126
|
float_graph = read_model_to_graph(model,
|
128
127
|
representative_data_gen,
|
129
|
-
|
128
|
+
framework_platform_capabilities,
|
130
129
|
DEFAULT_PYTORCH_INFO,
|
131
130
|
fw_impl)
|
132
131
|
|
@@ -143,7 +142,7 @@ if FOUND_TORCH:
|
|
143
142
|
target_resource_utilization,
|
144
143
|
representative_data_gen,
|
145
144
|
pruning_config,
|
146
|
-
|
145
|
+
framework_platform_capabilities)
|
147
146
|
|
148
147
|
# Apply the pruning process.
|
149
148
|
pruned_graph = pruner.prune_graph()
|
@@ -22,17 +22,18 @@ from model_compression_toolkit.core.common.quantization.quantize_graph_weights i
|
|
22
22
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
23
23
|
from model_compression_toolkit.logger import Logger
|
24
24
|
from model_compression_toolkit.constants import TENSORFLOW
|
25
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
25
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
26
26
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
27
27
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
28
28
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
29
29
|
MixedPrecisionQuantizationConfig
|
30
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
31
30
|
from model_compression_toolkit.core.runner import core_runner
|
32
31
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
33
32
|
from model_compression_toolkit.metadata import create_model_metadata
|
34
33
|
|
35
34
|
if FOUND_TF:
|
35
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
|
36
|
+
AttachTpcToKeras
|
36
37
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
37
38
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
38
39
|
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
@@ -42,8 +43,6 @@ if FOUND_TF:
|
|
42
43
|
|
43
44
|
from model_compression_toolkit import get_target_platform_capabilities
|
44
45
|
from mct_quantizers.keras.metadata import add_metadata
|
45
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
|
46
|
-
AttachTpcToKeras
|
47
46
|
|
48
47
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
49
48
|
|
@@ -52,7 +51,7 @@ if FOUND_TF:
|
|
52
51
|
representative_data_gen: Callable,
|
53
52
|
target_resource_utilization: ResourceUtilization = None,
|
54
53
|
core_config: CoreConfig = CoreConfig(),
|
55
|
-
target_platform_capabilities:
|
54
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
|
56
55
|
"""
|
57
56
|
Quantize a trained Keras model using post-training quantization. The model is quantized using a
|
58
57
|
symmetric constraint quantization thresholds (power of two).
|
@@ -139,7 +138,7 @@ if FOUND_TF:
|
|
139
138
|
fw_impl = KerasImplementation()
|
140
139
|
|
141
140
|
attach2keras = AttachTpcToKeras()
|
142
|
-
|
141
|
+
framework_platform_capabilities = attach2keras.attach(
|
143
142
|
target_platform_capabilities,
|
144
143
|
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
|
145
144
|
|
@@ -149,7 +148,7 @@ if FOUND_TF:
|
|
149
148
|
core_config=core_config,
|
150
149
|
fw_info=fw_info,
|
151
150
|
fw_impl=fw_impl,
|
152
|
-
|
151
|
+
fqc=framework_platform_capabilities,
|
153
152
|
target_resource_utilization=target_resource_utilization,
|
154
153
|
tb_w=tb_w)
|
155
154
|
|
@@ -177,9 +176,9 @@ if FOUND_TF:
|
|
177
176
|
fw_info)
|
178
177
|
|
179
178
|
exportable_model, user_info = get_exportable_keras_model(graph_with_stats_correction)
|
180
|
-
if
|
179
|
+
if framework_platform_capabilities.tpc.add_metadata:
|
181
180
|
exportable_model = add_metadata(exportable_model,
|
182
|
-
create_model_metadata(
|
181
|
+
create_model_metadata(fqc=framework_platform_capabilities,
|
183
182
|
scheduling_info=scheduling_info))
|
184
183
|
return exportable_model, user_info
|
185
184
|
|
@@ -19,9 +19,8 @@ from typing import Callable
|
|
19
19
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
20
20
|
from model_compression_toolkit.logger import Logger
|
21
21
|
from model_compression_toolkit.constants import PYTORCH
|
22
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
23
23
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
24
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
25
24
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
26
25
|
from model_compression_toolkit.core import CoreConfig
|
27
26
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
@@ -40,7 +39,7 @@ if FOUND_TORCH:
|
|
40
39
|
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
|
41
40
|
from model_compression_toolkit import get_target_platform_capabilities
|
42
41
|
from mct_quantizers.pytorch.metadata import add_metadata
|
43
|
-
from model_compression_toolkit.target_platform_capabilities.
|
42
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
|
44
43
|
AttachTpcToPytorch
|
45
44
|
|
46
45
|
DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
@@ -49,11 +48,11 @@ if FOUND_TORCH:
|
|
49
48
|
representative_data_gen: Callable,
|
50
49
|
target_resource_utilization: ResourceUtilization = None,
|
51
50
|
core_config: CoreConfig = CoreConfig(),
|
52
|
-
target_platform_capabilities:
|
51
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
|
53
52
|
"""
|
54
53
|
Quantize a trained Pytorch module using post-training quantization.
|
55
54
|
By default, the module is quantized using a symmetric constraint quantization thresholds
|
56
|
-
(power of two) as defined in the default
|
55
|
+
(power of two) as defined in the default FrameworkQuantizationCapabilities.
|
57
56
|
The module is first optimized using several transformations (e.g. BatchNormalization folding to
|
58
57
|
preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
|
59
58
|
being collected for each layer's output (and input, depends on the quantization configuration).
|
@@ -112,7 +111,7 @@ if FOUND_TORCH:
|
|
112
111
|
|
113
112
|
# Attach tpc model to framework
|
114
113
|
attach2pytorch = AttachTpcToPytorch()
|
115
|
-
|
114
|
+
framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
|
116
115
|
core_config.quantization_config.custom_tpc_opset_to_layer)
|
117
116
|
|
118
117
|
# Ignore hessian info service as it is not used here yet.
|
@@ -121,7 +120,7 @@ if FOUND_TORCH:
|
|
121
120
|
core_config=core_config,
|
122
121
|
fw_info=fw_info,
|
123
122
|
fw_impl=fw_impl,
|
124
|
-
|
123
|
+
fqc=framework_platform_capabilities,
|
125
124
|
target_resource_utilization=target_resource_utilization,
|
126
125
|
tb_w=tb_w)
|
127
126
|
|
@@ -149,9 +148,9 @@ if FOUND_TORCH:
|
|
149
148
|
fw_info)
|
150
149
|
|
151
150
|
exportable_model, user_info = get_exportable_pytorch_model(graph_with_stats_correction)
|
152
|
-
if
|
151
|
+
if framework_platform_capabilities.tpc.add_metadata:
|
153
152
|
exportable_model = add_metadata(exportable_model,
|
154
|
-
create_model_metadata(
|
153
|
+
create_model_metadata(fqc=framework_platform_capabilities,
|
155
154
|
scheduling_info=scheduling_info))
|
156
155
|
return exportable_model, user_info
|
157
156
|
|
@@ -19,13 +19,14 @@ from functools import partial
|
|
19
19
|
from model_compression_toolkit.core import CoreConfig
|
20
20
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
21
21
|
from model_compression_toolkit.logger import Logger
|
22
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
23
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
|
24
|
+
AttachTpcToKeras
|
23
25
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
24
26
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
25
27
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
26
28
|
MixedPrecisionQuantizationConfig
|
27
29
|
from mct_quantizers import KerasActivationQuantizationHolder
|
28
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
29
30
|
from model_compression_toolkit.core.runner import core_runner
|
30
31
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
31
32
|
|
@@ -55,8 +56,6 @@ if FOUND_TF:
|
|
55
56
|
from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder, \
|
56
57
|
get_activation_quantizer_holder
|
57
58
|
from model_compression_toolkit.qat.common.qat_config import QATConfig
|
58
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
|
59
|
-
AttachTpcToKeras
|
60
59
|
|
61
60
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
62
61
|
|
@@ -93,7 +92,7 @@ if FOUND_TF:
|
|
93
92
|
target_resource_utilization: ResourceUtilization = None,
|
94
93
|
core_config: CoreConfig = CoreConfig(),
|
95
94
|
qat_config: QATConfig = QATConfig(),
|
96
|
-
target_platform_capabilities:
|
95
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
|
97
96
|
"""
|
98
97
|
Prepare a trained Keras model for quantization aware training. First the model quantization is optimized
|
99
98
|
with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
|
@@ -200,7 +199,7 @@ if FOUND_TF:
|
|
200
199
|
core_config=core_config,
|
201
200
|
fw_info=DEFAULT_KERAS_INFO,
|
202
201
|
fw_impl=fw_impl,
|
203
|
-
|
202
|
+
fqc=target_platform_capabilities,
|
204
203
|
target_resource_utilization=target_resource_utilization,
|
205
204
|
tb_w=tb_w)
|
206
205
|
|
@@ -21,7 +21,7 @@ from tensorflow.python.framework.tensor_shape import TensorShape
|
|
21
21
|
|
22
22
|
from model_compression_toolkit.trainable_infrastructure import TrainingMethod
|
23
23
|
|
24
|
-
from
|
24
|
+
from mct_quantizers import QuantizationMethod
|
25
25
|
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
26
26
|
from mct_quantizers import QuantizationTarget, mark_quantizer
|
27
27
|
from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.trainable_infrastructure.common.constants import
|
|
22
22
|
|
23
23
|
from model_compression_toolkit.trainable_infrastructure import TrainingMethod
|
24
24
|
|
25
|
-
from
|
25
|
+
from mct_quantizers import QuantizationMethod
|
26
26
|
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
27
27
|
from mct_quantizers import QuantizationTarget, mark_quantizer
|
28
28
|
from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
|
@@ -12,13 +12,12 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
import copy
|
16
15
|
from typing import Callable
|
17
16
|
from functools import partial
|
18
17
|
|
19
18
|
from model_compression_toolkit.constants import PYTORCH
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
21
|
-
from model_compression_toolkit.target_platform_capabilities.
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
|
22
21
|
AttachTpcToPytorch
|
23
22
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
24
23
|
|
@@ -26,12 +25,9 @@ from model_compression_toolkit.core import CoreConfig
|
|
26
25
|
from model_compression_toolkit.core import common
|
27
26
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
28
27
|
from model_compression_toolkit.logger import Logger
|
29
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
30
28
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
31
29
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
32
30
|
MixedPrecisionQuantizationConfig
|
33
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
|
34
|
-
TargetPlatformCapabilities
|
35
31
|
from model_compression_toolkit.core.runner import core_runner
|
36
32
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
37
33
|
|
@@ -82,7 +78,7 @@ if FOUND_TORCH:
|
|
82
78
|
target_resource_utilization: ResourceUtilization = None,
|
83
79
|
core_config: CoreConfig = CoreConfig(),
|
84
80
|
qat_config: QATConfig = QATConfig(),
|
85
|
-
target_platform_capabilities:
|
81
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
|
86
82
|
"""
|
87
83
|
Prepare a trained Pytorch model for quantization aware training. First the model quantization is optimized
|
88
84
|
with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
|
@@ -159,7 +155,7 @@ if FOUND_TORCH:
|
|
159
155
|
|
160
156
|
# Attach tpc model to framework
|
161
157
|
attach2pytorch = AttachTpcToPytorch()
|
162
|
-
|
158
|
+
framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
|
163
159
|
core_config.quantization_config.custom_tpc_opset_to_layer)
|
164
160
|
|
165
161
|
# Ignore hessian scores service as we do not use it here
|
@@ -168,7 +164,7 @@ if FOUND_TORCH:
|
|
168
164
|
core_config=core_config,
|
169
165
|
fw_info=DEFAULT_PYTORCH_INFO,
|
170
166
|
fw_impl=fw_impl,
|
171
|
-
|
167
|
+
fqc=framework_platform_capabilities,
|
172
168
|
target_resource_utilization=target_resource_utilization,
|
173
169
|
tb_w=tb_w)
|
174
170
|
|
@@ -18,7 +18,7 @@ import numpy as np
|
|
18
18
|
import torch
|
19
19
|
import torch.nn as nn
|
20
20
|
|
21
|
-
from
|
21
|
+
from mct_quantizers import QuantizationMethod
|
22
22
|
from mct_quantizers import PytorchQuantizationWrapper
|
23
23
|
from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
|
24
24
|
from model_compression_toolkit import constants as C
|
@@ -28,7 +28,7 @@ from model_compression_toolkit.trainable_infrastructure.pytorch.quantizer_utils
|
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
29
29
|
from model_compression_toolkit.trainable_infrastructure.common.trainable_quantizer_config import \
|
30
30
|
TrainableQuantizerWeightsConfig
|
31
|
-
from
|
31
|
+
from mct_quantizers import QuantizationMethod
|
32
32
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
33
33
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import fix_range_to_include_zero
|
34
34
|
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_weight_quantizer import BasePytorchQATWeightTrainableQuantizer
|
@@ -18,7 +18,7 @@ import numpy as np
|
|
18
18
|
import torch
|
19
19
|
import torch.nn as nn
|
20
20
|
|
21
|
-
from
|
21
|
+
from mct_quantizers import QuantizationMethod
|
22
22
|
from mct_quantizers import PytorchQuantizationWrapper
|
23
23
|
from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
|
24
24
|
from model_compression_toolkit import constants as C
|
@@ -20,7 +20,7 @@ from torch import Tensor
|
|
20
20
|
from model_compression_toolkit.constants import RANGE_MAX, RANGE_MIN
|
21
21
|
from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
|
22
22
|
|
23
|
-
from
|
23
|
+
from mct_quantizers import QuantizationMethod
|
24
24
|
from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
|
25
25
|
|
26
26
|
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_weight_quantizer import BasePytorchQATWeightTrainableQuantizer
|
@@ -12,3 +12,12 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
|
16
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import AttributeFilter
|
17
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import (
|
18
|
+
FrameworkQuantizationCapabilities, OperationsSetToLayers, Smaller, SmallerEq, NotEq, Eq, GreaterEq, Greater,
|
19
|
+
LayerFilterParams, OperationsToLayers, get_current_tpc)
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities, OperatorsSet, \
|
21
|
+
OperatorSetGroup, Signedness, AttributeQuantizationConfig, OpQuantizationConfig, QuantizationConfigOptions, Fusing
|
22
|
+
|
23
|
+
from mct_quantizers import QuantizationMethod
|
@@ -7,6 +7,6 @@ OpQuantizationConfig = schema.OpQuantizationConfig
|
|
7
7
|
QuantizationConfigOptions = schema.QuantizationConfigOptions
|
8
8
|
OperatorsSetBase = schema.OperatorsSetBase
|
9
9
|
OperatorsSet = schema.OperatorsSet
|
10
|
-
|
10
|
+
OperatorSetGroup = schema.OperatorSetGroup
|
11
11
|
Fusing = schema.Fusing
|
12
|
-
|
12
|
+
TargetPlatformCapabilities = schema.TargetPlatformCapabilities
|
@@ -16,7 +16,7 @@ from logging import Logger
|
|
16
16
|
from typing import Optional
|
17
17
|
|
18
18
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
|
19
|
-
|
19
|
+
TargetPlatformCapabilities, QuantizationConfigOptions, OperatorsSetBase
|
20
20
|
|
21
21
|
|
22
22
|
def max_input_activation_n_bits(op_quantization_config: OpQuantizationConfig) -> int:
|
@@ -32,31 +32,31 @@ def max_input_activation_n_bits(op_quantization_config: OpQuantizationConfig) ->
|
|
32
32
|
return max(op_quantization_config.supported_input_activation_n_bits)
|
33
33
|
|
34
34
|
|
35
|
-
def get_config_options_by_operators_set(
|
35
|
+
def get_config_options_by_operators_set(tpc: TargetPlatformCapabilities,
|
36
36
|
operators_set_name: str) -> QuantizationConfigOptions:
|
37
37
|
"""
|
38
38
|
Get the QuantizationConfigOptions of an OperatorsSet by its name.
|
39
39
|
|
40
40
|
Args:
|
41
|
-
|
41
|
+
tpc (TargetPlatformCapabilities): The target platform model containing the operator sets and their configurations.
|
42
42
|
operators_set_name (str): The name of the OperatorsSet whose quantization configuration options are to be retrieved.
|
43
43
|
|
44
44
|
Returns:
|
45
45
|
QuantizationConfigOptions: The quantization configuration options associated with the specified OperatorsSet,
|
46
46
|
or the default quantization configuration options if the OperatorsSet is not found.
|
47
47
|
"""
|
48
|
-
for op_set in
|
48
|
+
for op_set in tpc.operator_set:
|
49
49
|
if operators_set_name == op_set.name:
|
50
50
|
return op_set.qc_options
|
51
|
-
return
|
51
|
+
return tpc.default_qco
|
52
52
|
|
53
53
|
|
54
|
-
def get_default_op_quantization_config(
|
54
|
+
def get_default_op_quantization_config(tpc: TargetPlatformCapabilities) -> OpQuantizationConfig:
|
55
55
|
"""
|
56
|
-
Get the default OpQuantizationConfig of the
|
56
|
+
Get the default OpQuantizationConfig of the TargetPlatformCapabilities.
|
57
57
|
|
58
58
|
Args:
|
59
|
-
|
59
|
+
tpc (TargetPlatformCapabilities): The target platform model containing the default quantization configuration.
|
60
60
|
|
61
61
|
Returns:
|
62
62
|
OpQuantizationConfig: The default quantization configuration.
|
@@ -64,32 +64,32 @@ def get_default_op_quantization_config(tp_model: TargetPlatformModel) -> OpQuant
|
|
64
64
|
Raises:
|
65
65
|
AssertionError: If the default quantization configuration list contains more than one configuration option.
|
66
66
|
"""
|
67
|
-
assert len(
|
67
|
+
assert len(tpc.default_qco.quantization_configurations) == 1, \
|
68
68
|
f"Default quantization configuration options must contain only one option, " \
|
69
|
-
f"but found {len(
|
70
|
-
return
|
69
|
+
f"but found {len(tpc.default_qco.quantization_configurations)} configurations." # pragma: no cover
|
70
|
+
return tpc.default_qco.quantization_configurations[0]
|
71
71
|
|
72
72
|
|
73
|
-
def is_opset_in_model(
|
73
|
+
def is_opset_in_model(tpc: TargetPlatformCapabilities, opset_name: str) -> bool:
|
74
74
|
"""
|
75
75
|
Check whether an OperatorsSet is defined in the model.
|
76
76
|
|
77
77
|
Args:
|
78
|
-
|
78
|
+
tpc (TargetPlatformCapabilities): The target platform model containing the list of operator sets.
|
79
79
|
opset_name (str): The name of the OperatorsSet to check for existence.
|
80
80
|
|
81
81
|
Returns:
|
82
82
|
bool: True if an OperatorsSet with the given name exists in the target platform model,
|
83
83
|
otherwise False.
|
84
84
|
"""
|
85
|
-
return
|
85
|
+
return tpc.operator_set is not None and opset_name in [x.name for x in tpc.operator_set]
|
86
86
|
|
87
|
-
def get_opset_by_name(
|
87
|
+
def get_opset_by_name(tpc: TargetPlatformCapabilities, opset_name: str) -> Optional[OperatorsSetBase]:
|
88
88
|
"""
|
89
89
|
Get an OperatorsSet object from the model by its name.
|
90
90
|
|
91
91
|
Args:
|
92
|
-
|
92
|
+
tpc (TargetPlatformCapabilities): The target platform model containing the list of operator sets.
|
93
93
|
opset_name (str): The name of the OperatorsSet to be retrieved.
|
94
94
|
|
95
95
|
Returns:
|
@@ -99,7 +99,7 @@ def get_opset_by_name(tp_model: TargetPlatformModel, opset_name: str) -> Optiona
|
|
99
99
|
Raises:
|
100
100
|
A critical log message if multiple operator sets with the same name are found.
|
101
101
|
"""
|
102
|
-
opset_list = [x for x in
|
102
|
+
opset_list = [x for x in tpc.operator_set if x.name == opset_name]
|
103
103
|
if len(opset_list) > 1:
|
104
|
-
Logger.critical(f"Found more than one OperatorsSet in
|
104
|
+
Logger.critical(f"Found more than one OperatorsSet in TargetPlatformCapabilities with the name {opset_name}.") # pragma: no cover
|
105
105
|
return opset_list[0] if opset_list else None
|
@@ -414,7 +414,7 @@ class QuantizationConfigOptions(BaseModel):
|
|
414
414
|
|
415
415
|
class TargetPlatformModelComponent(BaseModel):
|
416
416
|
"""
|
417
|
-
Component of
|
417
|
+
Component of TargetPlatformCapabilities (Fusing, OperatorsSet, etc.).
|
418
418
|
"""
|
419
419
|
class Config:
|
420
420
|
frozen = True
|
@@ -433,7 +433,7 @@ class OperatorsSet(OperatorsSetBase):
|
|
433
433
|
Set of operators that are represented by a unique label.
|
434
434
|
|
435
435
|
Attributes:
|
436
|
-
name (Union[str, OperatorSetNames]): The set's label (must be unique within a
|
436
|
+
name (Union[str, OperatorSetNames]): The set's label (must be unique within a TargetPlatformCapabilities).
|
437
437
|
qc_options (Optional[QuantizationConfigOptions]): Configuration options to use for this set of operations.
|
438
438
|
If None, it represents a fusing set.
|
439
439
|
type (Literal["OperatorsSet"]): Fixed type identifier.
|
@@ -457,7 +457,7 @@ class OperatorsSet(OperatorsSetBase):
|
|
457
457
|
return {"name": self.name}
|
458
458
|
|
459
459
|
|
460
|
-
class
|
460
|
+
class OperatorSetGroup(OperatorsSetBase):
|
461
461
|
"""
|
462
462
|
Concatenate a tuple of operator sets to treat them similarly in different places (like fusing).
|
463
463
|
|
@@ -469,7 +469,7 @@ class OperatorSetConcat(OperatorsSetBase):
|
|
469
469
|
name: Optional[str] = None # Will be set in the validator if not given
|
470
470
|
|
471
471
|
# Define a private attribute _type
|
472
|
-
type: Literal["
|
472
|
+
type: Literal["OperatorSetGroup"] = "OperatorSetGroup"
|
473
473
|
|
474
474
|
class Config:
|
475
475
|
frozen = True
|
@@ -518,11 +518,11 @@ class Fusing(TargetPlatformModelComponent):
|
|
518
518
|
hence no quantization is applied between them.
|
519
519
|
|
520
520
|
Attributes:
|
521
|
-
operator_groups (Tuple[Union[OperatorsSet,
|
522
|
-
each being either an
|
521
|
+
operator_groups (Tuple[Union[OperatorsSet, OperatorSetGroup], ...]): A tuple of operator groups,
|
522
|
+
each being either an OperatorSetGroup or an OperatorsSet.
|
523
523
|
name (Optional[str]): The name for the Fusing instance. If not provided, it is generated from the operator groups' names.
|
524
524
|
"""
|
525
|
-
operator_groups: Tuple[Annotated[Union[OperatorsSet,
|
525
|
+
operator_groups: Tuple[Annotated[Union[OperatorsSet, OperatorSetGroup], Field(discriminator='type')], ...]
|
526
526
|
name: Optional[str] = None # Will be set in the validator if not given.
|
527
527
|
|
528
528
|
class Config:
|
@@ -591,7 +591,7 @@ class Fusing(TargetPlatformModelComponent):
|
|
591
591
|
for i in range(len(self.operator_groups) - len(other.operator_groups) + 1):
|
592
592
|
for j in range(len(other.operator_groups)):
|
593
593
|
if self.operator_groups[i + j] != other.operator_groups[j] and not (
|
594
|
-
isinstance(self.operator_groups[i + j],
|
594
|
+
isinstance(self.operator_groups[i + j], OperatorSetGroup) and (
|
595
595
|
other.operator_groups[j] in self.operator_groups[i + j].operators_set)):
|
596
596
|
break
|
597
597
|
else:
|
@@ -621,7 +621,7 @@ class Fusing(TargetPlatformModelComponent):
|
|
621
621
|
for x in self.operator_groups
|
622
622
|
])
|
623
623
|
|
624
|
-
class
|
624
|
+
class TargetPlatformCapabilities(BaseModel):
|
625
625
|
"""
|
626
626
|
Represents the hardware configuration used for quantized model inference.
|
627
627
|
|
@@ -644,7 +644,7 @@ class TargetPlatformModel(BaseModel):
|
|
644
644
|
tpc_patch_version: Optional[int]
|
645
645
|
tpc_platform_type: Optional[str]
|
646
646
|
add_metadata: bool = True
|
647
|
-
name: Optional[str] = "
|
647
|
+
name: Optional[str] = "default_tpc"
|
648
648
|
is_simd_padding: bool = False
|
649
649
|
|
650
650
|
SCHEMA_VERSION: int = 1
|
@@ -682,10 +682,10 @@ class TargetPlatformModel(BaseModel):
|
|
682
682
|
|
683
683
|
def get_info(self) -> Dict[str, Any]:
|
684
684
|
"""
|
685
|
-
Get a dictionary summarizing the
|
685
|
+
Get a dictionary summarizing the TargetPlatformCapabilities properties.
|
686
686
|
|
687
687
|
Returns:
|
688
|
-
Dict[str, Any]: Summary of the
|
688
|
+
Dict[str, Any]: Summary of the TargetPlatformCapabilities properties.
|
689
689
|
"""
|
690
690
|
return {
|
691
691
|
"Model name": self.name,
|
@@ -695,6 +695,6 @@ class TargetPlatformModel(BaseModel):
|
|
695
695
|
|
696
696
|
def show(self):
|
697
697
|
"""
|
698
|
-
Display the
|
698
|
+
Display the TargetPlatformCapabilities.
|
699
699
|
"""
|
700
700
|
pprint.pprint(self.get_info(), sort_dicts=False)
|