mct-nightly 2.1.0.20240806.441__py3-none-any.whl → 2.1.0.20240808.431__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {mct_nightly-2.1.0.20240806.441.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/METADATA +2 -2
  2. {mct_nightly-2.1.0.20240806.441.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/RECORD +48 -47
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/constants.py +14 -1
  5. model_compression_toolkit/core/common/fusion/graph_fuser.py +135 -0
  6. model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py +4 -0
  7. model_compression_toolkit/core/common/quantization/debug_config.py +4 -1
  8. model_compression_toolkit/core/common/quantization/node_quantization_config.py +1 -1
  9. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +3 -4
  10. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +29 -1
  11. model_compression_toolkit/core/runner.py +21 -1
  12. model_compression_toolkit/gptq/keras/quantization_facade.py +13 -11
  13. model_compression_toolkit/gptq/pytorch/quantization_facade.py +13 -11
  14. model_compression_toolkit/metadata.py +61 -2
  15. model_compression_toolkit/ptq/keras/quantization_facade.py +12 -10
  16. model_compression_toolkit/ptq/pytorch/quantization_facade.py +12 -12
  17. model_compression_toolkit/qat/keras/quantization_facade.py +8 -8
  18. model_compression_toolkit/qat/pytorch/quantization_facade.py +8 -8
  19. model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +2 -1
  20. model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +18 -4
  21. model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +10 -13
  22. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +68 -52
  23. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +5 -3
  24. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +5 -3
  25. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +5 -3
  26. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +5 -3
  27. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +5 -3
  28. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +5 -3
  29. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +5 -3
  30. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +6 -4
  31. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +3 -3
  32. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py +35 -29
  33. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +5 -4
  34. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py +35 -28
  35. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +5 -4
  36. model_compression_toolkit/xquant/common/constants.py +3 -0
  37. model_compression_toolkit/xquant/common/core_report_generator.py +9 -1
  38. model_compression_toolkit/xquant/common/framework_report_utils.py +5 -14
  39. model_compression_toolkit/xquant/common/tensorboard_utils.py +30 -5
  40. model_compression_toolkit/xquant/keras/facade_xquant_report.py +2 -0
  41. model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -1
  42. model_compression_toolkit/xquant/keras/tensorboard_utils.py +101 -4
  43. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +2 -0
  44. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -2
  45. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +109 -3
  46. {mct_nightly-2.1.0.20240806.441.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/LICENSE.md +0 -0
  47. {mct_nightly-2.1.0.20240806.441.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/WHEEL +0 -0
  48. {mct_nightly-2.1.0.20240806.441.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/top_level.txt +0 -0
@@ -31,7 +31,7 @@ from model_compression_toolkit.core.runner import core_runner
31
31
  from model_compression_toolkit.gptq.runner import gptq_runner
32
32
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
33
33
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
34
- from model_compression_toolkit.metadata import get_versions_dict
34
+ from model_compression_toolkit.metadata import get_versions_dict, create_model_metadata
35
35
 
36
36
  LR_DEFAULT = 0.15
37
37
  LR_REST_DEFAULT = 1e-4
@@ -208,15 +208,15 @@ if FOUND_TF:
208
208
 
209
209
  fw_impl = GPTQKerasImplemantation()
210
210
 
211
- tg, bit_widths_config, hessian_info_service = core_runner(in_model=in_model,
212
- representative_data_gen=representative_data_gen,
213
- core_config=core_config,
214
- fw_info=DEFAULT_KERAS_INFO,
215
- fw_impl=fw_impl,
216
- tpc=target_platform_capabilities,
217
- target_resource_utilization=target_resource_utilization,
218
- tb_w=tb_w,
219
- running_gptq=True)
211
+ tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model,
212
+ representative_data_gen=representative_data_gen,
213
+ core_config=core_config,
214
+ fw_info=DEFAULT_KERAS_INFO,
215
+ fw_impl=fw_impl,
216
+ tpc=target_platform_capabilities,
217
+ target_resource_utilization=target_resource_utilization,
218
+ tb_w=tb_w,
219
+ running_gptq=True)
220
220
 
221
221
  float_graph = copy.deepcopy(tg)
222
222
 
@@ -242,7 +242,9 @@ if FOUND_TF:
242
242
 
243
243
  exportable_model, user_info = get_exportable_keras_model(tg_gptq)
244
244
  if target_platform_capabilities.tp_model.add_metadata:
245
- exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities))
245
+ exportable_model = add_metadata(exportable_model,
246
+ create_model_metadata(tpc=target_platform_capabilities,
247
+ scheduling_info=scheduling_info))
246
248
  return exportable_model, user_info
247
249
 
248
250
  else:
@@ -31,7 +31,7 @@ from model_compression_toolkit.core.analyzer import analyzer_model_quantization
31
31
  from model_compression_toolkit.core import CoreConfig
32
32
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
33
33
  MixedPrecisionQuantizationConfig
34
- from model_compression_toolkit.metadata import get_versions_dict
34
+ from model_compression_toolkit.metadata import get_versions_dict, create_model_metadata
35
35
 
36
36
  LR_DEFAULT = 1e-4
37
37
  LR_REST_DEFAULT = 1e-4
@@ -177,15 +177,15 @@ if FOUND_TORCH:
177
177
  # ---------------------- #
178
178
  # Core Runner
179
179
  # ---------------------- #
180
- graph, bit_widths_config, hessian_info_service = core_runner(in_model=model,
181
- representative_data_gen=representative_data_gen,
182
- core_config=core_config,
183
- fw_info=DEFAULT_PYTORCH_INFO,
184
- fw_impl=fw_impl,
185
- tpc=target_platform_capabilities,
186
- target_resource_utilization=target_resource_utilization,
187
- tb_w=tb_w,
188
- running_gptq=True)
180
+ graph, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=model,
181
+ representative_data_gen=representative_data_gen,
182
+ core_config=core_config,
183
+ fw_info=DEFAULT_PYTORCH_INFO,
184
+ fw_impl=fw_impl,
185
+ tpc=target_platform_capabilities,
186
+ target_resource_utilization=target_resource_utilization,
187
+ tb_w=tb_w,
188
+ running_gptq=True)
189
189
 
190
190
  float_graph = copy.deepcopy(graph)
191
191
 
@@ -212,7 +212,9 @@ if FOUND_TORCH:
212
212
 
213
213
  exportable_model, user_info = get_exportable_pytorch_model(graph_gptq)
214
214
  if target_platform_capabilities.tp_model.add_metadata:
215
- exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities))
215
+ exportable_model = add_metadata(exportable_model,
216
+ create_model_metadata(tpc=target_platform_capabilities,
217
+ scheduling_info=scheduling_info))
216
218
  return exportable_model, user_info
