mct-nightly 2.2.0.20250113.134913__py3-none-any.whl → 2.2.0.20250114.134534__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/RECORD +102 -104
  3. model_compression_toolkit/__init__.py +2 -2
  4. model_compression_toolkit/core/common/framework_info.py +1 -3
  5. model_compression_toolkit/core/common/fusion/layer_fusing.py +6 -5
  6. model_compression_toolkit/core/common/graph/base_graph.py +20 -21
  7. model_compression_toolkit/core/common/graph/base_node.py +44 -17
  8. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +7 -6
  9. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +187 -0
  10. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +0 -6
  11. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +35 -162
  12. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +36 -62
  13. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +668 -0
  14. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +25 -202
  15. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +74 -51
  16. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +3 -5
  17. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  18. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +7 -6
  19. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +0 -1
  20. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +0 -1
  21. model_compression_toolkit/core/common/pruning/pruner.py +5 -3
  22. model_compression_toolkit/core/common/quantization/bit_width_config.py +6 -12
  23. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -2
  24. model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
  25. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -1
  26. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  27. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +1 -1
  28. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +1 -1
  29. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +1 -1
  30. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +1 -1
  31. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +1 -1
  32. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +1 -1
  33. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +15 -14
  34. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
  35. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +1 -1
  36. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  37. model_compression_toolkit/core/graph_prep_runner.py +12 -11
  38. model_compression_toolkit/core/keras/default_framework_info.py +1 -1
  39. model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +1 -2
  40. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +5 -6
  41. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  42. model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
  43. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
  44. model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -1
  45. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +4 -5
  46. model_compression_toolkit/core/runner.py +33 -60
  47. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +1 -1
  48. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +1 -1
  49. model_compression_toolkit/gptq/keras/quantization_facade.py +8 -9
  50. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  51. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
  52. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
  53. model_compression_toolkit/gptq/pytorch/quantization_facade.py +8 -9
  54. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  55. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
  56. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
  57. model_compression_toolkit/metadata.py +11 -10
  58. model_compression_toolkit/pruning/keras/pruning_facade.py +5 -6
  59. model_compression_toolkit/pruning/pytorch/pruning_facade.py +6 -7
  60. model_compression_toolkit/ptq/keras/quantization_facade.py +8 -9
  61. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -9
  62. model_compression_toolkit/qat/keras/quantization_facade.py +5 -6
  63. model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py +1 -1
  64. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
  65. model_compression_toolkit/qat/pytorch/quantization_facade.py +5 -9
  66. model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py +1 -1
  67. model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py +1 -1
  68. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
  69. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +1 -1
  70. model_compression_toolkit/target_platform_capabilities/__init__.py +9 -0
  71. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  72. model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +2 -2
  73. model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +18 -18
  74. model_compression_toolkit/target_platform_capabilities/schema/v1.py +13 -13
  75. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/__init__.py +6 -6
  76. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2fw.py +10 -10
  77. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2keras.py +3 -3
  78. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2pytorch.py +3 -2
  79. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/current_tpc.py +8 -8
  80. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities.py → targetplatform2framework/framework_quantization_capabilities.py} +40 -40
  81. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities_component.py → targetplatform2framework/framework_quantization_capabilities_component.py} +2 -2
  82. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/layer_filter_params.py +0 -1
  83. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/operations_to_layers.py +8 -8
  84. model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +24 -24
  85. model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +18 -18
  86. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +3 -3
  87. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/{tp_model.py → tpc.py} +31 -32
  88. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +3 -3
  89. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/{tp_model.py → tpc.py} +27 -27
  90. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +4 -4
  91. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/{tp_model.py → tpc.py} +27 -27
  92. model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +1 -2
  93. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +2 -1
  94. model_compression_toolkit/trainable_infrastructure/keras/activation_quantizers/lsq/symmetric_lsq.py +1 -2
  95. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +1 -1
  96. model_compression_toolkit/xquant/common/model_folding_utils.py +7 -6
  97. model_compression_toolkit/xquant/keras/keras_report_utils.py +4 -4
  98. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -3
  99. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +0 -105
  100. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +0 -33
  101. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py +0 -528
  102. model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +0 -23
  103. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/LICENSE.md +0 -0
  104. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/WHEEL +0 -0
  105. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/top_level.txt +0 -0
  106. /model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attribute_filter.py +0 -0
@@ -21,38 +21,38 @@ from typing import List, Any, Dict, Tuple
21
21
  from model_compression_toolkit.logger import Logger
22
22
  from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import \
23
23
  get_config_options_by_operators_set, get_default_op_quantization_config, get_opset_by_name
