mct-nightly 2.2.0.20250113.134913__py3-none-any.whl → 2.2.0.20250114.134534__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/RECORD +102 -104
- model_compression_toolkit/__init__.py +2 -2
- model_compression_toolkit/core/common/framework_info.py +1 -3
- model_compression_toolkit/core/common/fusion/layer_fusing.py +6 -5
- model_compression_toolkit/core/common/graph/base_graph.py +20 -21
- model_compression_toolkit/core/common/graph/base_node.py +44 -17
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +7 -6
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +187 -0
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +0 -6
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +35 -162
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +36 -62
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +668 -0
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +25 -202
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +74 -51
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +7 -6
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +0 -1
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +0 -1
- model_compression_toolkit/core/common/pruning/pruner.py +5 -3
- model_compression_toolkit/core/common/quantization/bit_width_config.py +6 -12
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -2
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +1 -1
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +15 -14
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +1 -1
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
- model_compression_toolkit/core/graph_prep_runner.py +12 -11
- model_compression_toolkit/core/keras/default_framework_info.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +1 -2
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +5 -6
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +4 -5
- model_compression_toolkit/core/runner.py +33 -60
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +1 -1
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantization_facade.py +8 -9
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +8 -9
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/metadata.py +11 -10
- model_compression_toolkit/pruning/keras/pruning_facade.py +5 -6
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +6 -7
- model_compression_toolkit/ptq/keras/quantization_facade.py +8 -9
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -9
- model_compression_toolkit/qat/keras/quantization_facade.py +5 -6
- model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py +1 -1
- model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +5 -9
- model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
- model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +1 -1
- model_compression_toolkit/target_platform_capabilities/__init__.py +9 -0
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +2 -2
- model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +18 -18
- model_compression_toolkit/target_platform_capabilities/schema/v1.py +13 -13
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/__init__.py +6 -6
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2fw.py +10 -10
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2keras.py +3 -3
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2pytorch.py +3 -2
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/current_tpc.py +8 -8
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities.py → targetplatform2framework/framework_quantization_capabilities.py} +40 -40
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities_component.py → targetplatform2framework/framework_quantization_capabilities_component.py} +2 -2
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/layer_filter_params.py +0 -1
- model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/operations_to_layers.py +8 -8
- model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +24 -24
- model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +18 -18
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/{tp_model.py → tpc.py} +31 -32
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +3 -3
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/{tp_model.py → tpc.py} +27 -27
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +4 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/{tp_model.py → tpc.py} +27 -27
- model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +1 -2
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +2 -1
- model_compression_toolkit/trainable_infrastructure/keras/activation_quantizers/lsq/symmetric_lsq.py +1 -2
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +1 -1
- model_compression_toolkit/xquant/common/model_folding_utils.py +7 -6
- model_compression_toolkit/xquant/keras/keras_report_utils.py +4 -4
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +0 -105
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +0 -33
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py +0 -528
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +0 -23
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attribute_filter.py +0 -0
@@ -12,44 +12,37 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from collections import namedtuple
|
16
15
|
|
17
16
|
import copy
|
18
|
-
|
19
|
-
from typing import Callable, Tuple, Any, List, Dict
|
20
|
-
|
21
|
-
import numpy as np
|
17
|
+
from typing import Callable, Any, List
|
22
18
|
|
23
19
|
from model_compression_toolkit.core.common import FrameworkInfo
|
20
|
+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
24
21
|
from model_compression_toolkit.core.common.fusion.graph_fuser import GraphFuser
|
25
|
-
|
22
|
+
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
26
23
|
from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import compute_graph_max_cut, \
|
27
24
|
SchedulerInfo
|
28
25
|
from model_compression_toolkit.core.common.graph.memory_graph.memory_graph import MemoryGraph
|
29
26
|
from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
|
27
|
+
from model_compression_toolkit.core.common.mixed_precision.bit_width_setter import set_bit_widths
|
30
28
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_candidates_filter import \
|
31
29
|
filter_candidates_for_mixed_precision
|
30
|
+
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_facade import search_bit_width
|
31
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
|
32
|
+
ResourceUtilization
|
33
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
|
34
|
+
ResourceUtilizationCalculator, TargetInclusionCriterion, BitwidthMode
|
32
35
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import \
|
33
36
|
requires_mixed_precision
|
34
|
-
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
|
35
|
-
from model_compression_toolkit.core.quantization_prep_runner import quantization_preparation_runner
|
36
|
-
from model_compression_toolkit.logger import Logger
|
37
|
-
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
38
|
-
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
39
|
-
from model_compression_toolkit.core.common.mixed_precision.bit_width_setter import set_bit_widths
|
40
|
-
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget
|
41
|
-
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_aggregation_methods import MpRuAggregation
|
42
|
-
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_functions_mapping import ru_functions_mapping
|
43
|
-
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import MpRuMetric
|
44
|
-
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_facade import search_bit_width
|
45
37
|
from model_compression_toolkit.core.common.network_editors.edit_network import edit_network_graph
|
46
38
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
47
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
48
|
-
from model_compression_toolkit.core.common.visualization.final_config_visualizer import \
|
49
|
-
WeightsFinalBitwidthConfigVisualizer, \
|
50
|
-
ActivationFinalBitwidthConfigVisualizer
|
51
39
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter, \
|
52
40
|
finalize_bitwidth_in_tb
|
41
|
+
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
|
42
|
+
from model_compression_toolkit.core.quantization_prep_runner import quantization_preparation_runner
|
43
|
+
from model_compression_toolkit.logger import Logger
|
44
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
45
|
+
FrameworkQuantizationCapabilities
|
53
46
|
|
54
47
|
|
55
48
|
def core_runner(in_model: Any,
|
@@ -57,7 +50,7 @@ def core_runner(in_model: Any,
|
|
57
50
|
core_config: CoreConfig,
|
58
51
|
fw_info: FrameworkInfo,
|
59
52
|
fw_impl: FrameworkImplementation,
|
60
|
-
|
53
|
+
fqc: FrameworkQuantizationCapabilities,
|
61
54
|
target_resource_utilization: ResourceUtilization = None,
|
62
55
|
running_gptq: bool = False,
|
63
56
|
tb_w: TensorboardWriter = None):
|
@@ -77,7 +70,7 @@ def core_runner(in_model: Any,
|
|
77
70
|
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
78
71
|
groups of layers by how they should be quantized, etc.).
|
79
72
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
80
|
-
|
73
|
+
fqc: FrameworkQuantizationCapabilities object that models the inference target platform and
|
81
74
|
the attached framework operator's information.
|
82
75
|
target_resource_utilization: ResourceUtilization to constraint the search of the mixed-precision configuration for the model.
|
83
76
|
tb_w: TensorboardWriter object for logging
|
@@ -88,7 +81,7 @@ def core_runner(in_model: Any,
|
|
88
81
|
"""
|
89
82
|
|
90
83
|
# Warn is representative dataset has batch-size == 1
|
91
|
-
batch_data = iter(representative_data_gen())
|
84
|
+
batch_data = next(iter(representative_data_gen()))
|
92
85
|
if isinstance(batch_data, list):
|
93
86
|
batch_data = batch_data[0]
|
94
87
|
if batch_data.shape[0] == 1:
|
@@ -96,7 +89,7 @@ def core_runner(in_model: Any,
|
|
96
89
|
' consider increasing the batch size')
|
97
90
|
|
98
91
|
# Checking whether to run mixed precision quantization
|
99
|
-
if target_resource_utilization is not None:
|
92
|
+
if target_resource_utilization is not None and target_resource_utilization.is_any_restricted():
|
100
93
|
if core_config.mixed_precision_config is None:
|
101
94
|
Logger.critical("Provided an initialized target_resource_utilization, that means that mixed precision quantization is "
|
102
95
|
"enabled, but the provided MixedPrecisionQuantizationConfig is None.")
|
@@ -105,7 +98,7 @@ def core_runner(in_model: Any,
|
|
105
98
|
target_resource_utilization,
|
106
99
|
representative_data_gen,
|
107
100
|
core_config,
|
108
|
-
|
101
|
+
fqc,
|
109
102
|
fw_info,
|
110
103
|
fw_impl):
|
111
104
|
core_config.mixed_precision_config.set_mixed_precision_enable()
|
@@ -116,7 +109,7 @@ def core_runner(in_model: Any,
|
|
116
109
|
core_config.quantization_config,
|
117
110
|
fw_info,
|
118
111
|
fw_impl,
|
119
|
-
|
112
|
+
fqc,
|
120
113
|
core_config.bit_width_config,
|
121
114
|
tb_w,
|
122
115
|
mixed_precision_enable=core_config.is_mixed_precision_enabled,
|
@@ -138,7 +131,7 @@ def core_runner(in_model: Any,
|
|
138
131
|
if core_config.is_mixed_precision_enabled:
|
139
132
|
if core_config.mixed_precision_config.configuration_overwrite is None:
|
140
133
|
|
141
|
-
filter_candidates_for_mixed_precision(graph, target_resource_utilization, fw_info,
|
134
|
+
filter_candidates_for_mixed_precision(graph, target_resource_utilization, fw_info, fqc)
|
142
135
|
bit_widths_config = search_bit_width(tg,
|
143
136
|
fw_info,
|
144
137
|
fw_impl,
|
@@ -177,7 +170,6 @@ def core_runner(in_model: Any,
|
|
177
170
|
|
178
171
|
_set_final_resource_utilization(graph=tg,
|
179
172
|
final_bit_widths_config=bit_widths_config,
|
180
|
-
ru_functions_dict=ru_functions_mapping,
|
181
173
|
fw_info=fw_info,
|
182
174
|
fw_impl=fw_impl)
|
183
175
|
|
@@ -215,7 +207,6 @@ def core_runner(in_model: Any,
|
|
215
207
|
|
216
208
|
def _set_final_resource_utilization(graph: Graph,
|
217
209
|
final_bit_widths_config: List[int],
|
218
|
-
ru_functions_dict: Dict[RUTarget, Tuple[MpRuMetric, MpRuAggregation]],
|
219
210
|
fw_info: FrameworkInfo,
|
220
211
|
fw_impl: FrameworkImplementation):
|
221
212
|
"""
|
@@ -225,39 +216,21 @@ def _set_final_resource_utilization(graph: Graph,
|
|
225
216
|
Args:
|
226
217
|
graph: Graph to compute the resource utilization for.
|
227
218
|
final_bit_widths_config: The final bit-width configuration to quantize the model accordingly.
|
228
|
-
ru_functions_dict: A mapping between a RUTarget and a pair of resource utilization method and resource utilization aggregation functions.
|
229
219
|
fw_info: A FrameworkInfo object.
|
230
220
|
fw_impl: FrameworkImplementation object with specific framework methods implementation.
|
231
221
|
|
232
222
|
"""
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
elif len(final_bit_widths_config) > 0 and len(non_conf_ru) == 0:
|
246
|
-
final_ru_dict[ru_target] = ru_aggr(conf_ru, False)[0]
|
247
|
-
elif len(final_bit_widths_config) == 0 and len(non_conf_ru) > 0:
|
248
|
-
# final_bit_widths_config == 0 ==> no configurable nodes,
|
249
|
-
# thus, ru can be computed from non_conf_ru alone
|
250
|
-
final_ru_dict[ru_target] = ru_aggr(non_conf_ru, False)[0]
|
251
|
-
else:
|
252
|
-
# No relevant nodes have been quantized with affect on the given target - since we only consider
|
253
|
-
# in the model's final size the quantized layers size, this means that the final size for this target
|
254
|
-
# is zero.
|
255
|
-
Logger.warning(f"No relevant quantized layers for the ru target {ru_target} were found, the recorded "
|
256
|
-
f"final ru for this target would be 0.")
|
257
|
-
final_ru_dict[ru_target] = 0
|
258
|
-
|
259
|
-
final_ru = ResourceUtilization()
|
260
|
-
final_ru.set_resource_utilization_by_target(final_ru_dict)
|
261
|
-
print(final_ru)
|
223
|
+
w_qcs = {n: n.final_weights_quantization_cfg for n in graph.nodes}
|
224
|
+
a_qcs = {n: n.final_activation_quantization_cfg for n in graph.nodes}
|
225
|
+
ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
|
226
|
+
final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.QCustom,
|
227
|
+
act_qcs=a_qcs, w_qcs=w_qcs)
|
228
|
+
|
229
|
+
for ru_target, ru in final_ru.get_resource_utilization_dict().items():
|
230
|
+
if ru == 0:
|
231
|
+
Logger.warning(f"No relevant quantized layers for the resource utilization target {ru_target} were found, "
|
232
|
+
f"the recorded final ru for this target would be 0.")
|
233
|
+
|
234
|
+
Logger.info(f'Resource utilization (of quantized targets):\n {str(final_ru)}.')
|
262
235
|
graph.user_info.final_resource_utilization = final_ru
|
263
236
|
graph.user_info.mixed_precision_cfg = final_bit_widths_config
|
@@ -20,7 +20,7 @@ from model_compression_toolkit.core.common.quantization.node_quantization_config
|
|
20
20
|
NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
|
21
21
|
|
22
22
|
from model_compression_toolkit.logger import Logger
|
23
|
-
from
|
23
|
+
from mct_quantizers import QuantizationMethod
|
24
24
|
from mct_quantizers import QuantizationTarget
|
25
25
|
from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
|
26
26
|
from mct_quantizers.keras.quantizers import BaseKerasInferableQuantizer
|
@@ -21,7 +21,7 @@ from model_compression_toolkit.constants import THRESHOLD, SIGNED, RANGE_MIN, RA
|
|
21
21
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
|
22
22
|
NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
|
23
23
|
from model_compression_toolkit.logger import Logger
|
24
|
-
from
|
24
|
+
from mct_quantizers import QuantizationMethod
|
25
25
|
from mct_quantizers import QuantizationTarget
|
26
26
|
from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
|
27
27
|
from mct_quantizers import \
|
@@ -22,7 +22,9 @@ from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR
|
|
22
22
|
LR_BIAS_DEFAULT, GPTQ_MOMENTUM, REG_DEFAULT_SLA
|
23
23
|
from model_compression_toolkit.logger import Logger
|
24
24
|
from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE, GPTQ_HESSIAN_NUM_SAMPLES
|
25
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
25
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
26
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
|
27
|
+
AttachTpcToKeras
|
26
28
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
27
29
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
28
30
|
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig, \
|
@@ -33,7 +35,6 @@ from model_compression_toolkit.core import CoreConfig
|
|
33
35
|
from model_compression_toolkit.core.runner import core_runner
|
34
36
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
35
37
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
36
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
37
38
|
from model_compression_toolkit.metadata import create_model_metadata
|
38
39
|
|
39
40
|
|
@@ -48,8 +49,6 @@ if FOUND_TF:
|
|
48
49
|
from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
|
49
50
|
from model_compression_toolkit import get_target_platform_capabilities
|
50
51
|
from mct_quantizers.keras.metadata import add_metadata
|
51
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
|
52
|
-
AttachTpcToKeras
|
53
52
|
|
54
53
|
# As from TF2.9 optimizers package is changed
|
55
54
|
if version.parse(tf.__version__) < version.parse("2.9"):
|
@@ -157,7 +156,7 @@ if FOUND_TF:
|
|
157
156
|
gptq_representative_data_gen: Callable = None,
|
158
157
|
target_resource_utilization: ResourceUtilization = None,
|
159
158
|
core_config: CoreConfig = CoreConfig(),
|
160
|
-
target_platform_capabilities:
|
159
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
|
161
160
|
"""
|
162
161
|
Quantize a trained Keras model using post-training quantization. The model is quantized using a
|
163
162
|
symmetric constraint quantization thresholds (power of two).
|
@@ -244,7 +243,7 @@ if FOUND_TF:
|
|
244
243
|
|
245
244
|
# Attach tpc model to framework
|
246
245
|
attach2keras = AttachTpcToKeras()
|
247
|
-
|
246
|
+
framework_platform_capabilities = attach2keras.attach(
|
248
247
|
target_platform_capabilities,
|
249
248
|
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
|
250
249
|
|
@@ -253,7 +252,7 @@ if FOUND_TF:
|
|
253
252
|
core_config=core_config,
|
254
253
|
fw_info=DEFAULT_KERAS_INFO,
|
255
254
|
fw_impl=fw_impl,
|
256
|
-
|
255
|
+
fqc=framework_platform_capabilities,
|
257
256
|
target_resource_utilization=target_resource_utilization,
|
258
257
|
tb_w=tb_w,
|
259
258
|
running_gptq=True)
|
@@ -281,9 +280,9 @@ if FOUND_TF:
|
|
281
280
|
DEFAULT_KERAS_INFO)
|
282
281
|
|
283
282
|
exportable_model, user_info = get_exportable_keras_model(tg_gptq)
|
284
|
-
if
|
283
|
+
if framework_platform_capabilities.tpc.add_metadata:
|
285
284
|
exportable_model = add_metadata(exportable_model,
|
286
|
-
create_model_metadata(
|
285
|
+
create_model_metadata(fqc=framework_platform_capabilities,
|
287
286
|
scheduling_info=scheduling_info))
|
288
287
|
return exportable_model, user_info
|
289
288
|
|
@@ -18,7 +18,7 @@ import numpy as np
|
|
18
18
|
|
19
19
|
from model_compression_toolkit.gptq import RoundingType
|
20
20
|
from model_compression_toolkit.core.common import max_power_of_two
|
21
|
-
from
|
21
|
+
from mct_quantizers import QuantizationMethod
|
22
22
|
from mct_quantizers import QuantizationTarget
|
23
23
|
from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
|
24
24
|
SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
|
@@ -18,7 +18,7 @@ import numpy as np
|
|
18
18
|
|
19
19
|
from model_compression_toolkit.gptq import RoundingType
|
20
20
|
from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
|
21
|
-
from
|
21
|
+
from mct_quantizers import QuantizationMethod
|
22
22
|
from mct_quantizers import QuantizationTarget
|
23
23
|
from model_compression_toolkit.gptq.common.gptq_constants import \
|
24
24
|
SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
|
@@ -19,7 +19,7 @@ import numpy as np
|
|
19
19
|
import tensorflow as tf
|
20
20
|
|
21
21
|
from model_compression_toolkit.gptq import RoundingType
|
22
|
-
from
|
22
|
+
from mct_quantizers import QuantizationMethod
|
23
23
|
from mct_quantizers import QuantizationTarget
|
24
24
|
from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD
|
25
25
|
from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
|
@@ -31,8 +31,7 @@ from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR
|
|
31
31
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
32
32
|
from model_compression_toolkit.logger import Logger
|
33
33
|
from model_compression_toolkit.metadata import create_model_metadata
|
34
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
35
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
34
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
36
35
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
37
36
|
|
38
37
|
|
@@ -48,7 +47,7 @@ if FOUND_TORCH:
|
|
48
47
|
from torch.optim import Adam, Optimizer
|
49
48
|
from model_compression_toolkit import get_target_platform_capabilities
|
50
49
|
from mct_quantizers.pytorch.metadata import add_metadata
|
51
|
-
from model_compression_toolkit.target_platform_capabilities.
|
50
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
|
52
51
|
AttachTpcToPytorch
|
53
52
|
|
54
53
|
DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
@@ -146,11 +145,11 @@ if FOUND_TORCH:
|
|
146
145
|
core_config: CoreConfig = CoreConfig(),
|
147
146
|
gptq_config: GradientPTQConfig = None,
|
148
147
|
gptq_representative_data_gen: Callable = None,
|
149
|
-
target_platform_capabilities:
|
148
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
|
150
149
|
"""
|
151
150
|
Quantize a trained Pytorch module using post-training quantization.
|
152
151
|
By default, the module is quantized using a symmetric constraint quantization thresholds
|
153
|
-
(power of two) as defined in the default
|
152
|
+
(power of two) as defined in the default FrameworkQuantizationCapabilities.
|
154
153
|
The module is first optimized using several transformations (e.g. BatchNormalization folding to
|
155
154
|
preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
|
156
155
|
being collected for each layer's output (and input, depends on the quantization configuration).
|
@@ -217,7 +216,7 @@ if FOUND_TORCH:
|
|
217
216
|
|
218
217
|
# Attach tpc model to framework
|
219
218
|
attach2pytorch = AttachTpcToPytorch()
|
220
|
-
|
219
|
+
framework_quantization_capabilities = attach2pytorch.attach(target_platform_capabilities,
|
221
220
|
core_config.quantization_config.custom_tpc_opset_to_layer)
|
222
221
|
|
223
222
|
# ---------------------- #
|
@@ -228,7 +227,7 @@ if FOUND_TORCH:
|
|
228
227
|
core_config=core_config,
|
229
228
|
fw_info=DEFAULT_PYTORCH_INFO,
|
230
229
|
fw_impl=fw_impl,
|
231
|
-
|
230
|
+
fqc=framework_quantization_capabilities,
|
232
231
|
target_resource_utilization=target_resource_utilization,
|
233
232
|
tb_w=tb_w,
|
234
233
|
running_gptq=True)
|
@@ -257,9 +256,9 @@ if FOUND_TORCH:
|
|
257
256
|
DEFAULT_PYTORCH_INFO)
|
258
257
|
|
259
258
|
exportable_model, user_info = get_exportable_pytorch_model(graph_gptq)
|
260
|
-
if
|
259
|
+
if framework_quantization_capabilities.tpc.add_metadata:
|
261
260
|
exportable_model = add_metadata(exportable_model,
|
262
|
-
create_model_metadata(
|
261
|
+
create_model_metadata(fqc=framework_quantization_capabilities,
|
263
262
|
scheduling_info=scheduling_info))
|
264
263
|
return exportable_model, user_info
|
265
264
|
|
@@ -18,7 +18,7 @@ from typing import Dict
|
|
18
18
|
import numpy as np
|
19
19
|
|
20
20
|
from model_compression_toolkit.core.common import max_power_of_two
|
21
|
-
from
|
21
|
+
from mct_quantizers import QuantizationMethod
|
22
22
|
from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
|
23
23
|
from model_compression_toolkit.gptq.common.gptq_config import RoundingType
|
24
24
|
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
|
@@ -18,7 +18,7 @@ from typing import Dict
|
|
18
18
|
import numpy as np
|
19
19
|
|
20
20
|
from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
|
21
|
-
from
|
21
|
+
from mct_quantizers import QuantizationMethod
|
22
22
|
from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
|
23
23
|
from model_compression_toolkit.gptq.common.gptq_config import RoundingType
|
24
24
|
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
|
@@ -18,7 +18,7 @@ from typing import Dict
|
|
18
18
|
import numpy as np
|
19
19
|
from model_compression_toolkit.defaultdict import DefaultDict
|
20
20
|
|
21
|
-
from
|
21
|
+
from mct_quantizers import QuantizationMethod
|
22
22
|
from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
|
23
23
|
from model_compression_toolkit.gptq.common.gptq_config import RoundingType
|
24
24
|
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
|
@@ -18,33 +18,34 @@ from typing import Dict, Any
|
|
18
18
|
from model_compression_toolkit.constants import OPERATORS_SCHEDULING, FUSED_NODES_MAPPING, CUTS, MAX_CUT, OP_ORDER, \
|
19
19
|
OP_RECORD, SHAPE, NODE_OUTPUT_INDEX, NODE_NAME, TOTAL_SIZE, MEM_ELEMENTS
|
20
20
|
from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import SchedulerInfo
|
21
|
-
from model_compression_toolkit.target_platform_capabilities.
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
22
|
+
FrameworkQuantizationCapabilities
|
22
23
|
|
23
24
|
|
24
|
-
def create_model_metadata(
|
25
|
+
def create_model_metadata(fqc: FrameworkQuantizationCapabilities,
|
25
26
|
scheduling_info: SchedulerInfo = None) -> Dict:
|
26
27
|
"""
|
27
28
|
Creates and returns a metadata dictionary for the model, including version information
|
28
29
|
and optional scheduling information.
|
29
30
|
|
30
31
|
Args:
|
31
|
-
|
32
|
+
fqc: A FQC object to get the version.
|
32
33
|
scheduling_info: An object containing scheduling details and metadata. Default is None.
|
33
34
|
|
34
35
|
Returns:
|
35
36
|
Dict: A dictionary containing the model's version information and optional scheduling information.
|
36
37
|
"""
|
37
|
-
_metadata = get_versions_dict(
|
38
|
+
_metadata = get_versions_dict(fqc)
|
38
39
|
if scheduling_info:
|
39
40
|
scheduler_metadata = get_scheduler_metadata(scheduler_info=scheduling_info)
|
40
41
|
_metadata['scheduling_info'] = scheduler_metadata
|
41
42
|
return _metadata
|
42
43
|
|
43
44
|
|
44
|
-
def get_versions_dict(
|
45
|
+
def get_versions_dict(fqc) -> Dict:
|
45
46
|
"""
|
46
47
|
|
47
|
-
Returns: A dictionary with
|
48
|
+
Returns: A dictionary with FQC, MCT and FQC-Schema versions.
|
48
49
|
|
49
50
|
"""
|
50
51
|
# imported inside to avoid circular import error
|
@@ -53,10 +54,10 @@ def get_versions_dict(tpc) -> Dict:
|
|
53
54
|
@dataclass
|
54
55
|
class TPCVersions:
|
55
56
|
mct_version: str
|
56
|
-
tpc_minor_version: str = f'{tpc.
|
57
|
-
tpc_patch_version: str = f'{tpc.
|
58
|
-
tpc_platform_type: str = f'{tpc.
|
59
|
-
tpc_schema: str = f'{tpc.
|
57
|
+
tpc_minor_version: str = f'{fqc.tpc.tpc_minor_version}'
|
58
|
+
tpc_patch_version: str = f'{fqc.tpc.tpc_patch_version}'
|
59
|
+
tpc_platform_type: str = f'{fqc.tpc.tpc_platform_type}'
|
60
|
+
tpc_schema: str = f'{fqc.tpc.SCHEMA_VERSION}'
|
60
61
|
|
61
62
|
return asdict(TPCVersions(mct_version))
|
62
63
|
|
@@ -17,7 +17,7 @@ from typing import Callable, Tuple
|
|
17
17
|
|
18
18
|
from model_compression_toolkit import get_target_platform_capabilities
|
19
19
|
from model_compression_toolkit.constants import TENSORFLOW
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
21
21
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
22
22
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
23
23
|
from model_compression_toolkit.core.common.pruning.pruner import Pruner
|
@@ -26,17 +26,16 @@ from model_compression_toolkit.core.common.pruning.pruning_info import PruningIn
|
|
26
26
|
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
|
27
27
|
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
|
28
28
|
from model_compression_toolkit.logger import Logger
|
29
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
30
29
|
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
31
30
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
32
31
|
|
33
32
|
if FOUND_TF:
|
33
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
|
34
|
+
AttachTpcToKeras
|
34
35
|
from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
|
35
36
|
from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation
|
36
37
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
37
38
|
from tensorflow.keras.models import Model
|
38
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
|
39
|
-
AttachTpcToKeras
|
40
39
|
|
41
40
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
42
41
|
|
@@ -44,7 +43,7 @@ if FOUND_TF:
|
|
44
43
|
target_resource_utilization: ResourceUtilization,
|
45
44
|
representative_data_gen: Callable,
|
46
45
|
pruning_config: PruningConfig = PruningConfig(),
|
47
|
-
target_platform_capabilities:
|
46
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]:
|
48
47
|
"""
|
49
48
|
Perform structured pruning on a Keras model to meet a specified target resource utilization.
|
50
49
|
This function prunes the provided model according to the target resource utilization by grouping and pruning
|
@@ -62,7 +61,7 @@ if FOUND_TF:
|
|
62
61
|
target_resource_utilization (ResourceUtilization): The target Key Performance Indicators to be achieved through pruning.
|
63
62
|
representative_data_gen (Callable): A function to generate representative data for pruning analysis.
|
64
63
|
pruning_config (PruningConfig): Configuration settings for the pruning process. Defaults to standard config.
|
65
|
-
target_platform_capabilities (
|
64
|
+
target_platform_capabilities (FrameworkQuantizationCapabilities): Platform-specific constraints and capabilities. Defaults to DEFAULT_KERAS_TPC.
|
66
65
|
|
67
66
|
Returns:
|
68
67
|
Tuple[Model, PruningInfo]: A tuple containing the pruned Keras model and associated pruning information.
|
@@ -16,7 +16,7 @@
|
|
16
16
|
from typing import Callable, Tuple
|
17
17
|
from model_compression_toolkit import get_target_platform_capabilities
|
18
18
|
from model_compression_toolkit.constants import PYTORCH
|
19
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
20
20
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
21
21
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
22
22
|
from model_compression_toolkit.core.common.pruning.pruner import Pruner
|
@@ -25,7 +25,6 @@ from model_compression_toolkit.core.common.pruning.pruning_info import PruningIn
|
|
25
25
|
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
|
26
26
|
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
|
27
27
|
from model_compression_toolkit.logger import Logger
|
28
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
29
28
|
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
30
29
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
31
30
|
|
@@ -38,7 +37,7 @@ if FOUND_TORCH:
|
|
38
37
|
PruningPytorchImplementation
|
39
38
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
40
39
|
from torch.nn import Module
|
41
|
-
from model_compression_toolkit.target_platform_capabilities.
|
40
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
|
42
41
|
AttachTpcToPytorch
|
43
42
|
|
44
43
|
# Set the default Target Platform Capabilities (TPC) for PyTorch.
|
@@ -48,7 +47,7 @@ if FOUND_TORCH:
|
|
48
47
|
target_resource_utilization: ResourceUtilization,
|
49
48
|
representative_data_gen: Callable,
|
50
49
|
pruning_config: PruningConfig = PruningConfig(),
|
51
|
-
target_platform_capabilities:
|
50
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYOTRCH_TPC) -> \
|
52
51
|
Tuple[Module, PruningInfo]:
|
53
52
|
"""
|
54
53
|
Perform structured pruning on a Pytorch model to meet a specified target resource utilization.
|
@@ -121,12 +120,12 @@ if FOUND_TORCH:
|
|
121
120
|
|
122
121
|
# Attach TPC to framework
|
123
122
|
attach2pytorch = AttachTpcToPytorch()
|
124
|
-
|
123
|
+
framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities)
|
125
124
|
|
126
125
|
# Convert the original Pytorch model to an internal graph representation.
|
127
126
|
float_graph = read_model_to_graph(model,
|
128
127
|
representative_data_gen,
|
129
|
-
|
128
|
+
framework_platform_capabilities,
|
130
129
|
DEFAULT_PYTORCH_INFO,
|
131
130
|
fw_impl)
|
132
131
|
|
@@ -143,7 +142,7 @@ if FOUND_TORCH:
|
|
143
142
|
target_resource_utilization,
|
144
143
|
representative_data_gen,
|
145
144
|
pruning_config,
|
146
|
-
|
145
|
+
framework_platform_capabilities)
|
147
146
|
|
148
147
|
# Apply the pruning process.
|
149
148
|
pruned_graph = pruner.prune_graph()
|
@@ -22,17 +22,18 @@ from model_compression_toolkit.core.common.quantization.quantize_graph_weights i
|
|
22
22
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
23
23
|
from model_compression_toolkit.logger import Logger
|
24
24
|
from model_compression_toolkit.constants import TENSORFLOW
|
25
|
-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import
|
25
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
|
26
26
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
27
27
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
28
28
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
29
29
|
MixedPrecisionQuantizationConfig
|
30
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
31
30
|
from model_compression_toolkit.core.runner import core_runner
|
32
31
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
33
32
|
from model_compression_toolkit.metadata import create_model_metadata
|
34
33
|
|
35
34
|
if FOUND_TF:
|
35
|
+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
|
36
|
+
AttachTpcToKeras
|
36
37
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
37
38
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
38
39
|
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
@@ -42,8 +43,6 @@ if FOUND_TF:
|
|
42
43
|
|
43
44
|
from model_compression_toolkit import get_target_platform_capabilities
|
44
45
|
from mct_quantizers.keras.metadata import add_metadata
|
45
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
|
46
|
-
AttachTpcToKeras
|
47
46
|
|
48
47
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
49
48
|
|
@@ -52,7 +51,7 @@ if FOUND_TF:
|
|
52
51
|
representative_data_gen: Callable,
|
53
52
|
target_resource_utilization: ResourceUtilization = None,
|
54
53
|
core_config: CoreConfig = CoreConfig(),
|
55
|
-
target_platform_capabilities:
|
54
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
|
56
55
|
"""
|
57
56
|
Quantize a trained Keras model using post-training quantization. The model is quantized using a
|
58
57
|
symmetric constraint quantization thresholds (power of two).
|
@@ -139,7 +138,7 @@ if FOUND_TF:
|
|
139
138
|
fw_impl = KerasImplementation()
|
140
139
|
|
141
140
|
attach2keras = AttachTpcToKeras()
|
142
|
-
|
141
|
+
framework_platform_capabilities = attach2keras.attach(
|
143
142
|
target_platform_capabilities,
|
144
143
|
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
|
145
144
|
|
@@ -149,7 +148,7 @@ if FOUND_TF:
|
|
149
148
|
core_config=core_config,
|
150
149
|
fw_info=fw_info,
|
151
150
|
fw_impl=fw_impl,
|
152
|
-
|
151
|
+
fqc=framework_platform_capabilities,
|
153
152
|
target_resource_utilization=target_resource_utilization,
|
154
153
|
tb_w=tb_w)
|
155
154
|
|
@@ -177,9 +176,9 @@ if FOUND_TF:
|
|
177
176
|
fw_info)
|
178
177
|
|
179
178
|
exportable_model, user_info = get_exportable_keras_model(graph_with_stats_correction)
|
180
|
-
if
|
179
|
+
if framework_platform_capabilities.tpc.add_metadata:
|
181
180
|
exportable_model = add_metadata(exportable_model,
|
182
|
-
create_model_metadata(
|
181
|
+
create_model_metadata(fqc=framework_platform_capabilities,
|
183
182
|
scheduling_info=scheduling_info))
|
184
183
|
return exportable_model, user_info
|
185
184
|
|