217
219
 
218
220
 
@@ -13,8 +13,31 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Dict
17
- from model_compression_toolkit.constants import MCT_VERSION, TPC_VERSION
16
+ from typing import Dict, Any
17
+ from model_compression_toolkit.constants import MCT_VERSION, TPC_VERSION, OPERATORS_SCHEDULING, FUSED_NODES_MAPPING, \
18
+ CUTS, MAX_CUT, OP_ORDER, OP_RECORD, SHAPE, NODE_OUTPUT_INDEX, NODE_NAME, TOTAL_SIZE, MEM_ELEMENTS
19
+ from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import SchedulerInfo
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
21
+
22
+
23
+ def create_model_metadata(tpc: TargetPlatformCapabilities,
24
+ scheduling_info: SchedulerInfo = None) -> Dict:
25
+ """
26
+ Creates and returns a metadata dictionary for the model, including version information
27
+ and optional scheduling information.
28
+
29
+ Args:
30
+ tpc: A TPC object to get the version.
31
+ scheduling_info: An object containing scheduling details and metadata. Default is None.
32
+
33
+ Returns:
34
+ Dict: A dictionary containing the model's version information and optional scheduling information.
35
+ """
36
+ _metadata = get_versions_dict(tpc)
37
+ if scheduling_info:
38
+ scheduler_metadata = get_scheduler_metadata(scheduler_info=scheduling_info)
39
+ _metadata['scheduling_info'] = scheduler_metadata
40
+ return _metadata
18
41
 