24
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.operations_to_layers import \
25
- OperationsToLayers, OperationsSetToLayers
26
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
27
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.layer_filter_params import LayerFilterParams
24
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.operations_to_layers import OperationsToLayers, \
25
+ OperationsSetToLayers
26
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities_component import \
27
+ FrameworkQuantizationCapabilitiesComponent
28
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.layer_filter_params import LayerFilterParams
28
29
  from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass
29
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel, OperatorsSetBase, \
30
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities, OperatorsSetBase, \
30
31
  OpQuantizationConfig, QuantizationConfigOptions
31
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
32
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import _current_tpc
32
33
 
33
-
34
- class TargetPlatformCapabilities(ImmutableClass):
34
+ class FrameworkQuantizationCapabilities(ImmutableClass):
35
35
  """
36
36
  Attach framework information to a modeled hardware.
37
37
  """
38
38
  def __init__(self,
39
- tp_model: TargetPlatformModel,
39
+ tpc: TargetPlatformCapabilities,
40
40
  name: str = "base"):
41
41
  """
42
42
 
43
43
  Args:
44
- tp_model (TargetPlatformModel): Modeled hardware to attach framework information to.
45
- name (str): Name of the TargetPlatformCapabilities.
44
+ tpc (TargetPlatformCapabilities): Modeled hardware to attach framework information to.
45
+ name (str): Name of the FrameworkQuantizationCapabilities.
46
46
  """
47
47
 
48
48
  super().__init__()
49
49
  self.name = name
50
- assert isinstance(tp_model, TargetPlatformModel), f'Target platform model that was passed to TargetPlatformCapabilities must be of type TargetPlatformModel, but has type of {type(tp_model)}'
51
- self.tp_model = tp_model
50
+ assert isinstance(tpc, TargetPlatformCapabilities), f'Target platform model that was passed to FrameworkQuantizationCapabilities must be of type TargetPlatformCapabilities, but has type of {type(tpc)}'
51
+ self.tpc = tpc
52
52
  self.op_sets_to_layers = OperationsToLayers() # Init an empty OperationsToLayers
53
53
  self.layer2qco, self.filterlayer2qco = {}, {} # Init empty mappings from layers/LayerFilterParams to QC options
54
54
  # Track the unused opsets for warning purposes.
55
- self.__tp_model_opsets_not_used = [s.name for s in tp_model.operator_set]
55
+ self.__tpc_opsets_not_used = [s.name for s in tpc.operator_set]
56
56
  self.remove_fusing_names_from_not_used_list()
57
57
 
58
58
  def get_layers_by_opset_name(self, opset_name: str) -> List[Any]:
@@ -66,9 +66,9 @@ class TargetPlatformCapabilities(ImmutableClass):
66
66
  Returns:
67
67
  List of layers/LayerFilterParams that are attached to the opset name.
68
68
  """
69
- opset = get_opset_by_name(self.tp_model, opset_name)
69
+ opset = get_opset_by_name(self.tpc, opset_name)
70
70
  if opset is None:
71
- Logger.warning(f'{opset_name} was not found in TargetPlatformCapabilities.')
71
+ Logger.warning(f'{opset_name} was not found in FrameworkQuantizationCapabilities.')
72
72
  return None
73
73
  return self.get_layers_by_opset(opset)
74
74
 
@@ -100,9 +100,9 @@ class TargetPlatformCapabilities(ImmutableClass):
100
100
 
101
101
  """
102
102
  res = []
103
- if self.tp_model.fusing_patterns is None:
103
+ if self.tpc.fusing_patterns is None:
104
104
  return res
105
- for p in self.tp_model.fusing_patterns:
105
+ for p in self.tpc.fusing_patterns:
106
106
  ops = [self.get_layers_by_opset(x) for x in p.operator_groups]
107
107
  res.extend(itertools.product(*ops))
108
108
  return [list(x) for x in res]
@@ -111,47 +111,47 @@ class TargetPlatformCapabilities(ImmutableClass):
111
111
  def get_info(self) -> Dict[str, Any]:
112
112
  """
113
113
 
114
- Returns: Summarization of information in the TargetPlatformCapabilities.
114
+ Returns: Summarization of information in the FrameworkQuantizationCapabilities.
115
115
 
116
116
  """
117
117
  return {"Target Platform Capabilities": self.name,
118
- "Minor version": self.tp_model.tpc_minor_version,
119
- "Patch version": self.tp_model.tpc_patch_version,
120
- "Platform type": self.tp_model.tpc_platform_type,
121
- "Target Platform Model": self.tp_model.get_info(),
118
+ "Minor version": self.tpc.tpc_minor_version,
119
+ "Patch version": self.tpc.tpc_patch_version,
120
+ "Platform type": self.tpc.tpc_platform_type,
121
+ "Target Platform Model": self.tpc.get_info(),
122
122
  "Operations to layers": {op2layer.name:[l.__name__ for l in op2layer.layers] for op2layer in self.op_sets_to_layers.op_sets_to_layers}}
