mct-nightly 2.2.0.20250113.134913__py3-none-any.whl → 2.2.0.20250114.134534__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/RECORD +102 -104
  3. model_compression_toolkit/__init__.py +2 -2
  4. model_compression_toolkit/core/common/framework_info.py +1 -3
  5. model_compression_toolkit/core/common/fusion/layer_fusing.py +6 -5
  6. model_compression_toolkit/core/common/graph/base_graph.py +20 -21
  7. model_compression_toolkit/core/common/graph/base_node.py +44 -17
  8. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +7 -6
  9. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +187 -0
  10. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +0 -6
  11. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +35 -162
  12. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +36 -62
  13. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +668 -0
  14. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +25 -202
  15. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +74 -51
  16. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +3 -5
  17. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  18. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +7 -6
  19. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +0 -1
  20. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +0 -1
  21. model_compression_toolkit/core/common/pruning/pruner.py +5 -3
  22. model_compression_toolkit/core/common/quantization/bit_width_config.py +6 -12
  23. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -2
  24. model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
  25. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -1
  26. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  27. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +1 -1
  28. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +1 -1
  29. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +1 -1
  30. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +1 -1
  31. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +1 -1
  32. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +1 -1
  33. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +15 -14
  34. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
  35. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +1 -1
  36. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  37. model_compression_toolkit/core/graph_prep_runner.py +12 -11
  38. model_compression_toolkit/core/keras/default_framework_info.py +1 -1
  39. model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +1 -2
  40. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +5 -6
  41. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  42. model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
  43. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
  44. model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -1
  45. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +4 -5
  46. model_compression_toolkit/core/runner.py +33 -60
  47. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +1 -1
  48. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +1 -1
  49. model_compression_toolkit/gptq/keras/quantization_facade.py +8 -9
  50. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  51. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
  52. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
  53. model_compression_toolkit/gptq/pytorch/quantization_facade.py +8 -9
  54. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  55. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
  56. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
  57. model_compression_toolkit/metadata.py +11 -10
  58. model_compression_toolkit/pruning/keras/pruning_facade.py +5 -6
  59. model_compression_toolkit/pruning/pytorch/pruning_facade.py +6 -7
  60. model_compression_toolkit/ptq/keras/quantization_facade.py +8 -9
  61. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -9
  62. model_compression_toolkit/qat/keras/quantization_facade.py +5 -6
  63. model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py +1 -1
  64. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
  65. model_compression_toolkit/qat/pytorch/quantization_facade.py +5 -9
  66. model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py +1 -1
  67. model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py +1 -1
  68. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
  69. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +1 -1
  70. model_compression_toolkit/target_platform_capabilities/__init__.py +9 -0
  71. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  72. model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +2 -2
  73. model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +18 -18
  74. model_compression_toolkit/target_platform_capabilities/schema/v1.py +13 -13
  75. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/__init__.py +6 -6
  76. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2fw.py +10 -10
  77. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2keras.py +3 -3
  78. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2pytorch.py +3 -2
  79. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/current_tpc.py +8 -8
  80. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities.py → targetplatform2framework/framework_quantization_capabilities.py} +40 -40
  81. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities_component.py → targetplatform2framework/framework_quantization_capabilities_component.py} +2 -2
  82. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/layer_filter_params.py +0 -1
  83. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/operations_to_layers.py +8 -8
  84. model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +24 -24
  85. model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +18 -18
  86. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +3 -3
  87. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/{tp_model.py → tpc.py} +31 -32
  88. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +3 -3
  89. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/{tp_model.py → tpc.py} +27 -27
  90. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +4 -4
  91. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/{tp_model.py → tpc.py} +27 -27
  92. model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +1 -2
  93. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +2 -1
  94. model_compression_toolkit/trainable_infrastructure/keras/activation_quantizers/lsq/symmetric_lsq.py +1 -2
  95. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +1 -1
  96. model_compression_toolkit/xquant/common/model_folding_utils.py +7 -6
  97. model_compression_toolkit/xquant/keras/keras_report_utils.py +4 -4
  98. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -3
  99. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +0 -105
  100. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +0 -33
  101. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py +0 -528
  102. model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +0 -23
  103. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/LICENSE.md +0 -0
  104. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/WHEEL +0 -0
  105. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/top_level.txt +0 -0
  106. /model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attribute_filter.py +0 -0