19
42
 
20
43
  def get_versions_dict(tpc) -> Dict:
@@ -27,3 +50,39 @@ def get_versions_dict(tpc) -> Dict:
27
50
  from model_compression_toolkit import __version__ as mct_version
28
51
  tpc_version = f'{tpc.name}.{tpc.version}'
29
52
  return {MCT_VERSION: mct_version, TPC_VERSION: tpc_version}
53
+
54
+
55
+ def get_scheduler_metadata(scheduler_info: SchedulerInfo) -> Dict[str, Any]:
56
+ """
57
+ Extracts and returns metadata from SchedulerInfo.
58
+
59
+ Args:
60
+ scheduler_info (SchedulerInfo): The scheduler information object containing scheduling details like cuts and
61
+ fusing mapping.
62
+
63
+ Returns:
64
+ Dict[str, Any]: A dictionary containing extracted metadata, including schedule, maximum cut,
65
+ cuts information, and fused nodes mapping.
66
+ """
67
+ scheduler_metadata = {
68
+ OPERATORS_SCHEDULING: [str(layer) for layer in scheduler_info.operators_scheduling],
69
+ MAX_CUT: scheduler_info.max_cut,
70
+ CUTS: [
71
+ {
72
+ OP_ORDER: [op.name for op in cut.op_order],
73
+ OP_RECORD: [op.name for op in cut.op_record],
74
+ MEM_ELEMENTS: [
75
+ {
76
+ SHAPE: list(tensor.shape),
77
+ NODE_NAME: tensor.node_name,
78
+ TOTAL_SIZE: float(tensor.total_size),
79
+ NODE_OUTPUT_INDEX: tensor.node_output_index
80
+ }
81
+ for tensor in cut.mem_elements.elements
82
+ ]
83
+ }
84
+ for cut in scheduler_info.cuts
85
+ ],
86
+ FUSED_NODES_MAPPING: scheduler_info.fused_nodes_mapping
87
+ }
88
+ return scheduler_metadata
@@ -28,7 +28,7 @@ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quant
28
28
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
29
29
  from model_compression_toolkit.core.runner import core_runner
30
30
  from model_compression_toolkit.ptq.runner import ptq_runner
31
- from model_compression_toolkit.metadata import get_versions_dict
31
+ from model_compression_toolkit.metadata import create_model_metadata
32
32
 
33
33
  if FOUND_TF:
34
34
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
@@ -134,14 +134,14 @@ if FOUND_TF:
134
134
  fw_impl = KerasImplementation()
135
135
 
136
136
  # Ignore returned hessian service as PTQ does not use it
137
- tg, bit_widths_config, _ = core_runner(in_model=in_model,
138
- representative_data_gen=representative_data_gen,
139
- core_config=core_config,
140
- fw_info=fw_info,
141
- fw_impl=fw_impl,
142
- tpc=target_platform_capabilities,
143
- target_resource_utilization=target_resource_utilization,
144
- tb_w=tb_w)
137
+ tg, bit_widths_config, _, scheduling_info = core_runner(in_model=in_model,
138
+ representative_data_gen=representative_data_gen,
139
+ core_config=core_config,
140
+ fw_info=fw_info,
141
+ fw_impl=fw_impl,
142
+ tpc=target_platform_capabilities,
143
+ target_resource_utilization=target_resource_utilization,
144
+ tb_w=tb_w)
145
145
 
146
146
  # At this point, tg is a graph that went through substitutions (such as BN folding) and is
147
147
  # ready for quantization (namely, it holds quantization params, etc.) but the weights are
@@ -168,7 +168,9 @@ if FOUND_TF:
168
168
 
169
169
  exportable_model, user_info = get_exportable_keras_model(graph_with_stats_correction)
170
170
  if target_platform_capabilities.tp_model.add_metadata:
171
- exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities))
171
+ exportable_model = add_metadata(exportable_model,
172
+ create_model_metadata(tpc=target_platform_capabilities,
173
+ scheduling_info=scheduling_info))
172
174
  return exportable_model, user_info
173
175
 
174
176
 
@@ -16,7 +16,6 @@ import copy
16
16
 
17
17
  from typing import Callable
18
18
 
19
- from model_compression_toolkit.core import common
20
19
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
21
20
  from model_compression_toolkit.logger import Logger
22
21
  from model_compression_toolkit.constants import PYTORCH, FOUND_TORCH
@@ -29,8 +28,7 @@ from model_compression_toolkit.core.runner import core_runner
29
28
  from model_compression_toolkit.ptq.runner import ptq_runner
30
29
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
31
30
  from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
32
- from model_compression_toolkit.metadata import get_versions_dict
33
-
31
+ from model_compression_toolkit.metadata import create_model_metadata
34
32
 
35
33
  if FOUND_TORCH:
36
34
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
@@ -109,14 +107,14 @@ if FOUND_TORCH:
109
107
  fw_impl = PytorchImplementation()
110
108
 
111
109
  # Ignore hessian info service as it is not used here yet.
112
- tg, bit_widths_config, _ = core_runner(in_model=in_module,
113
- representative_data_gen=representative_data_gen,
114
- core_config=core_config,
115
- fw_info=fw_info,
116
- fw_impl=fw_impl,
117
- tpc=target_platform_capabilities,
118
- target_resource_utilization=target_resource_utilization,
119
- tb_w=tb_w)
110
+ tg, bit_widths_config, _, scheduling_info = core_runner(in_model=in_module,
111
+ representative_data_gen=representative_data_gen,
112
+ core_config=core_config,
113
+ fw_info=fw_info,
114
+ fw_impl=fw_impl,
115
+ tpc=target_platform_capabilities,
116
+ target_resource_utilization=target_resource_utilization,
117
+ tb_w=tb_w)
120
118
 
121
119
  # At this point, tg is a graph that went through substitutions (such as BN folding) and is
122
120
  # ready for quantization (namely, it holds quantization params, etc.) but the weights are
@@ -143,7 +141,9 @@ if FOUND_TORCH:
143
141
 
144
142
  exportable_model, user_info = get_exportable_pytorch_model(graph_with_stats_correction)
145
143
  if target_platform_capabilities.tp_model.add_metadata:
146
- exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities))
144
+ exportable_model = add_metadata(exportable_model,
145
+ create_model_metadata(tpc=target_platform_capabilities,
146
+ scheduling_info=scheduling_info))
147
147
  return exportable_model, user_info
148
148
 
149
149
 
@@ -187,14 +187,14 @@ if FOUND_TF:
187
187
  fw_impl = KerasImplementation()
188
188
 
189
189
  # Ignore hessian service since is not used in QAT at the moment
190
- tg, bit_widths_config, _ = core_runner(in_model=in_model,
191
- representative_data_gen=representative_data_gen,
192
- core_config=core_config,
193
- fw_info=DEFAULT_KERAS_INFO,
194
- fw_impl=fw_impl,
195
- tpc=target_platform_capabilities,
196
- target_resource_utilization=target_resource_utilization,
197
- tb_w=tb_w)
190
+ tg, bit_widths_config, _, _ = core_runner(in_model=in_model,
191
+ representative_data_gen=representative_data_gen,
192
+ core_config=core_config,
193
+ fw_info=DEFAULT_KERAS_INFO,
194
+ fw_impl=fw_impl,
195
+ tpc=target_platform_capabilities,
196
+ target_resource_utilization=target_resource_utilization,
197
+ tb_w=tb_w)
198
198
 
199
199
  tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_KERAS_INFO, fw_impl, tb_w)
200
200
 
@@ -154,14 +154,14 @@ if FOUND_TORCH:
154
154
  fw_impl = PytorchImplementation()
155
155
 
156
156
  # Ignore hessian scores service as we do not use it here