123
123
 
124
124
  def show(self):
125
125
  """
126
126
 
127
- Display the TargetPlatformCapabilities.
127
+ Display the FrameworkQuantizationCapabilities.
128
128
 
129
129
  """
130
130
  pprint.pprint(self.get_info(), sort_dicts=False, width=110)
131
131
 
132
- def append_component(self, tpc_component: TargetPlatformCapabilitiesComponent):
132
+ def append_component(self, tpc_component: FrameworkQuantizationCapabilitiesComponent):
133
133
  """
134
- Append a Component (like OperationsSetToLayers) to the TargetPlatformCapabilities.
134
+ Append a Component (like OperationsSetToLayers) to the FrameworkQuantizationCapabilities.
135
135
 
136
136
  Args:
137
- tpc_component: Component to append to TargetPlatformCapabilities.
137
+ tpc_component: Component to append to FrameworkQuantizationCapabilities.
138
138
 
139
139
  """
140
140
  if isinstance(tpc_component, OperationsSetToLayers):
141
141
  self.op_sets_to_layers += tpc_component
142
142
  else:
143
- Logger.critical(f"Attempt to append an unrecognized 'TargetPlatformCapabilitiesComponent' of type: '{type(tpc_component)}'. Ensure the component is compatible.") # pragma: no cover
143
+ Logger.critical(f"Attempt to append an unrecognized 'FrameworkQuantizationCapabilitiesComponent' of type: '{type(tpc_component)}'. Ensure the component is compatible.") # pragma: no cover
144
144
 
145
145
  def __enter__(self):
146
146
  """
147
- Init a TargetPlatformCapabilities object.
147
+ Init a FrameworkQuantizationCapabilities object.
148
148
  """
149
149
  _current_tpc.set(self)
150
150
  return self
151
151
 
152
152
  def __exit__(self, exc_type, exc_value, tb):
153
153
  """
154
- Finalize a TargetPlatformCapabilities object.
154
+ Finalize a FrameworkQuantizationCapabilities object.
155
155
  """
156
156
  if exc_value is not None:
157
157
  print(exc_value, exc_value.args)
@@ -164,11 +164,11 @@ class TargetPlatformCapabilities(ImmutableClass):
164
164
  def get_default_op_qc(self) -> OpQuantizationConfig:
165
165
  """
166
166
 
167
- Returns: The default OpQuantizationConfig of the TargetPlatformModel that is attached
168
- to the TargetPlatformCapabilities.
167
+ Returns: The default OpQuantizationConfig of the TargetPlatformCapabilities that is attached
168
+ to the FrameworkQuantizationCapabilities.
169
169
 
170
170
  """
171
- return get_default_op_quantization_config(self.tp_model)
171
+ return get_default_op_quantization_config(self.tpc)
172
172
 
173
173
 