@@ -12,44 +12,37 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from collections import namedtuple
16
15
 
17
16
  import copy
18
-
19
- from typing import Callable, Tuple, Any, List, Dict
20
-
21
- import numpy as np
17
+ from typing import Callable, Any, List
22
18
 
23
19
  from model_compression_toolkit.core.common import FrameworkInfo
20
+ from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
24
21
  from model_compression_toolkit.core.common.fusion.graph_fuser import GraphFuser
25
-
22
+ from model_compression_toolkit.core.common.graph.base_graph import Graph
26
23
  from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import compute_graph_max_cut, \
27
24
  SchedulerInfo
28
25
  from model_compression_toolkit.core.common.graph.memory_graph.memory_graph import MemoryGraph
29
26
  from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
27
+ from model_compression_toolkit.core.common.mixed_precision.bit_width_setter import set_bit_widths
30
28
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_candidates_filter import \
31
29
  filter_candidates_for_mixed_precision
30
+ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_facade import search_bit_width
31
+ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
32
+ ResourceUtilization
33
+ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
34
+ ResourceUtilizationCalculator, TargetInclusionCriterion, BitwidthMode
32
35
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import \
33
36
  requires_mixed_precision
34
- from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
35
- from model_compression_toolkit.core.quantization_prep_runner import quantization_preparation_runner
36
- from model_compression_toolkit.logger import Logger
37
- from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
38
- from model_compression_toolkit.core.common.graph.base_graph import Graph
39
- from model_compression_toolkit.core.common.mixed_precision.bit_width_setter import set_bit_widths
40
- from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget
41
- from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_aggregation_methods import MpRuAggregation
42
- from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_functions_mapping import ru_functions_mapping
43
- from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import MpRuMetric
44
- from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_facade import search_bit_width
45
37
  from model_compression_toolkit.core.common.network_editors.edit_network import edit_network_graph
46
38
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
47
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
48
- from model_compression_toolkit.core.common.visualization.final_config_visualizer import \
49
- WeightsFinalBitwidthConfigVisualizer, \
50
- ActivationFinalBitwidthConfigVisualizer
51
39
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter, \
52
40
  finalize_bitwidth_in_tb
41
+ from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
42
+ from model_compression_toolkit.core.quantization_prep_runner import quantization_preparation_runner
43
+ from model_compression_toolkit.logger import Logger
44
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
45
+ FrameworkQuantizationCapabilities
53
46
 
54
47
 