157
- tg, bit_widths_config, _ = core_runner(in_model=in_model,
158
- representative_data_gen=representative_data_gen,
159
- core_config=core_config,
160
- fw_info=DEFAULT_PYTORCH_INFO,
161
- fw_impl=fw_impl,
162
- tpc=target_platform_capabilities,
163
- target_resource_utilization=target_resource_utilization,
164
- tb_w=tb_w)
157
+ tg, bit_widths_config, _, _ = core_runner(in_model=in_model,
158
+ representative_data_gen=representative_data_gen,
159
+ core_config=core_config,
160
+ fw_info=DEFAULT_PYTORCH_INFO,
161
+ fw_impl=fw_impl,
162
+ tpc=target_platform_capabilities,
163
+ target_resource_utilization=target_resource_utilization,
164
+ tb_w=tb_w)
165
165
 
166
166
  tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_PYTORCH_INFO, fw_impl, tb_w)
167
167
 
@@ -17,7 +17,8 @@ from model_compression_toolkit.target_platform_capabilities.target_platform.fusi
17
17
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import AttributeFilter
18
18
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities, OperationsSetToLayers, Smaller, SmallerEq, NotEq, Eq, GreaterEq, Greater, LayerFilterParams, OperationsToLayers, get_current_tpc
19
19
  from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model import get_default_quantization_config_options, TargetPlatformModel
20
- from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import OpQuantizationConfig, QuantizationConfigOptions, AttributeQuantizationConfig
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import \
21
+ OpQuantizationConfig, QuantizationConfigOptions, AttributeQuantizationConfig, Signedness
21
22
  from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSet, OperatorSetConcat
22
23
 
23
24
  from mct_quantizers import QuantizationMethod
@@ -15,12 +15,26 @@
15
15
 
16
16
  import copy
17
17
  from typing import List, Dict, Union, Any, Tuple
18
+ from enum import Enum
18
19
 
19
20
  from mct_quantizers import QuantizationMethod
20
21
  from model_compression_toolkit.constants import FLOAT_BITWIDTH
21
22
  from model_compression_toolkit.logger import Logger
22
23
 
23
24
 
25
+ class Signedness(Enum):
26
+ """
27
+ An enum for choosing the signedness of the quantization method:
28
+
29
+ AUTO - Signedness decided automatically by quantization.
30
+ SIGNED - Force signed quantization.
31
+ UNSIGNED - Force unsigned quantization.
32
+ """
33
+ AUTO = 0
34
+ SIGNED = 1
35
+ UNSIGNED = 2
36
+
37
+
24
38
  def clone_and_edit_object_params(obj: Any, **kwargs: Dict) -> Any:
25
39
  """
26
40
  Clones the given object and edit some of its parameters.
@@ -120,7 +134,7 @@ class OpQuantizationConfig:
120
134
  fixed_scale: float,
121
135
  fixed_zero_point: int,
122
136
  simd_size: int,
123
- is_signed: bool = None
137
+ signedness: Signedness
124
138
  ):
125
139
  """
126
140
 
@@ -134,8 +148,8 @@ class OpQuantizationConfig:
134
148
  quantization_preserving (bool): Whether quantization parameters should be the same for an operator's input and output.
135
149
  fixed_scale (float): Scale to use for an operator quantization parameters.
136
150
  fixed_zero_point (int): Zero-point to use for an operator quantization parameters.