174
174
  def _get_config_options_mapping(self) -> Tuple[Dict[Any, QuantizationConfigOptions],
@@ -184,9 +184,9 @@ class TargetPlatformCapabilities(ImmutableClass):
184
184
  filterlayer2qco = {}
185
185
  for op2layers in self.op_sets_to_layers.op_sets_to_layers:
186
186
  for l in op2layers.layers:
187
- qco = get_config_options_by_operators_set(self.tp_model, op2layers.name)
187
+ qco = get_config_options_by_operators_set(self.tpc, op2layers.name)
188
188
  if qco is None:
189
- qco = self.tp_model.default_qco
189
+ qco = self.tpc.default_qco
190
190
 
191
191
  # here, we need to take care of mapping a general attribute name into a framework and
192
192
  # layer type specific attribute name.
@@ -208,8 +208,8 @@ class TargetPlatformCapabilities(ImmutableClass):
208
208
  Remove OperatorSets names from the list of the unused sets (so a warning
209
209
  will not be displayed).
210
210
  """
211
- if self.tp_model.fusing_patterns is not None:
212
- for f in self.tp_model.fusing_patterns:
211
+ if self.tpc.fusing_patterns is not None:
212
+ for f in self.tpc.fusing_patterns:
213
213
  for s in f.operator_groups:
214
214
  self.remove_opset_from_not_used_list(s.name)
215
215
 
@@ -222,8 +222,8 @@ class TargetPlatformCapabilities(ImmutableClass):
222
222
  opset_to_remove: OperatorsSet name to remove.
223
223
 
224
224
  """
225
- if opset_to_remove in self.__tp_model_opsets_not_used:
226
- self.__tp_model_opsets_not_used.remove(opset_to_remove)
225
+ if opset_to_remove in self.__tpc_opsets_not_used:
226
+ self.__tpc_opsets_not_used.remove(opset_to_remove)
227
227
 
228
228
  @property
229
229
  def is_simd_padding(self) -> bool:
@@ -232,4 +232,4 @@ class TargetPlatformCapabilities(ImmutableClass):
232
232
  Returns: Check if the TP model defines that padding due to SIMD constrains occurs.
233
233
 
234
234
  """
235
- return self.tp_model.is_simd_padding
235
+ return self.tpc.is_simd_padding
@@ -13,10 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
16
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import _current_tpc
17
17
 
18
18
 
19
- class TargetPlatformCapabilitiesComponent:
19
+ class FrameworkQuantizationCapabilitiesComponent:
20
20
  def __init__(self, name: str):
21
21
  self.name = name
22
22
  _current_tpc.get().append_component(self)
@@ -14,7 +14,6 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from typing import Any
17
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import AttributeFilter
18
17
 
19
18
 
20
19
  class LayerFilterParams:
@@ -18,13 +18,13 @@ from typing import List, Any, Dict
18
18
  from model_compression_toolkit.logger import Logger
19
19
  from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import \
20
20
  get_config_options_by_operators_set, is_opset_in_model
21
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
22
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
23
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorsSetBase, OperatorSetConcat
21
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import _current_tpc
22
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities_component import FrameworkQuantizationCapabilitiesComponent
23
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorsSetBase, OperatorSetGroup
24
24
  from model_compression_toolkit import DefaultDict
25
25
 
26
26
 
27
- class OperationsSetToLayers(TargetPlatformCapabilitiesComponent):
27
+ class OperationsSetToLayers(FrameworkQuantizationCapabilitiesComponent):
28
28
  """
29
29
  Associate an OperatorsSet to a list of framework's layers.
30
30
  """
@@ -57,7 +57,7 @@ class OperationsSetToLayers(TargetPlatformCapabilitiesComponent):
57
57
 
58
58
  class OperationsToLayers:
59
59
  """
60
- Gather multiple OperationsSetToLayers to represent mapping of framework's layers to TargetPlatformModel OperatorsSet.
60
+ Gather multiple OperationsSetToLayers to represent mapping of framework's layers to TargetPlatformCapabilities OperatorsSet.
61
61
  """
62
62
  def __init__(self,
63
63
  op_sets_to_layers: List[OperationsSetToLayers]=None):
@@ -88,7 +88,7 @@ class OperationsToLayers:
88
88
  for o in self.op_sets_to_layers:
89
89
  if op.name == o.name:
90
90
  return o.layers
91
- if isinstance(op, OperatorSetConcat): # If its a concat - return all layers from all OperatorsSets that in the OperatorSetConcat
91
+ if isinstance(op, OperatorSetGroup): # If its a concat - return all layers from all OperatorsSets that in the OperatorSetGroup
92
92
  layers = []
93
93
  for o in op.operators_set:
94
94
  layers.extend(self.get_layers_by_op(o))
@@ -142,9 +142,9 @@ class OperationsToLayers:
142
142
  assert ops2layers.name not in existing_opset_names, f'OperationsSetToLayers names should be unique, but {ops2layers.name} appears to violate it.'
143
143
  existing_opset_names.append(ops2layers.name)
144
144
 
145
- # Assert that a layer does not appear in more than a single OperatorsSet in the TargetPlatformModel.
145
+ # Assert that a layer does not appear in more than a single OperatorsSet in the TargetPlatformCapabilities.
146
146
  for layer in ops2layers.layers:
147
- qco_by_opset_name = get_config_options_by_operators_set(_current_tpc.get().tp_model, ops2layers.name)
147
+ qco_by_opset_name = get_config_options_by_operators_set(_current_tpc.get().tpc, ops2layers.name)
148
148
  if layer in existing_layers:
149
149
  Logger.critical(f'Found layer {layer.__name__} in more than one '
150
150
  f'OperatorsSet') # pragma: no cover
@@ -16,34 +16,34 @@ from pathlib import Path
16
16
  from typing import Union
17
17
 
18
18
  from model_compression_toolkit.logger import Logger
19
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
19
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
20
20
  import json
21
21
 
22
22
 
23
- def load_target_platform_model(tp_model_or_path: Union[TargetPlatformModel, str]) -> TargetPlatformModel:
23
+ def load_target_platform_model(tpc_obj_or_path: Union[TargetPlatformCapabilities, str]) -> TargetPlatformCapabilities:
24
24
  """
25
- Parses the tp_model input, which can be either a TargetPlatformModel object
25
+ Parses the tpc input, which can be either a TargetPlatformCapabilities object
26
26
  or a string path to a JSON file.
27
27
 
28
28
  Parameters:
29
- tp_model_or_path (Union[TargetPlatformModel, str]): Input target platform model or path to .JSON file.
29
+ tpc_obj_or_path (Union[TargetPlatformModel, str]): Input target platform model or path to .JSON file.
30
30
 
31
31
  Returns:
32
- TargetPlatformModel: The parsed TargetPlatformModel.
32
+ TargetPlatformCapabilities: The parsed TargetPlatformCapabilities.
33
33
 
34
34
  Raises:
35
35
  FileNotFoundError: If the JSON file does not exist.
36
- ValueError: If the JSON content is invalid or cannot initialize the TargetPlatformModel.
37
- TypeError: If the input is neither a TargetPlatformModel nor a valid JSON file path.
36
+ ValueError: If the JSON content is invalid or cannot initialize the TargetPlatformCapabilities.
37
+ TypeError: If the input is neither a TargetPlatformCapabilities nor a valid JSON file path.
38
38
  """
39
- if isinstance(tp_model_or_path, TargetPlatformModel):
40
- return tp_model_or_path
39
+ if isinstance(tpc_obj_or_path, TargetPlatformCapabilities):
40
+ return tpc_obj_or_path
41
41
 
42
- if isinstance(tp_model_or_path, str):
43
- path = Path(tp_model_or_path)
42
+ if isinstance(tpc_obj_or_path, str):
43
+ path = Path(tpc_obj_or_path)
44
44
 
45
45
  if not path.exists() or not path.is_file():
46
- raise FileNotFoundError(f"The path '{tp_model_or_path}' is not a valid file.")
46
+ raise FileNotFoundError(f"The path '{tpc_obj_or_path}' is not a valid file.")
47
47
  # Verify that the file has a .json extension
48
48
  if path.suffix.lower() != '.json':
49
49
  raise ValueError(f"The file '{path}' does not have a '.json' extension.")
@@ -51,35 +51,35 @@ def load_target_platform_model(tp_model_or_path: Union[TargetPlatformModel, str]
51
51
  with path.open('r', encoding='utf-8') as file:
52
52
  data = file.read()
53
53
  except OSError as e:
54
- raise ValueError(f"Error reading the file '{tp_model_or_path}': {e.strerror}.") from e
54
+ raise ValueError(f"Error reading the file '{tpc_obj_or_path}': {e.strerror}.") from e
55
55
 
56
56
  try:
57
- return TargetPlatformModel.parse_raw(data)
57
+ return TargetPlatformCapabilities.parse_raw(data)
58
58
  except ValueError as e:
59
- raise ValueError(f"Invalid JSON for loading TargetPlatformModel in '{tp_model_or_path}': {e}.") from e
59
+ raise ValueError(f"Invalid JSON for loading TargetPlatformCapabilities in '{tpc_obj_or_path}': {e}.") from e
60
60
  except Exception as e:
61
- raise ValueError(f"Unexpected error while initializing TargetPlatformModel: {e}.") from e
61
+ raise ValueError(f"Unexpected error while initializing TargetPlatformCapabilities: {e}.") from e
62
62
 
63
63
  raise TypeError(
64
- f"tp_model_or_path must be either a TargetPlatformModel instance or a string path to a JSON file, "
65
- f"but received type '{type(tp_model_or_path).__name__}'."
64
+ f"tpc_obj_or_path must be either a TargetPlatformCapabilities instance or a string path to a JSON file, "
65
+ f"but received type '{type(tpc_obj_or_path).__name__}'."
66
66
  )
67
67
 
68
68
 
69
- def export_target_platform_model(model: TargetPlatformModel, export_path: Union[str, Path]) -> None:
69
+ def export_target_platform_model(model: TargetPlatformCapabilities, export_path: Union[str, Path]) -> None:
70
70
  """
71
- Exports a TargetPlatformModel instance to a JSON file.
71
+ Exports a TargetPlatformCapabilities instance to a JSON file.
72
72
 
73
73
  Parameters:
74
- model (TargetPlatformModel): The TargetPlatformModel instance to export.
74
+ model (TargetPlatformCapabilities): The TargetPlatformCapabilities instance to export.
75
75
  export_path (Union[str, Path]): The file path to export the model to.
76
76
 
77
77
  Raises:
78
- ValueError: If the model is not an instance of TargetPlatformModel.
78
+ ValueError: If the model is not an instance of TargetPlatformCapabilities.
79
79
  OSError: If there is an issue writing to the file.
80
80
  """
81
- if not isinstance(model, TargetPlatformModel):
82
- raise ValueError("The provided model is not a valid TargetPlatformModel instance.")
81
+ if not isinstance(model, TargetPlatformCapabilities):
82
+ raise ValueError("The provided model is not a valid TargetPlatformCapabilities instance.")
83
83
 
84
84
  path = Path(export_path)
85
85
  try:
@@ -15,60 +15,60 @@
15
15
  from model_compression_toolkit.constants import TENSORFLOW, PYTORCH
16
16
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, \
17
17
  TFLITE_TP_MODEL, QNNPACK_TP_MODEL
18
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
18
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
19
19
 
20
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model as get_tp_model_imx500_v1
21
- from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tp_model import get_tp_model as get_tp_model_tflite_v1
22
- from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model as get_tp_model_qnnpack_v1
20
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc import get_tpc as get_tpc_imx500_v1
21
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc import get_tpc as get_tpc_tflite_v1
22
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc import get_tpc as get_tpc_qnnpack_v1
23
23
 
24
24
 
25
25
  # TODO: These methods need to be replaced once modifying the TPC API.
26
26
 
27
27
  def get_target_platform_capabilities(fw_name: str,
28
28
  target_platform_name: str,
29
- target_platform_version: str = None) -> TargetPlatformModel:
29
+ target_platform_version: str = None) -> TargetPlatformCapabilities:
30
30
  """
31
- This is a degenerated function that only returns the MCT default TargetPlatformModel object, to comply with the
31
+ This is a degenerated function that only returns the MCT default TargetPlatformCapabilities object, to comply with the
32
32
  existing TPC API.
33
33
 
34
34
  Args:
35
- fw_name: Framework name of the TargetPlatformCapabilities.
35
+ fw_name: Framework name of the FrameworkQuantizationCapabilities.
36
36
  target_platform_name: Target platform model name the model will use for inference.
37
37
  target_platform_version: Target platform capabilities version.
38
38
 
39
39
  Returns:
40
- A default TargetPlatformModel object.
40
+ A default TargetPlatformCapabilities object.
41
41
  """
42
42
 
43
43
  assert fw_name in [TENSORFLOW, PYTORCH], f"Unsupported framework {fw_name}."
44
44
 
45
45
  if target_platform_name == DEFAULT_TP_MODEL:
46
- return get_tp_model_imx500_v1()
46
+ return get_tpc_imx500_v1()
47
47
 
48
48
  assert target_platform_version == 'v1' or target_platform_version is None, \
49
49
  "The usage of get_target_platform_capabilities API is supported only with the default TPC ('v1')."
50
50
 
51
51
  if target_platform_name == IMX500_TP_MODEL:
52
- return get_tp_model_imx500_v1()
52
+ return get_tpc_imx500_v1()
53
53
  elif target_platform_name == TFLITE_TP_MODEL:
54
- return get_tp_model_tflite_v1()
54
+ return get_tpc_tflite_v1()
55
55
  elif target_platform_name == QNNPACK_TP_MODEL:
56
- return get_tp_model_qnnpack_v1()
56
+ return get_tpc_qnnpack_v1()
57
57
 
58
58
  raise ValueError(f"Unsupported target platform name {target_platform_name}.")
59
59
 
60
60
 
61
- def get_tpc_model(name: str, tp_model: TargetPlatformModel):
61
+ def get_tpc_model(name: str, tpc: TargetPlatformCapabilities):
62
62
  """
63
- This is a utility method that just returns the TargetPlatformModel that it receives, to support existing TPC API.
63
+ This is a utility method that just returns the TargetPlatformCapabilities that it receives, to support existing TPC API.
64
64
 
65
65
  Args:
66
- name: the name of the TargetPlatformModel (not used in this function).
67
- tp_model: a TargetPlatformModel to return.
66
+ name: the name of the TargetPlatformCapabilities (not used in this function).
67
+ tpc: a TargetPlatformCapabilities to return.
68
68
 
69
69
  Returns:
70
- The given TargetPlatformModel object.
70
+ The given TargetPlatformCapabilities object.
71
71
 
72
72
  """
73
73
 
74
- return tp_model
74
+ return tpc
@@ -13,13 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from model_compression_toolkit.verify_packages import FOUND_TORCH, FOUND_TF
16
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model, generate_tp_model, \
16
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc import get_tpc, generate_tpc, \
17
17
  get_op_quantization_configs
18
18
  if FOUND_TF:
19
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model as get_keras_tpc_latest
19
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc import get_tpc as get_keras_tpc_latest
20
20
  from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
21
21
  get_tpc_model as generate_keras_tpc
22
22
  if FOUND_TORCH:
23
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model as get_pytorch_tpc_latest
23
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc import get_tpc as get_pytorch_tpc_latest
24
24
  from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
25
25
  get_tpc_model as generate_pytorch_tpc
@@ -16,36 +16,35 @@ from typing import List, Tuple
16
16
 
17
17
  import model_compression_toolkit as mct
18
18
  import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
19
+ from mct_quantizers import QuantizationMethod
19
20
  from model_compression_toolkit.constants import FLOAT_BITWIDTH
20
21
  from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS, \
21
22
  IMX500_TP_MODEL
22
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel, Signedness, \
23
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities, Signedness, \
23
24
  AttributeQuantizationConfig, OpQuantizationConfig
24
25
 
25
- tp = mct.target_platform
26
26
 
27
-
28
- def get_tp_model() -> TargetPlatformModel:
27
+ def get_tpc() -> TargetPlatformCapabilities:
29
28
  """
30
29
  A method that generates a default target platform model, with base 8-bit quantization configuration and 8, 4, 2
31
30
  bits configuration list for mixed-precision quantization.
32
31
  NOTE: in order to generate a target platform model with different configurations but with the same Operators Sets
33
32
  (for tests, experiments, etc.), use this method implementation as a test-case, i.e., override the
34
- 'get_op_quantization_configs' method and use its output to call 'generate_tp_model' with your configurations.
33
+ 'get_op_quantization_configs' method and use its output to call 'generate_tpc' with your configurations.
35
34
 
36
- Returns: A TargetPlatformModel object.
35
+ Returns: A TargetPlatformCapabilities object.
37
36
 
38
37
  """
39
38
  base_config, mixed_precision_cfg_list, default_config = get_op_quantization_configs()
40
- return generate_tp_model(default_config=default_config,
41
- base_config=base_config,
42
- mixed_precision_cfg_list=mixed_precision_cfg_list,
43
- name='imx500_tp_model')
39
+ return generate_tpc(default_config=default_config,
40
+ base_config=base_config,
41
+ mixed_precision_cfg_list=mixed_precision_cfg_list,
42
+ name='imx500_tpc')
44
43
 
45
44
 
46
45
  def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig], OpQuantizationConfig]:
47
46
  """
48
- Creates a default configuration object for 8-bit quantization, to be used to set a default TargetPlatformModel.
47
+ Creates a default configuration object for 8-bit quantization, to be used to set a default TargetPlatformCapabilities.
49
48
  In addition, creates a default configuration objects list (with 8, 4 and 2 bit quantization) to be used as
50
49
  default configuration for mixed-precision quantization.
51
50
 
@@ -60,7 +59,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza
60
59
 
61
60
  # define a default quantization config for all non-specified weights attributes.
62
61
  default_weight_attr_config = AttributeQuantizationConfig(
63
- weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
62
+ weights_quantization_method=QuantizationMethod.POWER_OF_TWO,
64
63
  weights_n_bits=8,
65
64
  weights_per_channel_threshold=False,
66
65
  enable_weights_quantization=False,
@@ -69,7 +68,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza
69
68
 
70
69
  # define a quantization config to quantize the kernel (for layers where there is a kernel attribute).
71
70
  kernel_base_config = AttributeQuantizationConfig(
72
- weights_quantization_method=tp.QuantizationMethod.SYMMETRIC,
71
+ weights_quantization_method=QuantizationMethod.SYMMETRIC,
73
72
  weights_n_bits=8,
74
73
  weights_per_channel_threshold=True,
75
74
  enable_weights_quantization=True,
@@ -77,7 +76,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza
77
76
 
78
77
  # define a quantization config to quantize the bias (for layers where there is a bias attribute).
79
78
  bias_config = AttributeQuantizationConfig(
80
- weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
79
+ weights_quantization_method=QuantizationMethod.POWER_OF_TWO,
81
80
  weights_n_bits=FLOAT_BITWIDTH,
82
81
  weights_per_channel_threshold=False,
83
82
  enable_weights_quantization=False,
@@ -92,7 +91,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza
92
91
  eight_bits_default = schema.OpQuantizationConfig(
93
92
  default_weight_attr_config=default_weight_attr_config,
94
93
  attr_weights_configs_mapping={},
95
- activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
94
+ activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
96
95
  activation_n_bits=8,
97
96
  supported_input_activation_n_bits=8,
98
97
  enable_activation_quantization=True,
@@ -106,7 +105,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza
106
105
  linear_eight_bits = schema.OpQuantizationConfig(
107
106
  default_weight_attr_config=default_weight_attr_config,
108
107
  attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config},
109
- activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
108
+ activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
110
109
  activation_n_bits=8,
111
110
  supported_input_activation_n_bits=8,
112
111
  enable_activation_quantization=True,
@@ -131,22 +130,22 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza
131
130
  return linear_eight_bits, mixed_precision_cfg_list, eight_bits_default
132
131
 
133
132
 
134
- def generate_tp_model(default_config: OpQuantizationConfig,
135
- base_config: OpQuantizationConfig,
136
- mixed_precision_cfg_list: List[OpQuantizationConfig],
137
- name: str) -> TargetPlatformModel:
133
+ def generate_tpc(default_config: OpQuantizationConfig,
134
+ base_config: OpQuantizationConfig,
135
+ mixed_precision_cfg_list: List[OpQuantizationConfig],
136
+ name: str) -> TargetPlatformCapabilities:
138
137
  """
139
- Generates TargetPlatformModel with default defined Operators Sets, based on the given base configuration and
138
+ Generates TargetPlatformCapabilities with default defined Operators Sets, based on the given base configuration and
140
139
  mixed-precision configurations options list.
141
140
 
142
141
  Args
143
142
  default_config: A default OpQuantizationConfig to set as the TP model default configuration.
144
- base_config: An OpQuantizationConfig to set as the TargetPlatformModel base configuration for mixed-precision purposes only.
143
+ base_config: An OpQuantizationConfig to set as the TargetPlatformCapabilities base configuration for mixed-precision purposes only.
145
144
  mixed_precision_cfg_list: A list of OpQuantizationConfig to be used as the TP model mixed-precision
146
145
  quantization configuration options.
147
- name: The name of the TargetPlatformModel.
146
+ name: The name of the TargetPlatformCapabilities.
148
147
 
149
- Returns: A TargetPlatformModel object.
148
+ Returns: A TargetPlatformCapabilities object.
150
149
 
151
150
  """
152
151
  # Create a QuantizationConfigOptions, which defines a set
@@ -218,14 +217,14 @@ def generate_tp_model(default_config: OpQuantizationConfig,
218
217
 
219
218
  operator_set.extend([conv, conv_transpose, depthwise_conv, fc, relu, relu6, leaky_relu, add, sub, mul, div, prelu, swish,
220
219
  hard_swish, sigmoid, tanh, hard_tanh])
221
- any_relu = schema.OperatorSetConcat(operators_set=[relu, relu6, leaky_relu, hard_tanh])
220
+ any_relu = schema.OperatorSetGroup(operators_set=[relu, relu6, leaky_relu, hard_tanh])
222
221
  # Combine multiple operators into a single operator to avoid quantization between
223
222
  # them. To do this we define fusing patterns using the OperatorsSets that were created.
224
- # To group multiple sets with regard to fusing, an OperatorSetConcat can be created
225
- activations_after_conv_to_fuse = schema.OperatorSetConcat(operators_set=[relu, relu6, leaky_relu, hard_tanh, swish, hard_swish, prelu, sigmoid, tanh])
226
- conv_types = schema.OperatorSetConcat(operators_set=[conv, conv_transpose, depthwise_conv])
227
- activations_after_fc_to_fuse = schema.OperatorSetConcat(operators_set=[relu, relu6, leaky_relu, hard_tanh, swish, hard_swish, sigmoid])
228
- any_binary = schema.OperatorSetConcat(operators_set=[add, sub, mul, div])
223
+ # To group multiple sets with regard to fusing, an OperatorSetGroup can be created
224
+ activations_after_conv_to_fuse = schema.OperatorSetGroup(operators_set=[relu, relu6, leaky_relu, hard_tanh, swish, hard_swish, prelu, sigmoid, tanh])
225
+ conv_types = schema.OperatorSetGroup(operators_set=[conv, conv_transpose, depthwise_conv])
226
+ activations_after_fc_to_fuse = schema.OperatorSetGroup(operators_set=[relu, relu6, leaky_relu, hard_tanh, swish, hard_swish, sigmoid])
227
+ any_binary = schema.OperatorSetGroup(operators_set=[add, sub, mul, div])
229
228
 
230
229
  # ------------------- #
231
230
  # Fusions
@@ -234,10 +233,10 @@ def generate_tp_model(default_config: OpQuantizationConfig,
234
233
  fusing_patterns.append(schema.Fusing(operator_groups=(fc, activations_after_fc_to_fuse)))
235
234
  fusing_patterns.append(schema.Fusing(operator_groups=(any_binary, any_relu)))
236
235
 
237
- # Create a TargetPlatformModel and set its default quantization config.
236
+ # Create a TargetPlatformCapabilities and set its default quantization config.
238
237
  # This default configuration will be used for all operations
239
238
  # unless specified otherwise (see OperatorsSet, for example):
240
- generated_tpc = schema.TargetPlatformModel(
239
+ generated_tpc = schema.TargetPlatformCapabilities(
241
240
  default_qco=default_configuration_options,
242
241
  tpc_minor_version=1,
243
242
  tpc_patch_version=0,
@@ -13,14 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from model_compression_toolkit.verify_packages import FOUND_TORCH, FOUND_TF
16
- from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model, generate_tp_model, get_op_quantization_configs
16
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc import get_tpc, generate_tpc, get_op_quantization_configs
17
17
  if FOUND_TF:
18
- from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model as \
18
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc import get_tpc as \
19
19
  get_keras_tpc_latest
20
20
  from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
21
21
  get_tpc_model as generate_keras_tpc, get_tpc_model as generate_keras_tpc
22
22
  if FOUND_TORCH:
23
- from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model as \
23
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc import get_tpc as \
24
24
  get_pytorch_tpc_latest
25
25
  from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
26
26
  get_tpc_model as generate_pytorch_tpc, get_tpc_model as generate_pytorch_tpc