55
48
  def core_runner(in_model: Any,
@@ -57,7 +50,7 @@ def core_runner(in_model: Any,
57
50
  core_config: CoreConfig,
58
51
  fw_info: FrameworkInfo,
59
52
  fw_impl: FrameworkImplementation,
60
- tpc: TargetPlatformCapabilities,
53
+ fqc: FrameworkQuantizationCapabilities,
61
54
  target_resource_utilization: ResourceUtilization = None,
62
55
  running_gptq: bool = False,
63
56
  tb_w: TensorboardWriter = None):
@@ -77,7 +70,7 @@ def core_runner(in_model: Any,
77
70
  fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
78
71
  groups of layers by how they should be quantized, etc.).
79
72
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
80
- tpc: TargetPlatformCapabilities object that models the inference target platform and
73
+ fqc: FrameworkQuantizationCapabilities object that models the inference target platform and
81
74
  the attached framework operator's information.
82
75
  target_resource_utilization: ResourceUtilization to constraint the search of the mixed-precision configuration for the model.
83
76
  tb_w: TensorboardWriter object for logging
@@ -88,7 +81,7 @@ def core_runner(in_model: Any,
88
81
  """
89
82
 
90
83
  # Warn is representative dataset has batch-size == 1
91
- batch_data = iter(representative_data_gen()).__next__()
84
+ batch_data = next(iter(representative_data_gen()))
92
85
  if isinstance(batch_data, list):
93
86
  batch_data = batch_data[0]
94
87
  if batch_data.shape[0] == 1:
@@ -96,7 +89,7 @@ def core_runner(in_model: Any,
96
89
  ' consider increasing the batch size')
97
90
 
98
91
  # Checking whether to run mixed precision quantization
99
- if target_resource_utilization is not None:
92
+ if target_resource_utilization is not None and target_resource_utilization.is_any_restricted():
100
93
  if core_config.mixed_precision_config is None:
101
94
  Logger.critical("Provided an initialized target_resource_utilization, that means that mixed precision quantization is "
102
95
  "enabled, but the provided MixedPrecisionQuantizationConfig is None.")
@@ -105,7 +98,7 @@ def core_runner(in_model: Any,
105
98
  target_resource_utilization,
106
99
  representative_data_gen,
107
100
  core_config,
108
- tpc,
101
+ fqc,
109
102
  fw_info,
110
103
  fw_impl):
111
104
  core_config.mixed_precision_config.set_mixed_precision_enable()
@@ -116,7 +109,7 @@ def core_runner(in_model: Any,
116
109
  core_config.quantization_config,
117
110
  fw_info,
118
111
  fw_impl,
119
- tpc,
112
+ fqc,
120
113
  core_config.bit_width_config,
121
114
  tb_w,
122
115
  mixed_precision_enable=core_config.is_mixed_precision_enabled,
@@ -138,7 +131,7 @@ def core_runner(in_model: Any,
138
131
  if core_config.is_mixed_precision_enabled:
139
132
  if core_config.mixed_precision_config.configuration_overwrite is None:
140
133
 
141
- filter_candidates_for_mixed_precision(graph, target_resource_utilization, fw_info, tpc)
134
+ filter_candidates_for_mixed_precision(graph, target_resource_utilization, fw_info, fqc)
142
135
  bit_widths_config = search_bit_width(tg,
143
136
  fw_info,
144
137
  fw_impl,
@@ -177,7 +170,6 @@ def core_runner(in_model: Any,
177
170
 
178
171
  _set_final_resource_utilization(graph=tg,
179
172
  final_bit_widths_config=bit_widths_config,
180
- ru_functions_dict=ru_functions_mapping,
181
173
  fw_info=fw_info,
182
174
  fw_impl=fw_impl)
183
175
 
@@ -215,7 +207,6 @@ def core_runner(in_model: Any,
215
207
 
216
208
  def _set_final_resource_utilization(graph: Graph,
217
209
  final_bit_widths_config: List[int],
218
- ru_functions_dict: Dict[RUTarget, Tuple[MpRuMetric, MpRuAggregation]],
219
210
  fw_info: FrameworkInfo,
220
211
  fw_impl: FrameworkImplementation):
221
212
  """
@@ -225,39 +216,21 @@ def _set_final_resource_utilization(graph: Graph,
225
216
  Args:
226
217
  graph: Graph to compute the resource utilization for.
227
218
  final_bit_widths_config: The final bit-width configuration to quantize the model accordingly.
228
- ru_functions_dict: A mapping between a RUTarget and a pair of resource utilization method and resource utilization aggregation functions.
229
219
  fw_info: A FrameworkInfo object.
230
220
  fw_impl: FrameworkImplementation object with specific framework methods implementation.
231
221
 
232
222
  """
233
-
234
- final_ru_dict = {}
235
- for ru_target, ru_funcs in ru_functions_dict.items():
236
- ru_method, ru_aggr = ru_funcs
237
- if ru_target == RUTarget.BOPS:
238
- final_ru_dict[ru_target] = \
239
- ru_aggr(ru_method(final_bit_widths_config, graph, fw_info, fw_impl, False), False)[0]
240
- else:
241
- non_conf_ru = ru_method([], graph, fw_info, fw_impl)
242
- conf_ru = ru_method(final_bit_widths_config, graph, fw_info, fw_impl)
243
- if len(final_bit_widths_config) > 0 and len(non_conf_ru) > 0:
244
- final_ru_dict[ru_target] = ru_aggr(np.concatenate([conf_ru, non_conf_ru]), False)[0]
245
- elif len(final_bit_widths_config) > 0 and len(non_conf_ru) == 0:
246
- final_ru_dict[ru_target] = ru_aggr(conf_ru, False)[0]
247
- elif len(final_bit_widths_config) == 0 and len(non_conf_ru) > 0:
248
- # final_bit_widths_config == 0 ==> no configurable nodes,
249
- # thus, ru can be computed from non_conf_ru alone
250
- final_ru_dict[ru_target] = ru_aggr(non_conf_ru, False)[0]
251
- else:
252
- # No relevant nodes have been quantized with affect on the given target - since we only consider
253
- # in the model's final size the quantized layers size, this means that the final size for this target
254
- # is zero.
255
- Logger.warning(f"No relevant quantized layers for the ru target {ru_target} were found, the recorded "
256
- f"final ru for this target would be 0.")
257
- final_ru_dict[ru_target] = 0
258
-
259
- final_ru = ResourceUtilization()
260
- final_ru.set_resource_utilization_by_target(final_ru_dict)
261
- print(final_ru)
223
+ w_qcs = {n: n.final_weights_quantization_cfg for n in graph.nodes}
224
+ a_qcs = {n: n.final_activation_quantization_cfg for n in graph.nodes}
225
+ ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
226
+ final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.QCustom,
227
+ act_qcs=a_qcs, w_qcs=w_qcs)
228
+
229
+ for ru_target, ru in final_ru.get_resource_utilization_dict().items():
230
+ if ru == 0:
231
+ Logger.warning(f"No relevant quantized layers for the resource utilization target {ru_target} were found, "
232
+ f"the recorded final ru for this target would be 0.")
233
+
234
+ Logger.info(f'Resource utilization (of quantized targets):\n {str(final_ru)}.')
262
235
  graph.user_info.final_resource_utilization = final_ru
263
236
  graph.user_info.mixed_precision_cfg = final_bit_widths_config
@@ -20,7 +20,7 @@ from model_compression_toolkit.core.common.quantization.node_quantization_config
20
20
  NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
21
21
 
22
22
  from model_compression_toolkit.logger import Logger
23
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
+ from mct_quantizers import QuantizationMethod
24
24
  from mct_quantizers import QuantizationTarget
25
25
  from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
26
26
  from mct_quantizers.keras.quantizers import BaseKerasInferableQuantizer
@@ -21,7 +21,7 @@ from model_compression_toolkit.constants import THRESHOLD, SIGNED, RANGE_MIN, RA
21
21
  from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
22
22
  NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
23
23
  from model_compression_toolkit.logger import Logger
24
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
24
+ from mct_quantizers import QuantizationMethod
25
25
  from mct_quantizers import QuantizationTarget
26
26
  from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
27
27
  from mct_quantizers import \
@@ -22,7 +22,9 @@ from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR
22
22
  LR_BIAS_DEFAULT, GPTQ_MOMENTUM, REG_DEFAULT_SLA
23
23
  from model_compression_toolkit.logger import Logger
24
24
  from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE, GPTQ_HESSIAN_NUM_SAMPLES
25
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
25
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
26
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
27
+ AttachTpcToKeras
26
28
  from model_compression_toolkit.verify_packages import FOUND_TF
27
29
  from model_compression_toolkit.core.common.user_info import UserInformation
28
30
  from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig, \
@@ -33,7 +35,6 @@ from model_compression_toolkit.core import CoreConfig
33
35
  from model_compression_toolkit.core.runner import core_runner
34
36
  from model_compression_toolkit.gptq.runner import gptq_runner
35
37
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
36
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
37
38
  from model_compression_toolkit.metadata import create_model_metadata
38
39
 
39
40
 
@@ -48,8 +49,6 @@ if FOUND_TF:
48
49
  from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
49
50
  from model_compression_toolkit import get_target_platform_capabilities
50
51
  from mct_quantizers.keras.metadata import add_metadata
51
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
52
- AttachTpcToKeras
53
52
 
54
53
  # As from TF2.9 optimizers package is changed
55
54
  if version.parse(tf.__version__) < version.parse("2.9"):
@@ -157,7 +156,7 @@ if FOUND_TF:
157
156
  gptq_representative_data_gen: Callable = None,
158
157
  target_resource_utilization: ResourceUtilization = None,
159
158
  core_config: CoreConfig = CoreConfig(),
160
- target_platform_capabilities: TargetPlatformModel = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
159
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
161
160
  """
162
161
  Quantize a trained Keras model using post-training quantization. The model is quantized using a
163
162
  symmetric constraint quantization thresholds (power of two).
@@ -244,7 +243,7 @@ if FOUND_TF:
244
243
 
245
244
  # Attach tpc model to framework
246
245
  attach2keras = AttachTpcToKeras()
247
- target_platform_capabilities = attach2keras.attach(
246
+ framework_platform_capabilities = attach2keras.attach(
248
247
  target_platform_capabilities,
249
248
  custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
250
249
 
@@ -253,7 +252,7 @@ if FOUND_TF:
253
252
  core_config=core_config,
254
253
  fw_info=DEFAULT_KERAS_INFO,
255
254
  fw_impl=fw_impl,
256
- tpc=target_platform_capabilities,
255
+ fqc=framework_platform_capabilities,
257
256
  target_resource_utilization=target_resource_utilization,
258
257
  tb_w=tb_w,
259
258
  running_gptq=True)
@@ -281,9 +280,9 @@ if FOUND_TF:
281
280
  DEFAULT_KERAS_INFO)
282
281
 
283
282
  exportable_model, user_info = get_exportable_keras_model(tg_gptq)
284
- if target_platform_capabilities.tp_model.add_metadata:
283
+ if framework_platform_capabilities.tpc.add_metadata:
285
284
  exportable_model = add_metadata(exportable_model,
286
- create_model_metadata(tpc=target_platform_capabilities,
285
+ create_model_metadata(fqc=framework_platform_capabilities,
287
286
  scheduling_info=scheduling_info))
288
287
  return exportable_model, user_info
289
288
 
@@ -18,7 +18,7 @@ import numpy as np
18
18
 
19
19
  from model_compression_toolkit.gptq import RoundingType
20
20
  from model_compression_toolkit.core.common import max_power_of_two
21
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
+ from mct_quantizers import QuantizationMethod
22
22
  from mct_quantizers import QuantizationTarget
23
23
  from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
24
24
  SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
@@ -18,7 +18,7 @@ import numpy as np
18
18
 
19
19
  from model_compression_toolkit.gptq import RoundingType
20
20
  from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
21
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
+ from mct_quantizers import QuantizationMethod
22
22
  from mct_quantizers import QuantizationTarget
23
23
  from model_compression_toolkit.gptq.common.gptq_constants import \
24
24
  SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
@@ -19,7 +19,7 @@ import numpy as np
19
19
  import tensorflow as tf
20
20
 
21
21
  from model_compression_toolkit.gptq import RoundingType
22
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
+ from mct_quantizers import QuantizationMethod
23
23
  from mct_quantizers import QuantizationTarget
24
24
  from model_compression_toolkit.gptq.common.gptq_constants import AUXVAR, PTQ_THRESHOLD
25
25
  from model_compression_toolkit.gptq.keras.quantizer import quant_utils as qutils
@@ -31,8 +31,7 @@ from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR
31
31
  from model_compression_toolkit.gptq.runner import gptq_runner
32
32
  from model_compression_toolkit.logger import Logger
33
33
  from model_compression_toolkit.metadata import create_model_metadata
34
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
35
- from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
34
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
36
35
  from model_compression_toolkit.verify_packages import FOUND_TORCH
37
36
 
38
37
 
@@ -48,7 +47,7 @@ if FOUND_TORCH:
48
47
  from torch.optim import Adam, Optimizer
49
48
  from model_compression_toolkit import get_target_platform_capabilities
50
49
  from mct_quantizers.pytorch.metadata import add_metadata
51
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
50
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
52
51
  AttachTpcToPytorch
53
52
 
54
53
  DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
@@ -146,11 +145,11 @@ if FOUND_TORCH:
146
145
  core_config: CoreConfig = CoreConfig(),
147
146
  gptq_config: GradientPTQConfig = None,
148
147
  gptq_representative_data_gen: Callable = None,
149
- target_platform_capabilities: TargetPlatformModel = DEFAULT_PYTORCH_TPC):
148
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
150
149
  """
151
150
  Quantize a trained Pytorch module using post-training quantization.
152
151
  By default, the module is quantized using a symmetric constraint quantization thresholds
153
- (power of two) as defined in the default TargetPlatformCapabilities.
152
+ (power of two) as defined in the default FrameworkQuantizationCapabilities.
154
153
  The module is first optimized using several transformations (e.g. BatchNormalization folding to
155
154
  preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
156
155
  being collected for each layer's output (and input, depends on the quantization configuration).
@@ -217,7 +216,7 @@ if FOUND_TORCH:
217
216
 
218
217
  # Attach tpc model to framework
219
218
  attach2pytorch = AttachTpcToPytorch()
220
- target_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
219
+ framework_quantization_capabilities = attach2pytorch.attach(target_platform_capabilities,
221
220
  core_config.quantization_config.custom_tpc_opset_to_layer)
222
221
 
223
222
  # ---------------------- #
@@ -228,7 +227,7 @@ if FOUND_TORCH:
228
227
  core_config=core_config,
229
228
  fw_info=DEFAULT_PYTORCH_INFO,
230
229
  fw_impl=fw_impl,
231
- tpc=target_platform_capabilities,
230
+ fqc=framework_quantization_capabilities,
232
231
  target_resource_utilization=target_resource_utilization,
233
232
  tb_w=tb_w,
234
233
  running_gptq=True)
@@ -257,9 +256,9 @@ if FOUND_TORCH:
257
256
  DEFAULT_PYTORCH_INFO)
258
257
 
259
258
  exportable_model, user_info = get_exportable_pytorch_model(graph_gptq)
260
- if target_platform_capabilities.tp_model.add_metadata:
259
+ if framework_quantization_capabilities.tpc.add_metadata:
261
260
  exportable_model = add_metadata(exportable_model,
262
- create_model_metadata(tpc=target_platform_capabilities,
261
+ create_model_metadata(fqc=framework_quantization_capabilities,
263
262
  scheduling_info=scheduling_info))
264
263
  return exportable_model, user_info
265
264
 
@@ -18,7 +18,7 @@ from typing import Dict
18
18
  import numpy as np
19
19
 
20
20
  from model_compression_toolkit.core.common import max_power_of_two
21
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
+ from mct_quantizers import QuantizationMethod
22
22
  from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
23
23
  from model_compression_toolkit.gptq.common.gptq_config import RoundingType
24
24
  from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
@@ -18,7 +18,7 @@ from typing import Dict
18
18
  import numpy as np
19
19
 
20
20
  from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
21
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
+ from mct_quantizers import QuantizationMethod
22
22
  from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
23
23
  from model_compression_toolkit.gptq.common.gptq_config import RoundingType
24
24
  from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
@@ -18,7 +18,7 @@ from typing import Dict
18
18
  import numpy as np
19
19
  from model_compression_toolkit.defaultdict import DefaultDict
20
20
 
21
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
+ from mct_quantizers import QuantizationMethod
22
22
  from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
23
23
  from model_compression_toolkit.gptq.common.gptq_config import RoundingType
24
24
  from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
@@ -18,33 +18,34 @@ from typing import Dict, Any
18
18
  from model_compression_toolkit.constants import OPERATORS_SCHEDULING, FUSED_NODES_MAPPING, CUTS, MAX_CUT, OP_ORDER, \
19
19
  OP_RECORD, SHAPE, NODE_OUTPUT_INDEX, NODE_NAME, TOTAL_SIZE, MEM_ELEMENTS
20
20
  from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import SchedulerInfo
21
- from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
21
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
22
+ FrameworkQuantizationCapabilities
22
23
 
23
24
 
24
- def create_model_metadata(tpc: TargetPlatformCapabilities,
25
+ def create_model_metadata(fqc: FrameworkQuantizationCapabilities,
25
26
  scheduling_info: SchedulerInfo = None) -> Dict:
26
27
  """
27
28
  Creates and returns a metadata dictionary for the model, including version information
28
29
  and optional scheduling information.
29
30
 
30
31
  Args:
31
- tpc: A TPC object to get the version.
32
+ fqc: A FQC object to get the version.
32
33
  scheduling_info: An object containing scheduling details and metadata. Default is None.
33
34
 
34
35
  Returns:
35
36
  Dict: A dictionary containing the model's version information and optional scheduling information.
36
37
  """
37
- _metadata = get_versions_dict(tpc)
38
+ _metadata = get_versions_dict(fqc)
38
39
  if scheduling_info:
39
40
  scheduler_metadata = get_scheduler_metadata(scheduler_info=scheduling_info)
40
41
  _metadata['scheduling_info'] = scheduler_metadata
41
42
  return _metadata
42
43
 
43
44
 
44
- def get_versions_dict(tpc) -> Dict:
45
+ def get_versions_dict(fqc) -> Dict:
45
46
  """
46
47
 
47
- Returns: A dictionary with TPC, MCT and TPC-Schema versions.
48
+ Returns: A dictionary with FQC, MCT and FQC-Schema versions.
48
49
 
49
50
  """
50
51
  # imported inside to avoid circular import error
@@ -53,10 +54,10 @@ def get_versions_dict(tpc) -> Dict:
53
54
  @dataclass
54
55
  class TPCVersions:
55
56
  mct_version: str
56
- tpc_minor_version: str = f'{tpc.tp_model.tpc_minor_version}'
57
- tpc_patch_version: str = f'{tpc.tp_model.tpc_patch_version}'
58
- tpc_platform_type: str = f'{tpc.tp_model.tpc_platform_type}'
59
- tpc_schema: str = f'{tpc.tp_model.SCHEMA_VERSION}'
57
+ tpc_minor_version: str = f'{fqc.tpc.tpc_minor_version}'
58
+ tpc_patch_version: str = f'{fqc.tpc.tpc_patch_version}'
59
+ tpc_platform_type: str = f'{fqc.tpc.tpc_platform_type}'
60
+ tpc_schema: str = f'{fqc.tpc.SCHEMA_VERSION}'
60
61
 
61
62
  return asdict(TPCVersions(mct_version))
62
63
 
@@ -17,7 +17,7 @@ from typing import Callable, Tuple
17
17
 
18
18
  from model_compression_toolkit import get_target_platform_capabilities
19
19
  from model_compression_toolkit.constants import TENSORFLOW
20
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
20
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
21
21
  from model_compression_toolkit.verify_packages import FOUND_TF
22
22
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
23
23
  from model_compression_toolkit.core.common.pruning.pruner import Pruner
@@ -26,17 +26,16 @@ from model_compression_toolkit.core.common.pruning.pruning_info import PruningIn
26
26
  from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
27
27
  from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
28
28
  from model_compression_toolkit.logger import Logger
29
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
30
29
  from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
31
30
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
32
31
 
33
32
  if FOUND_TF:
33
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
34
+ AttachTpcToKeras
34
35
  from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
35
36
  from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation
36
37
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
37
38
  from tensorflow.keras.models import Model
38
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
39
- AttachTpcToKeras
40
39
 
41
40
  DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
42
41
 
@@ -44,7 +43,7 @@ if FOUND_TF:
44
43
  target_resource_utilization: ResourceUtilization,
45
44
  representative_data_gen: Callable,
46
45
  pruning_config: PruningConfig = PruningConfig(),
47
- target_platform_capabilities: TargetPlatformModel = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]:
46
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]:
48
47
  """
49
48
  Perform structured pruning on a Keras model to meet a specified target resource utilization.
50
49
  This function prunes the provided model according to the target resource utilization by grouping and pruning
@@ -62,7 +61,7 @@ if FOUND_TF:
62
61
  target_resource_utilization (ResourceUtilization): The target Key Performance Indicators to be achieved through pruning.
63
62
  representative_data_gen (Callable): A function to generate representative data for pruning analysis.
64
63
  pruning_config (PruningConfig): Configuration settings for the pruning process. Defaults to standard config.
65
- target_platform_capabilities (TargetPlatformCapabilities): Platform-specific constraints and capabilities. Defaults to DEFAULT_KERAS_TPC.
64
+ target_platform_capabilities (FrameworkQuantizationCapabilities): Platform-specific constraints and capabilities. Defaults to DEFAULT_KERAS_TPC.
66
65
 
67
66
  Returns:
68
67
  Tuple[Model, PruningInfo]: A tuple containing the pruned Keras model and associated pruning information.
@@ -16,7 +16,7 @@
16
16
  from typing import Callable, Tuple
17
17
  from model_compression_toolkit import get_target_platform_capabilities
18
18
  from model_compression_toolkit.constants import PYTORCH
19
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
19
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
20
20
  from model_compression_toolkit.verify_packages import FOUND_TORCH
21
21
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
22
22
  from model_compression_toolkit.core.common.pruning.pruner import Pruner
@@ -25,7 +25,6 @@ from model_compression_toolkit.core.common.pruning.pruning_info import PruningIn
25
25
  from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
26
26
  from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
27
27
  from model_compression_toolkit.logger import Logger
28
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
29
28
  from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
30
29
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
31
30
 
@@ -38,7 +37,7 @@ if FOUND_TORCH:
38
37
  PruningPytorchImplementation
39
38
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
40
39
  from torch.nn import Module
41
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
40
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
42
41
  AttachTpcToPytorch
43
42
 
44
43
  # Set the default Target Platform Capabilities (TPC) for PyTorch.
@@ -48,7 +47,7 @@ if FOUND_TORCH:
48
47
  target_resource_utilization: ResourceUtilization,
49
48
  representative_data_gen: Callable,
50
49
  pruning_config: PruningConfig = PruningConfig(),
51
- target_platform_capabilities: TargetPlatformModel = DEFAULT_PYOTRCH_TPC) -> \
50
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYOTRCH_TPC) -> \
52
51
  Tuple[Module, PruningInfo]:
53
52
  """
54
53
  Perform structured pruning on a Pytorch model to meet a specified target resource utilization.
@@ -121,12 +120,12 @@ if FOUND_TORCH:
121
120
 
122
121
  # Attach TPC to framework
123
122
  attach2pytorch = AttachTpcToPytorch()
124
- target_platform_capabilities = attach2pytorch.attach(target_platform_capabilities)
123
+ framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities)
125
124
 
126
125
  # Convert the original Pytorch model to an internal graph representation.
127
126
  float_graph = read_model_to_graph(model,
128
127
  representative_data_gen,
129
- target_platform_capabilities,
128
+ framework_platform_capabilities,
130
129
  DEFAULT_PYTORCH_INFO,
131
130
  fw_impl)
132
131
 
@@ -143,7 +142,7 @@ if FOUND_TORCH:
143
142
  target_resource_utilization,
144
143
  representative_data_gen,
145
144
  pruning_config,
146
- target_platform_capabilities)
145
+ framework_platform_capabilities)
147
146
 
148
147
  # Apply the pruning process.
149
148
  pruned_graph = pruner.prune_graph()
@@ -22,17 +22,18 @@ from model_compression_toolkit.core.common.quantization.quantize_graph_weights i
22
22
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
23
23
  from model_compression_toolkit.logger import Logger
24
24
  from model_compression_toolkit.constants import TENSORFLOW
25
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
25
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
26
26
  from model_compression_toolkit.verify_packages import FOUND_TF
27
27
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
28
28
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
29
29
  MixedPrecisionQuantizationConfig
30
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
31
30
  from model_compression_toolkit.core.runner import core_runner
32
31
  from model_compression_toolkit.ptq.runner import ptq_runner
33
32
  from model_compression_toolkit.metadata import create_model_metadata
34
33
 
35
34
  if FOUND_TF:
35
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
36
+ AttachTpcToKeras
36
37
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
37
38
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
38
39
  from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
@@ -42,8 +43,6 @@ if FOUND_TF:
42
43
 
43
44
  from model_compression_toolkit import get_target_platform_capabilities
44
45
  from mct_quantizers.keras.metadata import add_metadata
45
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
46
- AttachTpcToKeras
47
46
 
48
47
  DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
49
48
 
@@ -52,7 +51,7 @@ if FOUND_TF:
52
51
  representative_data_gen: Callable,
53
52
  target_resource_utilization: ResourceUtilization = None,
54
53
  core_config: CoreConfig = CoreConfig(),
55
- target_platform_capabilities: TargetPlatformModel = DEFAULT_KERAS_TPC):
54
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
56
55
  """
57
56
  Quantize a trained Keras model using post-training quantization. The model is quantized using a
58
57
  symmetric constraint quantization thresholds (power of two).
@@ -139,7 +138,7 @@ if FOUND_TF:
139
138
  fw_impl = KerasImplementation()
140
139
 
141
140
  attach2keras = AttachTpcToKeras()
142
- target_platform_capabilities = attach2keras.attach(
141
+ framework_platform_capabilities = attach2keras.attach(
143
142
  target_platform_capabilities,
144
143
  custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
145
144
 
@@ -149,7 +148,7 @@ if FOUND_TF:
149
148
  core_config=core_config,
150
149
  fw_info=fw_info,
151
150
  fw_impl=fw_impl,
152
- tpc=target_platform_capabilities,
151
+ fqc=framework_platform_capabilities,
153
152
  target_resource_utilization=target_resource_utilization,
154
153
  tb_w=tb_w)
155
154
 
@@ -177,9 +176,9 @@ if FOUND_TF:
177
176
  fw_info)
178
177
 
179
178
  exportable_model, user_info = get_exportable_keras_model(graph_with_stats_correction)
180
- if target_platform_capabilities.tp_model.add_metadata:
179
+ if framework_platform_capabilities.tpc.add_metadata:
181
180
  exportable_model = add_metadata(exportable_model,
182
- create_model_metadata(tpc=target_platform_capabilities,
181
+ create_model_metadata(fqc=framework_platform_capabilities,
183
182
  scheduling_info=scheduling_info))
184
183
  return exportable_model, user_info
185
184