137
- is_signed (bool): Force activation quantization signedness (None means don't force).
138
151
  simd_size (int): Per op integer representing the Single Instruction, Multiple Data (SIMD) width of an operator. It indicates the number of data elements that can be fetched and processed simultaneously in a single instruction.
152
+ signedness (bool): Set activation quantization signedness.
139
153
 
140
154
  """
141
155
 
@@ -154,7 +168,7 @@ class OpQuantizationConfig:
154
168
  self.quantization_preserving = quantization_preserving
155
169
  self.fixed_scale = fixed_scale
156
170
  self.fixed_zero_point = fixed_zero_point
157
- self.is_signed = is_signed
171
+ self.signedness = signedness
158
172
  self.simd_size = simd_size
159
173
 
160
174
  def get_info(self):
@@ -206,7 +220,7 @@ class OpQuantizationConfig:
206
220
  self.activation_n_bits == other.activation_n_bits and \
207
221
  self.supported_input_activation_n_bits == other.supported_input_activation_n_bits and \
208
222
  self.enable_activation_quantization == other.enable_activation_quantization and \
209
- self.is_signed == other.is_signed and \
223
+ self.signedness == other.signedness and \
210
224
  self.simd_size == other.simd_size
211
225
 
212
226
  @property
@@ -16,17 +16,17 @@
16
16
  from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
17
17
 
18
18
  from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.target_platform_capabilities import \
19
- tpc_dict as imx500_tpc_dict
19
+ get_tpc_dict_by_fw as get_imx500_tpc
20
20
  from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.target_platform_capabilities import \
21
- tpc_dict as tflite_tpc_dict
21
+ get_tpc_dict_by_fw as get_tflite_tpc
22
22
  from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.target_platform_capabilities import \
23
- tpc_dict as qnnpack_tpc_dict
23
+ get_tpc_dict_by_fw as get_qnnpack_tpc
24
24
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, TFLITE_TP_MODEL, QNNPACK_TP_MODEL, LATEST
25
25
 
26
- tpc_dict = {DEFAULT_TP_MODEL: imx500_tpc_dict,
27
- IMX500_TP_MODEL: imx500_tpc_dict,
28
- TFLITE_TP_MODEL: tflite_tpc_dict,
29
- QNNPACK_TP_MODEL: qnnpack_tpc_dict}
26
+ tpc_dict = {DEFAULT_TP_MODEL: get_imx500_tpc,
27
+ IMX500_TP_MODEL: get_imx500_tpc,
28
+ TFLITE_TP_MODEL: get_tflite_tpc,
29
+ QNNPACK_TP_MODEL: get_qnnpack_tpc}
30
30
 
31
31
 
32
32
  def get_target_platform_capabilities(fw_name: str,
@@ -47,13 +47,10 @@ def get_target_platform_capabilities(fw_name: str,
47
47
  """
48
48
  assert target_platform_name in tpc_dict, f'Target platform {target_platform_name} is not defined!'
49
49
  fw_tpc = tpc_dict.get(target_platform_name)
50
- assert fw_name in fw_tpc, f'Framework {fw_name} is not supported in {target_platform_name}. Please make sure the relevant ' \
51
- f'packages are installed when using MCT for optimizing a {fw_name} model. ' \
52
- f'For Tensorflow, please install tensorflow. ' \
53
- f'For PyTorch, please install torch.'
54
- tpc_versions = fw_tpc.get(fw_name)
50
+ tpc_versions = fw_tpc(fw_name)
55
51
  if target_platform_version is None:
56
52
  target_platform_version = LATEST
57
53
  else:
58
- assert target_platform_version in tpc_versions, f'TPC version {target_platform_version} is not supported for framework {fw_name}.'
54
+ assert target_platform_version in tpc_versions, (f'TPC version {target_platform_version} is not supported for '
55
+ f'framework {fw_name}.')
59
56
  return tpc_versions[target_platform_version]()
@@ -12,61 +12,74 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from model_compression_toolkit.logger import Logger
15
16
 
16
17
  from model_compression_toolkit.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH
17
18
  from model_compression_toolkit.target_platform_capabilities.constants import LATEST
18
19
 
19
- ###############################
20
- # Build Tensorflow TPC models
21
- ###############################
22
- keras_tpc_models_dict = None
23
- if FOUND_TF:
24
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_keras_tpc_latest
25
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1
26
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v1_lut
27
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_keras import get_keras_tpc as get_keras_tpc_v1_pot
28
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_keras import get_keras_tpc as get_keras_tpc_v2
29
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v2_lut
30
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_keras import get_keras_tpc as get_keras_tpc_v3
31
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v3_lut
32
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_keras import get_keras_tpc as get_keras_tpc_v4
33
20
 
34
- # Keras: TPC versioning
35
- keras_tpc_models_dict = {'v1': get_keras_tpc_v1,
36
- 'v1_lut': get_keras_tpc_v1_lut,
37
- 'v1_pot': get_keras_tpc_v1_pot,
38
- 'v2': get_keras_tpc_v2,
39
- 'v2_lut': get_keras_tpc_v2_lut,
40
- 'v3': get_keras_tpc_v3,
41
- 'v3_lut': get_keras_tpc_v3_lut,
42
- 'v4': get_keras_tpc_v4,
43
- LATEST: get_keras_tpc_latest}
21
+ def get_tpc_dict_by_fw(fw_name):
22
+ tpc_models_dict = None
23
+ if fw_name == TENSORFLOW:
24
+ ###############################
25
+ # Build Tensorflow TPC models
26
+ ###############################
27
+ if FOUND_TF:
28
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import \
29
+ get_keras_tpc_latest
30
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_keras import \
31
+ get_keras_tpc as get_keras_tpc_v1
32
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_keras import \
33
+ get_keras_tpc as get_keras_tpc_v1_lut
34
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_keras import \
35
+ get_keras_tpc as get_keras_tpc_v1_pot
36
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_keras import \
37
+ get_keras_tpc as get_keras_tpc_v2
38
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_keras import \
39
+ get_keras_tpc as get_keras_tpc_v2_lut
40
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_keras import \
41
+ get_keras_tpc as get_keras_tpc_v3
42
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_keras import \
43
+ get_keras_tpc as get_keras_tpc_v3_lut
44
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_keras import \
45
+ get_keras_tpc as get_keras_tpc_v4
44
46
 
45
- ###############################
46
- # Build Pytorch TPC models
47
- ###############################
48
- pytorch_tpc_models_dict = None
49
- if FOUND_TORCH:
50
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_pytorch_tpc_latest
51
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_pytorch import \
52
- get_pytorch_tpc as get_pytorch_tpc_v1
53
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_pytorch import \
54
- get_pytorch_tpc as get_pytorch_tpc_v1_pot
55
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_pytorch import \
56
- get_pytorch_tpc as get_pytorch_tpc_v1_lut
57
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_pytorch import \
58
- get_pytorch_tpc as get_pytorch_tpc_v2
59
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_pytorch import \
60
- get_pytorch_tpc as get_pytorch_tpc_v2_lut
61
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_pytorch import \
62
- get_pytorch_tpc as get_pytorch_tpc_v3
63
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_pytorch import \
64
- get_pytorch_tpc as get_pytorch_tpc_v3_lut
65
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_pytorch import \
66
- get_pytorch_tpc as get_pytorch_tpc_v4
47
+ # Keras: TPC versioning
48
+ tpc_models_dict = {'v1': get_keras_tpc_v1,
49
+ 'v1_lut': get_keras_tpc_v1_lut,
50
+ 'v1_pot': get_keras_tpc_v1_pot,
51
+ 'v2': get_keras_tpc_v2,
52
+ 'v2_lut': get_keras_tpc_v2_lut,
53
+ 'v3': get_keras_tpc_v3,
54
+ 'v3_lut': get_keras_tpc_v3_lut,
55
+ 'v4': get_keras_tpc_v4,
56
+ LATEST: get_keras_tpc_latest}
57
+ elif fw_name == PYTORCH:
58
+ ###############################
59
+ # Build Pytorch TPC models
60
+ ###############################
61
+ if FOUND_TORCH:
62
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import \
63
+ get_pytorch_tpc_latest
64
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_pytorch import \
65
+ get_pytorch_tpc as get_pytorch_tpc_v1
66
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_pytorch import \
67
+ get_pytorch_tpc as get_pytorch_tpc_v1_pot
68
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_pytorch import \
69
+ get_pytorch_tpc as get_pytorch_tpc_v1_lut
70
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_pytorch import \
71
+ get_pytorch_tpc as get_pytorch_tpc_v2
72
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_pytorch import \
73
+ get_pytorch_tpc as get_pytorch_tpc_v2_lut
74
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_pytorch import \
75
+ get_pytorch_tpc as get_pytorch_tpc_v3
76
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_pytorch import \
77
+ get_pytorch_tpc as get_pytorch_tpc_v3_lut
78
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_pytorch import \
79
+ get_pytorch_tpc as get_pytorch_tpc_v4
67
80
 
68
- # Pytorch: TPC versioning
69
- pytorch_tpc_models_dict = {'v1': get_pytorch_tpc_v1,
81
+ # Pytorch: TPC versioning
82
+ tpc_models_dict = {'v1': get_pytorch_tpc_v1,
70
83
  'v1_lut': get_pytorch_tpc_v1_lut,
71
84
  'v1_pot': get_pytorch_tpc_v1_pot,
72
85
  'v2': get_pytorch_tpc_v2,
@@ -75,7 +88,10 @@ if FOUND_TORCH:
75
88
  'v3_lut': get_pytorch_tpc_v3_lut,
76
89
  'v4': get_pytorch_tpc_v4,
77
90
  LATEST: get_pytorch_tpc_latest}
78
-
79
- tpc_dict = {TENSORFLOW: keras_tpc_models_dict,
80
- PYTORCH: pytorch_tpc_models_dict}
81
-
91
+ if tpc_models_dict is not None:
92
+ return tpc_models_dict
93
+ else:
94
+ Logger.critical(f'Framework {fw_name} is not supported in imx500 or the relevant packages are not '
95
+ f'installed. Please make sure the relevant packages are installed when using MCT for optimizing'
96
+ f' a {fw_name} model. For Tensorflow, please install tensorflow. For PyTorch, please install '
97
+ f'torch.') # pragma: no cover
@@ -18,7 +18,7 @@ import model_compression_toolkit as mct
18
18
  from model_compression_toolkit.constants import FLOAT_BITWIDTH
19
19
  from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS
20
20
  from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, \
21
- TargetPlatformModel
21
+ TargetPlatformModel, Signedness
22
22
  from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import \
23
23
  AttributeQuantizationConfig
24
24
 
@@ -98,7 +98,8 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza
98
98
  quantization_preserving=False,
99
99
  fixed_scale=None,
100
100
  fixed_zero_point=None,
101
- simd_size=32)
101
+ simd_size=32,
102
+ signedness=Signedness.AUTO)
102
103
 
