mct-nightly 2.2.0.20250113.527__py3-none-any.whl → 2.2.0.20250114.84821__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/RECORD +103 -105
- model_compression_toolkit/__init__.py +2 -2
- model_compression_toolkit/core/common/framework_info.py +1 -3
- model_compression_toolkit/core/common/fusion/layer_fusing.py +6 -5
- model_compression_toolkit/core/common/graph/base_graph.py +20 -21
- model_compression_toolkit/core/common/graph/base_node.py +44 -17
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +7 -6
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +0 -6
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +26 -135
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +36 -62
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +667 -0
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +25 -202
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py +164 -470
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +30 -7
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +7 -6
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +0 -1
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +0 -1
- model_compression_toolkit/core/common/pruning/pruner.py +5 -3
- model_compression_toolkit/core/common/quantization/bit_width_config.py +6 -12
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -2
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +15 -14
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +1 -1
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/graph_prep_runner.py +12 -11
- model_compression_toolkit/core/keras/data_util.py +24 -5
- model_compression_toolkit/core/keras/default_framework_info.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +1 -2
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +5 -6
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +4 -5
- model_compression_toolkit/core/runner.py +33 -60
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +1 -1
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantization_facade.py +8 -9
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +8 -9
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/metadata.py +11 -10
- model_compression_toolkit/pruning/keras/pruning_facade.py +5 -6
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +6 -7
- model_compression_toolkit/ptq/keras/quantization_facade.py +8 -9
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -9
- model_compression_toolkit/qat/keras/quantization_facade.py +5 -6
- model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py +1 -1
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +5 -9
- model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +1 -1
- model_compression_toolkit/target_platform_capabilities/__init__.py +9 -0
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +2 -2
- model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +18 -18
- model_compression_toolkit/target_platform_capabilities/schema/v1.py +13 -13
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/__init__.py +6 -6
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2fw.py +10 -10
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2keras.py +3 -3
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2pytorch.py +3 -2
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/current_tpc.py +8 -8
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities.py → targetplatform2framework/framework_quantization_capabilities.py} +40 -40
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities_component.py → targetplatform2framework/framework_quantization_capabilities_component.py} +2 -2
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/layer_filter_params.py +0 -1
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/operations_to_layers.py +8 -8
- model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +24 -24
- model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +18 -18
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/{tp_model.py → tpc.py} +31 -32
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/{tp_model.py → tpc.py} +27 -27
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +4 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/{tp_model.py → tpc.py} +27 -27
- model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +1 -2
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +2 -1
- model_compression_toolkit/trainable_infrastructure/keras/activation_quantizers/lsq/symmetric_lsq.py +1 -2
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/xquant/common/model_folding_utils.py +7 -6
- model_compression_toolkit/xquant/keras/keras_report_utils.py +4 -4
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +0 -105
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +0 -33
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +0 -23
- {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attribute_filter.py +0 -0
@@ -18,33 +18,34 @@ from typing import Dict, Any
|
|
18
18
|
from model_compression_toolkit.constants import OPERATORS_SCHEDULING, FUSED_NODES_MAPPING, CUTS, MAX_CUT, OP_ORDER, \
|
19
19
|
OP_RECORD, SHAPE, NODE_OUTPUT_INDEX, NODE_NAME, TOTAL_SIZE, MEM_ELEMENTS
|
20
20
|
from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import SchedulerInfo
|
21
|
-
from model_compression_toolkit.target_platform_capabilities.
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
22
|
+
FrameworkQuantizationCapabilities
|
22
23
|
|
23
24
|
|
24
|
-
def create_model_metadata(
|
25
|
+
def create_model_metadata(fqc: FrameworkQuantizationCapabilities,
|
25
26
|
scheduling_info: SchedulerInfo = None) -> Dict:
|
26
27
|
"""
|
27
28
|
Creates and returns a metadata dictionary for the model, including version information
|
28
29
|
and optional scheduling information.
|
29
30
|
|
30
31
|
Args:
|
31
|
-
|
32
|
+
fqc: A FQC object to get the version.
|
32
33
|
scheduling_info: An object containing scheduling details and metadata. Default is None.
|
33
34
|
|
34
35
|
Returns:
|
35
36
|
Dict: A dictionary containing the model's version information and optional scheduling information.
|
36
37
|
"""
|
37
|
-
_metadata = get_versions_dict(
|
38
|
+
_metadata = get_versions_dict(fqc)
|
38
39
|
if scheduling_info:
|
39
40
|
scheduler_metadata = get_scheduler_metadata(scheduler_info=scheduling_info)
|
40
41
|
_metadata['scheduling_info'] = scheduler_metadata
|
41
42
|
return _metadata
|
42
43
|
|
43
44
|
|
44
|
-
def get_versions_dict(
|
45
|
+
def get_versions_dict(fqc) -> Dict:
|
45
46
|
"""
|
46
47
|
|
47
|
-
Returns: A dictionary with
|
48
|
+
Returns: A dictionary with FQC, MCT and FQC-Schema versions.
|
48
49
|
|
49
50
|
"""
|
50
51
|
# imported inside to avoid circular import error
|
@@ -53,10 +54,10 @@ def get_versions_dict(tpc) -> Dict:
|
|
53
54
|
@dataclass
|
54
55
|
class TPCVersions:
|
55
56
|
mct_version: str
|
56
|
-
tpc_minor_version: str = f'{tpc.
|
57
|
-
tpc_patch_version: str = f'{tpc.
|
58
|
-
tpc_platform_type: str = f'{tpc.
|
59
|
-
tpc_schema: str = f'{tpc.
|
57
|
+
tpc_minor_version: str = f'{fqc.tpc.tpc_minor_version}'
|
58
|
+
tpc_patch_version: str = f'{fqc.tpc.tpc_patch_version}'
|
59
|
+
tpc_platform_type: str = f'{fqc.tpc.tpc_platform_type}'
|
60
|
+
tpc_schema: str = f'{fqc.tpc.SCHEMA_VERSION}'
|
60
61
|
|
61
62
|
return asdict(TPCVersions(mct_version))
|
62
63
|
|
@@ -17,7 +17,7 @@ from typing import Callable, Tuple
|
|
17
17
|
|
18
18
|
from model_compression_toolkit import get_target_platform_capabilities
|
19
19
|
from model_compression_toolkit.constants import TENSORFLOW
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
21
21
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
22
22
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
23
23
|
from model_compression_toolkit.core.common.pruning.pruner import Pruner
|
@@ -26,17 +26,16 @@ from model_compression_toolkit.core.common.pruning.pruning_info import PruningIn
|
|
26
26
|
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
|
27
27
|
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
|
28
28
|
from model_compression_toolkit.logger import Logger
|
29
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
30
29
|
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
31
30
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
32
31
|
|
33
32
|
if FOUND_TF:
|
33
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
|
34
|
+
AttachTpcToKeras
|
34
35
|
from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
|
35
36
|
from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation
|
36
37
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
37
38
|
from tensorflow.keras.models import Model
|
38
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
|
39
|
-
AttachTpcToKeras
|
40
39
|
|
41
40
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
42
41
|
|
@@ -44,7 +43,7 @@ if FOUND_TF:
|
|
44
43
|
target_resource_utilization: ResourceUtilization,
|
45
44
|
representative_data_gen: Callable,
|
46
45
|
pruning_config: PruningConfig = PruningConfig(),
|
47
|
-
target_platform_capabilities:
|
46
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]:
|
48
47
|
"""
|
49
48
|
Perform structured pruning on a Keras model to meet a specified target resource utilization.
|
50
49
|
This function prunes the provided model according to the target resource utilization by grouping and pruning
|
@@ -62,7 +61,7 @@ if FOUND_TF:
|
|
62
61
|
target_resource_utilization (ResourceUtilization): The target Key Performance Indicators to be achieved through pruning.
|
63
62
|
representative_data_gen (Callable): A function to generate representative data for pruning analysis.
|
64
63
|
pruning_config (PruningConfig): Configuration settings for the pruning process. Defaults to standard config.
|
65
|
-
target_platform_capabilities (
|
64
|
+
target_platform_capabilities (FrameworkQuantizationCapabilities): Platform-specific constraints and capabilities. Defaults to DEFAULT_KERAS_TPC.
|
66
65
|
|
67
66
|
Returns:
|
68
67
|
Tuple[Model, PruningInfo]: A tuple containing the pruned Keras model and associated pruning information.
|
@@ -16,7 +16,7 @@
|
|
16
16
|
from typing import Callable, Tuple
|
17
17
|
from model_compression_toolkit import get_target_platform_capabilities
|
18
18
|
from model_compression_toolkit.constants import PYTORCH
|
19
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
20
20
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
21
21
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
22
22
|
from model_compression_toolkit.core.common.pruning.pruner import Pruner
|
@@ -25,7 +25,6 @@ from model_compression_toolkit.core.common.pruning.pruning_info import PruningIn
|
|
25
25
|
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
|
26
26
|
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
|
27
27
|
from model_compression_toolkit.logger import Logger
|
28
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
29
28
|
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
30
29
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
31
30
|
|
@@ -38,7 +37,7 @@ if FOUND_TORCH:
|
|
38
37
|
PruningPytorchImplementation
|
39
38
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
40
39
|
from torch.nn import Module
|
41
|
-
from model_compression_toolkit.target_platform_capabilities.
|
40
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
|
42
41
|
AttachTpcToPytorch
|
43
42
|
|
44
43
|
# Set the default Target Platform Capabilities (TPC) for PyTorch.
|
@@ -48,7 +47,7 @@ if FOUND_TORCH:
|
|
48
47
|
target_resource_utilization: ResourceUtilization,
|
49
48
|
representative_data_gen: Callable,
|
50
49
|
pruning_config: PruningConfig = PruningConfig(),
|
51
|
-
target_platform_capabilities:
|
50
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYOTRCH_TPC) -> \
|
52
51
|
Tuple[Module, PruningInfo]:
|
53
52
|
"""
|
54
53
|
Perform structured pruning on a Pytorch model to meet a specified target resource utilization.
|
@@ -121,12 +120,12 @@ if FOUND_TORCH:
|
|
121
120
|
|
122
121
|
# Attach TPC to framework
|
123
122
|
attach2pytorch = AttachTpcToPytorch()
|
124
|
-
|
123
|
+
framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities)
|
125
124
|
|
126
125
|
# Convert the original Pytorch model to an internal graph representation.
|
127
126
|
float_graph = read_model_to_graph(model,
|
128
127
|
representative_data_gen,
|
129
|
-
|
128
|
+
framework_platform_capabilities,
|
130
129
|
DEFAULT_PYTORCH_INFO,
|
131
130
|
fw_impl)
|
132
131
|
|
@@ -143,7 +142,7 @@ if FOUND_TORCH:
|
|
143
142
|
target_resource_utilization,
|
144
143
|
representative_data_gen,
|
145
144
|
pruning_config,
|
146
|
-
|
145
|
+
framework_platform_capabilities)
|
147
146
|
|
148
147
|
# Apply the pruning process.
|
149
148
|
pruned_graph = pruner.prune_graph()
|
@@ -22,17 +22,18 @@ from model_compression_toolkit.core.common.quantization.quantize_graph_weights i
|
|
22
22
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
23
23
|
from model_compression_toolkit.logger import Logger
|
24
24
|
from model_compression_toolkit.constants import TENSORFLOW
|
25
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
25
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
26
26
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
27
27
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
28
28
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
29
29
|
MixedPrecisionQuantizationConfig
|
30
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
31
30
|
from model_compression_toolkit.core.runner import core_runner
|
32
31
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
33
32
|
from model_compression_toolkit.metadata import create_model_metadata
|
34
33
|
|
35
34
|
if FOUND_TF:
|
35
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
|
36
|
+
AttachTpcToKeras
|
36
37
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
37
38
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
38
39
|
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
@@ -42,8 +43,6 @@ if FOUND_TF:
|
|
42
43
|
|
43
44
|
from model_compression_toolkit import get_target_platform_capabilities
|
44
45
|
from mct_quantizers.keras.metadata import add_metadata
|
45
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
|
46
|
-
AttachTpcToKeras
|
47
46
|
|
48
47
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
49
48
|
|
@@ -52,7 +51,7 @@ if FOUND_TF:
|
|
52
51
|
representative_data_gen: Callable,
|
53
52
|
target_resource_utilization: ResourceUtilization = None,
|
54
53
|
core_config: CoreConfig = CoreConfig(),
|
55
|
-
target_platform_capabilities:
|
54
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
|
56
55
|
"""
|
57
56
|
Quantize a trained Keras model using post-training quantization. The model is quantized using a
|
58
57
|
symmetric constraint quantization thresholds (power of two).
|
@@ -139,7 +138,7 @@ if FOUND_TF:
|
|
139
138
|
fw_impl = KerasImplementation()
|
140
139
|
|
141
140
|
attach2keras = AttachTpcToKeras()
|
142
|
-
|
141
|
+
framework_platform_capabilities = attach2keras.attach(
|
143
142
|
target_platform_capabilities,
|
144
143
|
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
|
145
144
|
|
@@ -149,7 +148,7 @@ if FOUND_TF:
|
|
149
148
|
core_config=core_config,
|
150
149
|
fw_info=fw_info,
|
151
150
|
fw_impl=fw_impl,
|
152
|
-
|
151
|
+
fqc=framework_platform_capabilities,
|
153
152
|
target_resource_utilization=target_resource_utilization,
|
154
153
|
tb_w=tb_w)
|
155
154
|
|
@@ -177,9 +176,9 @@ if FOUND_TF:
|
|
177
176
|
fw_info)
|
178
177
|
|
179
178
|
exportable_model, user_info = get_exportable_keras_model(graph_with_stats_correction)
|
180
|
-
if
|
179
|
+
if framework_platform_capabilities.tpc.add_metadata:
|
181
180
|
exportable_model = add_metadata(exportable_model,
|
182
|
-
create_model_metadata(
|
181
|
+
create_model_metadata(fqc=framework_platform_capabilities,
|
183
182
|
scheduling_info=scheduling_info))
|
184
183
|
return exportable_model, user_info
|
185
184
|
|
@@ -19,9 +19,8 @@ from typing import Callable
|
|
19
19
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
20
20
|
from model_compression_toolkit.logger import Logger
|
21
21
|
from model_compression_toolkit.constants import PYTORCH
|
22
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
23
23
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
24
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
25
24
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
26
25
|
from model_compression_toolkit.core import CoreConfig
|
27
26
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
@@ -40,7 +39,7 @@ if FOUND_TORCH:
|
|
40
39
|
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
|
41
40
|
from model_compression_toolkit import get_target_platform_capabilities
|
42
41
|
from mct_quantizers.pytorch.metadata import add_metadata
|
43
|
-
from model_compression_toolkit.target_platform_capabilities.
|
42
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
|
44
43
|
AttachTpcToPytorch
|
45
44
|
|
46
45
|
DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
@@ -49,11 +48,11 @@ if FOUND_TORCH:
|
|
49
48
|
representative_data_gen: Callable,
|
50
49
|
target_resource_utilization: ResourceUtilization = None,
|
51
50
|
core_config: CoreConfig = CoreConfig(),
|
52
|
-
target_platform_capabilities:
|
51
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
|
53
52
|
"""
|
54
53
|
Quantize a trained Pytorch module using post-training quantization.
|
55
54
|
By default, the module is quantized using a symmetric constraint quantization thresholds
|
56
|
-
(power of two) as defined in the default
|
55
|
+
(power of two) as defined in the default FrameworkQuantizationCapabilities.
|
57
56
|
The module is first optimized using several transformations (e.g. BatchNormalization folding to
|
58
57
|
preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
|
59
58
|
being collected for each layer's output (and input, depends on the quantization configuration).
|
@@ -112,7 +111,7 @@ if FOUND_TORCH:
|
|
112
111
|
|
113
112
|
# Attach tpc model to framework
|
114
113
|
attach2pytorch = AttachTpcToPytorch()
|
115
|
-
|
114
|
+
framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
|
116
115
|
core_config.quantization_config.custom_tpc_opset_to_layer)
|
117
116
|
|
118
117
|
# Ignore hessian info service as it is not used here yet.
|
@@ -121,7 +120,7 @@ if FOUND_TORCH:
|
|
121
120
|
core_config=core_config,
|
122
121
|
fw_info=fw_info,
|
123
122
|
fw_impl=fw_impl,
|
124
|
-
|
123
|
+
fqc=framework_platform_capabilities,
|
125
124
|
target_resource_utilization=target_resource_utilization,
|
126
125
|
tb_w=tb_w)
|
127
126
|
|
@@ -149,9 +148,9 @@ if FOUND_TORCH:
|
|
149
148
|
fw_info)
|
150
149
|
|
151
150
|
exportable_model, user_info = get_exportable_pytorch_model(graph_with_stats_correction)
|
152
|
-
if
|
151
|
+
if framework_platform_capabilities.tpc.add_metadata:
|
153
152
|
exportable_model = add_metadata(exportable_model,
|
154
|
-
create_model_metadata(
|
153
|
+
create_model_metadata(fqc=framework_platform_capabilities,
|
155
154
|
scheduling_info=scheduling_info))
|
156
155
|
return exportable_model, user_info
|
157
156
|
|
@@ -19,13 +19,14 @@ from functools import partial
|
|
19
19
|
from model_compression_toolkit.core import CoreConfig
|
20
20
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
21
21
|
from model_compression_toolkit.logger import Logger
|
22
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
22
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
23
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
|
24
|
+
AttachTpcToKeras
|
23
25
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
24
26
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
25
27
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
26
28
|
MixedPrecisionQuantizationConfig
|
27
29
|
from mct_quantizers import KerasActivationQuantizationHolder
|
28
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
29
30
|
from model_compression_toolkit.core.runner import core_runner
|
30
31
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
31
32
|
|
@@ -55,8 +56,6 @@ if FOUND_TF:
|
|
55
56
|
from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder, \
|
56
57
|
get_activation_quantizer_holder
|
57
58
|
from model_compression_toolkit.qat.common.qat_config import QATConfig
|
58
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
|
59
|
-
AttachTpcToKeras
|
60
59
|
|
61
60
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
62
61
|
|
@@ -93,7 +92,7 @@ if FOUND_TF:
|
|
93
92
|
target_resource_utilization: ResourceUtilization = None,
|
94
93
|
core_config: CoreConfig = CoreConfig(),
|
95
94
|
qat_config: QATConfig = QATConfig(),
|
96
|
-
target_platform_capabilities:
|
95
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
|
97
96
|
"""
|
98
97
|
Prepare a trained Keras model for quantization aware training. First the model quantization is optimized
|
99
98
|
with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
|
@@ -200,7 +199,7 @@ if FOUND_TF:
|
|
200
199
|
core_config=core_config,
|
201
200
|
fw_info=DEFAULT_KERAS_INFO,
|
202
201
|
fw_impl=fw_impl,
|
203
|
-
|
202
|
+
fqc=target_platform_capabilities,
|
204
203
|
target_resource_utilization=target_resource_utilization,
|
205
204
|
tb_w=tb_w)
|
206
205
|
|
@@ -21,7 +21,7 @@ from tensorflow.python.framework.tensor_shape import TensorShape
|
|
21
21
|
|
22
22
|
from model_compression_toolkit.trainable_infrastructure import TrainingMethod
|
23
23
|
|
24
|
-
from
|
24
|
+
from mct_quantizers import QuantizationMethod
|
25
25
|
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
26
26
|
from mct_quantizers import QuantizationTarget, mark_quantizer
|
27
27
|
from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.trainable_infrastructure.common.constants import
|
|
22
22
|
|
23
23
|
from model_compression_toolkit.trainable_infrastructure import TrainingMethod
|
24
24
|
|
25
|
-
from
|
25
|
+
from mct_quantizers import QuantizationMethod
|
26
26
|
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
27
27
|
from mct_quantizers import QuantizationTarget, mark_quantizer
|
28
28
|
from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
|
@@ -12,13 +12,12 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
import copy
|
16
15
|
from typing import Callable
|
17
16
|
from functools import partial
|
18
17
|
|
19
18
|
from model_compression_toolkit.constants import PYTORCH
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
21
|
-
from model_compression_toolkit.target_platform_capabilities.
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
|
22
21
|
AttachTpcToPytorch
|
23
22
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
24
23
|
|
@@ -26,12 +25,9 @@ from model_compression_toolkit.core import CoreConfig
|
|
26
25
|
from model_compression_toolkit.core import common
|
27
26
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
28
27
|
from model_compression_toolkit.logger import Logger
|
29
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
30
28
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
31
29
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
32
30
|
MixedPrecisionQuantizationConfig
|
33
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
|
34
|
-
TargetPlatformCapabilities
|
35
31
|
from model_compression_toolkit.core.runner import core_runner
|
36
32
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
37
33
|
|
@@ -82,7 +78,7 @@ if FOUND_TORCH:
|
|
82
78
|
target_resource_utilization: ResourceUtilization = None,
|
83
79
|
core_config: CoreConfig = CoreConfig(),
|
84
80
|
qat_config: QATConfig = QATConfig(),
|
85
|
-
target_platform_capabilities:
|
81
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
|
86
82
|
"""
|
87
83
|
Prepare a trained Pytorch model for quantization aware training. First the model quantization is optimized
|
88
84
|
with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
|
@@ -159,7 +155,7 @@ if FOUND_TORCH:
|
|
159
155
|
|
160
156
|
# Attach tpc model to framework
|
161
157
|
attach2pytorch = AttachTpcToPytorch()
|
162
|
-
|
158
|
+
framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
|
163
159
|
core_config.quantization_config.custom_tpc_opset_to_layer)
|
164
160
|
|
165
161
|
# Ignore hessian scores service as we do not use it here
|
@@ -168,7 +164,7 @@ if FOUND_TORCH:
|
|
168
164
|
core_config=core_config,
|
169
165
|
fw_info=DEFAULT_PYTORCH_INFO,
|
170
166
|
fw_impl=fw_impl,
|
171
|
-
|
167
|
+
fqc=framework_platform_capabilities,
|
172
168
|
target_resource_utilization=target_resource_utilization,
|
173
169
|
tb_w=tb_w)
|
174
170
|
|
@@ -18,7 +18,7 @@ import numpy as np
|
|
18
18
|
import torch
|
19
19
|
import torch.nn as nn
|
20
20
|
|
21
|
-
from
|
21
|
+
from mct_quantizers import QuantizationMethod
|
22
22
|
from mct_quantizers import PytorchQuantizationWrapper
|
23
23
|
from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
|
24
24
|
from model_compression_toolkit import constants as C
|
@@ -28,7 +28,7 @@ from model_compression_toolkit.trainable_infrastructure.pytorch.quantizer_utils
|
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
29
29
|
from model_compression_toolkit.trainable_infrastructure.common.trainable_quantizer_config import \
|
30
30
|
TrainableQuantizerWeightsConfig
|
31
|
-
from
|
31
|
+
from mct_quantizers import QuantizationMethod
|
32
32
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
|
33
33
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import fix_range_to_include_zero
|
34
34
|
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_weight_quantizer import BasePytorchQATWeightTrainableQuantizer
|
@@ -18,7 +18,7 @@ import numpy as np
|
|
18
18
|
import torch
|
19
19
|
import torch.nn as nn
|
20
20
|
|
21
|
-
from
|
21
|
+
from mct_quantizers import QuantizationMethod
|
22
22
|
from mct_quantizers import PytorchQuantizationWrapper
|
23
23
|
from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
|
24
24
|
from model_compression_toolkit import constants as C
|
@@ -20,7 +20,7 @@ from torch import Tensor
|
|
20
20
|
from model_compression_toolkit.constants import RANGE_MAX, RANGE_MIN
|
21
21
|
from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
|
22
22
|
|
23
|
-
from
|
23
|
+
from mct_quantizers import QuantizationMethod
|
24
24
|
from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
|
25
25
|
|
26
26
|
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_weight_quantizer import BasePytorchQATWeightTrainableQuantizer
|
@@ -12,3 +12,12 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
|
16
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import AttributeFilter
|
17
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import (
|
18
|
+
FrameworkQuantizationCapabilities, OperationsSetToLayers, Smaller, SmallerEq, NotEq, Eq, GreaterEq, Greater,
|
19
|
+
LayerFilterParams, OperationsToLayers, get_current_tpc)
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities, OperatorsSet, \
|
21
|
+
OperatorSetGroup, Signedness, AttributeQuantizationConfig, OpQuantizationConfig, QuantizationConfigOptions, Fusing
|
22
|
+
|
23
|
+
from mct_quantizers import QuantizationMethod
|
@@ -7,6 +7,6 @@ OpQuantizationConfig = schema.OpQuantizationConfig
|
|
7
7
|
QuantizationConfigOptions = schema.QuantizationConfigOptions
|
8
8
|
OperatorsSetBase = schema.OperatorsSetBase
|
9
9
|
OperatorsSet = schema.OperatorsSet
|
10
|
-
|
10
|
+
OperatorSetGroup = schema.OperatorSetGroup
|
11
11
|
Fusing = schema.Fusing
|
12
|
-
|
12
|
+
TargetPlatformCapabilities = schema.TargetPlatformCapabilities
|
@@ -16,7 +16,7 @@ from logging import Logger
|
|
16
16
|
from typing import Optional
|
17
17
|
|
18
18
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
|
19
|
-
|
19
|
+
TargetPlatformCapabilities, QuantizationConfigOptions, OperatorsSetBase
|
20
20
|
|
21
21
|
|
22
22
|
def max_input_activation_n_bits(op_quantization_config: OpQuantizationConfig) -> int:
|
@@ -32,31 +32,31 @@ def max_input_activation_n_bits(op_quantization_config: OpQuantizationConfig) ->
|
|
32
32
|
return max(op_quantization_config.supported_input_activation_n_bits)
|
33
33
|
|
34
34
|
|
35
|
-
def get_config_options_by_operators_set(
|
35
|
+
def get_config_options_by_operators_set(tpc: TargetPlatformCapabilities,
|
36
36
|
operators_set_name: str) -> QuantizationConfigOptions:
|
37
37
|
"""
|
38
38
|
Get the QuantizationConfigOptions of an OperatorsSet by its name.
|
39
39
|
|
40
40
|
Args:
|
41
|
-
|
41
|
+
tpc (TargetPlatformCapabilities): The target platform model containing the operator sets and their configurations.
|
42
42
|
operators_set_name (str): The name of the OperatorsSet whose quantization configuration options are to be retrieved.
|
43
43
|
|
44
44
|
Returns:
|
45
45
|
QuantizationConfigOptions: The quantization configuration options associated with the specified OperatorsSet,
|
46
46
|
or the default quantization configuration options if the OperatorsSet is not found.
|
47
47
|
"""
|
48
|
-
for op_set in
|
48
|
+
for op_set in tpc.operator_set:
|
49
49
|
if operators_set_name == op_set.name:
|
50
50
|
return op_set.qc_options
|
51
|
-
return
|
51
|
+
return tpc.default_qco
|
52
52
|
|
53
53
|
|
54
|
-
def get_default_op_quantization_config(
|
54
|
+
def get_default_op_quantization_config(tpc: TargetPlatformCapabilities) -> OpQuantizationConfig:
|
55
55
|
"""
|
56
|
-
Get the default OpQuantizationConfig of the
|
56
|
+
Get the default OpQuantizationConfig of the TargetPlatformCapabilities.
|
57
57
|
|
58
58
|
Args:
|
59
|
-
|
59
|
+
tpc (TargetPlatformCapabilities): The target platform model containing the default quantization configuration.
|
60
60
|
|
61
61
|
Returns:
|
62
62
|
OpQuantizationConfig: The default quantization configuration.
|
@@ -64,32 +64,32 @@ def get_default_op_quantization_config(tp_model: TargetPlatformModel) -> OpQuant
|
|
64
64
|
Raises:
|
65
65
|
AssertionError: If the default quantization configuration list contains more than one configuration option.
|
66
66
|
"""
|
67
|
-
assert len(
|
67
|
+
assert len(tpc.default_qco.quantization_configurations) == 1, \
|
68
68
|
f"Default quantization configuration options must contain only one option, " \
|
69
|
-
f"but found {len(
|
70
|
-
return
|
69
|
+
f"but found {len(tpc.default_qco.quantization_configurations)} configurations." # pragma: no cover
|
70
|
+
return tpc.default_qco.quantization_configurations[0]
|
71
71
|
|
72
72
|
|
73
|
-
def is_opset_in_model(
|
73
|
+
def is_opset_in_model(tpc: TargetPlatformCapabilities, opset_name: str) -> bool:
|
74
74
|
"""
|
75
75
|
Check whether an OperatorsSet is defined in the model.
|
76
76
|
|
77
77
|
Args:
|
78
|
-
|
78
|
+
tpc (TargetPlatformCapabilities): The target platform model containing the list of operator sets.
|
79
79
|
opset_name (str): The name of the OperatorsSet to check for existence.
|
80
80
|
|
81
81
|
Returns:
|
82
82
|
bool: True if an OperatorsSet with the given name exists in the target platform model,
|
83
83
|
otherwise False.
|
84
84
|
"""
|
85
|
-
return
|
85
|
+
return tpc.operator_set is not None and opset_name in [x.name for x in tpc.operator_set]
|
86
86
|
|
87
|
-
def get_opset_by_name(
|
87
|
+
def get_opset_by_name(tpc: TargetPlatformCapabilities, opset_name: str) -> Optional[OperatorsSetBase]:
|
88
88
|
"""
|
89
89
|
Get an OperatorsSet object from the model by its name.
|
90
90
|
|
91
91
|
Args:
|
92
|
-
|
92
|
+
tpc (TargetPlatformCapabilities): The target platform model containing the list of operator sets.
|
93
93
|
opset_name (str): The name of the OperatorsSet to be retrieved.
|
94
94
|
|
95
95
|
Returns:
|
@@ -99,7 +99,7 @@ def get_opset_by_name(tp_model: TargetPlatformModel, opset_name: str) -> Optiona
|
|
99
99
|
Raises:
|
100
100
|
A critical log message if multiple operator sets with the same name are found.
|
101
101
|
"""
|
102
|
-
opset_list = [x for x in
|
102
|
+
opset_list = [x for x in tpc.operator_set if x.name == opset_name]
|
103
103
|
if len(opset_list) > 1:
|
104
|
-
Logger.critical(f"Found more than one OperatorsSet in
|
104
|
+
Logger.critical(f"Found more than one OperatorsSet in TargetPlatformCapabilities with the name {opset_name}.") # pragma: no cover
|
105
105
|
return opset_list[0] if opset_list else None
|