103
104
  # We define an 8-bit config for linear operations quantization, that include a kernel and bias attributes.
104
105
  linear_eight_bits = tp.OpQuantizationConfig(
@@ -111,7 +112,8 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza
111
112
  quantization_preserving=False,
112
113
  fixed_scale=None,
113
114
  fixed_zero_point=None,
114
- simd_size=32)
115
+ simd_size=32,
116
+ signedness=Signedness.AUTO)
115
117
 
116
118
  # To quantize a model using mixed-precision, create
117
119
  # a list with more than one OpQuantizationConfig.
@@ -19,7 +19,7 @@ from model_compression_toolkit.constants import FLOAT_BITWIDTH
19
19
  from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS, \
20
20
  WEIGHTS_QUANTIZATION_METHOD
21
21
  from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, \
22
- TargetPlatformModel
22
+ TargetPlatformModel, Signedness
23
23
  from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import \
24
24
  AttributeQuantizationConfig
25
25
 
@@ -94,7 +94,8 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza
94
94
  quantization_preserving=False,
95
95
  fixed_scale=None,
96
96
  fixed_zero_point=None,
97
- simd_size=32)
97
+ simd_size=32,
98
+ signedness=Signedness.AUTO)
98
99
 
99
100
  # We define an 8-bit config for linear operations quantization, that include a kernel and bias attributes.
100
101
  linear_eight_bits = tp.OpQuantizationConfig(
@@ -107,7 +108,8 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza
107
108
  quantization_preserving=False,
108
109
  fixed_scale=None,
109
110
  fixed_zero_point=None,
110
- simd_size=32)
111
+ simd_size=32,
112
+ signedness=Signedness.AUTO)
111
113
 
112
114
  # To quantize a model using mixed-precision, create
113
115
  # a list with more than one OpQuantizationConfig.