mct-nightly 2.1.0.20240807.445__py3-none-any.whl → 2.1.0.20240808.431__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/RECORD +33 -32
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/constants.py +14 -1
  5. model_compression_toolkit/core/common/fusion/graph_fuser.py +135 -0
  6. model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py +4 -0
  7. model_compression_toolkit/core/common/quantization/debug_config.py +4 -1
  8. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +29 -1
  9. model_compression_toolkit/core/runner.py +21 -1
  10. model_compression_toolkit/gptq/keras/quantization_facade.py +13 -11
  11. model_compression_toolkit/gptq/pytorch/quantization_facade.py +13 -11
  12. model_compression_toolkit/metadata.py +61 -2
  13. model_compression_toolkit/ptq/keras/quantization_facade.py +12 -10
  14. model_compression_toolkit/ptq/pytorch/quantization_facade.py +12 -12
  15. model_compression_toolkit/qat/keras/quantization_facade.py +8 -8
  16. model_compression_toolkit/qat/pytorch/quantization_facade.py +8 -8
  17. model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +10 -13
  18. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +68 -52
  19. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py +35 -29
  20. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py +35 -28
  21. model_compression_toolkit/xquant/common/constants.py +3 -0
  22. model_compression_toolkit/xquant/common/core_report_generator.py +9 -1
  23. model_compression_toolkit/xquant/common/framework_report_utils.py +5 -14
  24. model_compression_toolkit/xquant/common/tensorboard_utils.py +30 -5
  25. model_compression_toolkit/xquant/keras/facade_xquant_report.py +2 -0
  26. model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -1
  27. model_compression_toolkit/xquant/keras/tensorboard_utils.py +101 -4
  28. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +2 -0
  29. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -2
  30. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +109 -3
  31. {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/LICENSE.md +0 -0
  32. {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/WHEEL +0 -0
  33. {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,31 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Dict
17
- from model_compression_toolkit.constants import MCT_VERSION, TPC_VERSION
16
+ from typing import Dict, Any
17
+ from model_compression_toolkit.constants import MCT_VERSION, TPC_VERSION, OPERATORS_SCHEDULING, FUSED_NODES_MAPPING, \
18
+ CUTS, MAX_CUT, OP_ORDER, OP_RECORD, SHAPE, NODE_OUTPUT_INDEX, NODE_NAME, TOTAL_SIZE, MEM_ELEMENTS
19
+ from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import SchedulerInfo
20
+ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
21
+
22
+
23
+ def create_model_metadata(tpc: TargetPlatformCapabilities,
24
+ scheduling_info: SchedulerInfo = None) -> Dict:
25
+ """
26
+ Creates and returns a metadata dictionary for the model, including version information
27
+ and optional scheduling information.
28
+
29
+ Args:
30
+ tpc: A TPC object to get the version.
31
+ scheduling_info: An object containing scheduling details and metadata. Default is None.
32
+
33
+ Returns:
34
+ Dict: A dictionary containing the model's version information and optional scheduling information.
35
+ """
36
+ _metadata = get_versions_dict(tpc)
37
+ if scheduling_info:
38
+ scheduler_metadata = get_scheduler_metadata(scheduler_info=scheduling_info)
39
+ _metadata['scheduling_info'] = scheduler_metadata
40
+ return _metadata
18
41
 
19
42
 
20
43
  def get_versions_dict(tpc) -> Dict:
@@ -27,3 +50,39 @@ def get_versions_dict(tpc) -> Dict:
27
50
  from model_compression_toolkit import __version__ as mct_version
28
51
  tpc_version = f'{tpc.name}.{tpc.version}'
29
52
  return {MCT_VERSION: mct_version, TPC_VERSION: tpc_version}
53
+
54
+
55
+ def get_scheduler_metadata(scheduler_info: SchedulerInfo) -> Dict[str, Any]:
56
+ """
57
+ Extracts and returns metadata from SchedulerInfo.
58
+
59
+ Args:
60
+ scheduler_info (SchedulerInfo): The scheduler information object containing scheduling details like cuts and
61
+ fusing mapping.
62
+
63
+ Returns:
64
+ Dict[str, Any]: A dictionary containing extracted metadata, including schedule, maximum cut,
65
+ cuts information, and fused nodes mapping.
66
+ """
67
+ scheduler_metadata = {
68
+ OPERATORS_SCHEDULING: [str(layer) for layer in scheduler_info.operators_scheduling],
69
+ MAX_CUT: scheduler_info.max_cut,
70
+ CUTS: [
71
+ {
72
+ OP_ORDER: [op.name for op in cut.op_order],
73
+ OP_RECORD: [op.name for op in cut.op_record],
74
+ MEM_ELEMENTS: [
75
+ {
76
+ SHAPE: list(tensor.shape),
77
+ NODE_NAME: tensor.node_name,
78
+ TOTAL_SIZE: float(tensor.total_size),
79
+ NODE_OUTPUT_INDEX: tensor.node_output_index
80
+ }
81
+ for tensor in cut.mem_elements.elements
82
+ ]
83
+ }
84
+ for cut in scheduler_info.cuts
85
+ ],
86
+ FUSED_NODES_MAPPING: scheduler_info.fused_nodes_mapping
87
+ }
88
+ return scheduler_metadata
@@ -28,7 +28,7 @@ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quant
28
28
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
29
29
  from model_compression_toolkit.core.runner import core_runner
30
30
  from model_compression_toolkit.ptq.runner import ptq_runner
31
- from model_compression_toolkit.metadata import get_versions_dict
31
+ from model_compression_toolkit.metadata import create_model_metadata
32
32
 
33
33
  if FOUND_TF:
34
34
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
@@ -134,14 +134,14 @@ if FOUND_TF:
134
134
  fw_impl = KerasImplementation()
135
135
 
136
136
  # Ignore returned hessian service as PTQ does not use it
137
- tg, bit_widths_config, _ = core_runner(in_model=in_model,
138
- representative_data_gen=representative_data_gen,
139
- core_config=core_config,
140
- fw_info=fw_info,
141
- fw_impl=fw_impl,
142
- tpc=target_platform_capabilities,
143
- target_resource_utilization=target_resource_utilization,
144
- tb_w=tb_w)
137
+ tg, bit_widths_config, _, scheduling_info = core_runner(in_model=in_model,
138
+ representative_data_gen=representative_data_gen,
139
+ core_config=core_config,
140
+ fw_info=fw_info,
141
+ fw_impl=fw_impl,
142
+ tpc=target_platform_capabilities,
143
+ target_resource_utilization=target_resource_utilization,
144
+ tb_w=tb_w)
145
145
 
146
146
  # At this point, tg is a graph that went through substitutions (such as BN folding) and is
147
147
  # ready for quantization (namely, it holds quantization params, etc.) but the weights are
@@ -168,7 +168,9 @@ if FOUND_TF:
168
168
 
169
169
  exportable_model, user_info = get_exportable_keras_model(graph_with_stats_correction)
170
170
  if target_platform_capabilities.tp_model.add_metadata:
171
- exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities))
171
+ exportable_model = add_metadata(exportable_model,
172
+ create_model_metadata(tpc=target_platform_capabilities,
173
+ scheduling_info=scheduling_info))
172
174
  return exportable_model, user_info
173
175
 
174
176
 
@@ -16,7 +16,6 @@ import copy
16
16
 
17
17
  from typing import Callable
18
18
 
19
- from model_compression_toolkit.core import common
20
19
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
21
20
  from model_compression_toolkit.logger import Logger
22
21
  from model_compression_toolkit.constants import PYTORCH, FOUND_TORCH
@@ -29,8 +28,7 @@ from model_compression_toolkit.core.runner import core_runner
29
28
  from model_compression_toolkit.ptq.runner import ptq_runner
30
29
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
31
30
  from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
32
- from model_compression_toolkit.metadata import get_versions_dict
33
-
31
+ from model_compression_toolkit.metadata import create_model_metadata
34
32
 
35
33
  if FOUND_TORCH:
36
34
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
@@ -109,14 +107,14 @@ if FOUND_TORCH:
109
107
  fw_impl = PytorchImplementation()
110
108
 
111
109
  # Ignore hessian info service as it is not used here yet.
112
- tg, bit_widths_config, _ = core_runner(in_model=in_module,
113
- representative_data_gen=representative_data_gen,
114
- core_config=core_config,
115
- fw_info=fw_info,
116
- fw_impl=fw_impl,
117
- tpc=target_platform_capabilities,
118
- target_resource_utilization=target_resource_utilization,
119
- tb_w=tb_w)
110
+ tg, bit_widths_config, _, scheduling_info = core_runner(in_model=in_module,
111
+ representative_data_gen=representative_data_gen,
112
+ core_config=core_config,
113
+ fw_info=fw_info,
114
+ fw_impl=fw_impl,
115
+ tpc=target_platform_capabilities,
116
+ target_resource_utilization=target_resource_utilization,
117
+ tb_w=tb_w)
120
118
 
121
119
  # At this point, tg is a graph that went through substitutions (such as BN folding) and is
122
120
  # ready for quantization (namely, it holds quantization params, etc.) but the weights are
@@ -143,7 +141,9 @@ if FOUND_TORCH:
143
141
 
144
142
  exportable_model, user_info = get_exportable_pytorch_model(graph_with_stats_correction)
145
143
  if target_platform_capabilities.tp_model.add_metadata:
146
- exportable_model = add_metadata(exportable_model, get_versions_dict(target_platform_capabilities))
144
+ exportable_model = add_metadata(exportable_model,
145
+ create_model_metadata(tpc=target_platform_capabilities,
146
+ scheduling_info=scheduling_info))
147
147
  return exportable_model, user_info
148
148
 
149
149
 
@@ -187,14 +187,14 @@ if FOUND_TF:
187
187
  fw_impl = KerasImplementation()
188
188
 
189
189
  # Ignore hessian service since is not used in QAT at the moment
190
- tg, bit_widths_config, _ = core_runner(in_model=in_model,
191
- representative_data_gen=representative_data_gen,
192
- core_config=core_config,
193
- fw_info=DEFAULT_KERAS_INFO,
194
- fw_impl=fw_impl,
195
- tpc=target_platform_capabilities,
196
- target_resource_utilization=target_resource_utilization,
197
- tb_w=tb_w)
190
+ tg, bit_widths_config, _, _ = core_runner(in_model=in_model,
191
+ representative_data_gen=representative_data_gen,
192
+ core_config=core_config,
193
+ fw_info=DEFAULT_KERAS_INFO,
194
+ fw_impl=fw_impl,
195
+ tpc=target_platform_capabilities,
196
+ target_resource_utilization=target_resource_utilization,
197
+ tb_w=tb_w)
198
198
 
199
199
  tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_KERAS_INFO, fw_impl, tb_w)
200
200
 
@@ -154,14 +154,14 @@ if FOUND_TORCH:
154
154
  fw_impl = PytorchImplementation()
155
155
 
156
156
  # Ignore hessian scores service as we do not use it here
157
- tg, bit_widths_config, _ = core_runner(in_model=in_model,
158
- representative_data_gen=representative_data_gen,
159
- core_config=core_config,
160
- fw_info=DEFAULT_PYTORCH_INFO,
161
- fw_impl=fw_impl,
162
- tpc=target_platform_capabilities,
163
- target_resource_utilization=target_resource_utilization,
164
- tb_w=tb_w)
157
+ tg, bit_widths_config, _, _ = core_runner(in_model=in_model,
158
+ representative_data_gen=representative_data_gen,
159
+ core_config=core_config,
160
+ fw_info=DEFAULT_PYTORCH_INFO,
161
+ fw_impl=fw_impl,
162
+ tpc=target_platform_capabilities,
163
+ target_resource_utilization=target_resource_utilization,
164
+ tb_w=tb_w)
165
165
 
166
166
  tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_PYTORCH_INFO, fw_impl, tb_w)
167
167
 
@@ -16,17 +16,17 @@
16
16
  from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
17
17
 
18
18
  from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.target_platform_capabilities import \
19
- tpc_dict as imx500_tpc_dict
19
+ get_tpc_dict_by_fw as get_imx500_tpc
20
20
  from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.target_platform_capabilities import \
21
- tpc_dict as tflite_tpc_dict
21
+ get_tpc_dict_by_fw as get_tflite_tpc
22
22
  from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.target_platform_capabilities import \
23
- tpc_dict as qnnpack_tpc_dict
23
+ get_tpc_dict_by_fw as get_qnnpack_tpc
24
24
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, TFLITE_TP_MODEL, QNNPACK_TP_MODEL, LATEST
25
25
 
26
- tpc_dict = {DEFAULT_TP_MODEL: imx500_tpc_dict,
27
- IMX500_TP_MODEL: imx500_tpc_dict,
28
- TFLITE_TP_MODEL: tflite_tpc_dict,
29
- QNNPACK_TP_MODEL: qnnpack_tpc_dict}
26
+ tpc_dict = {DEFAULT_TP_MODEL: get_imx500_tpc,
27
+ IMX500_TP_MODEL: get_imx500_tpc,
28
+ TFLITE_TP_MODEL: get_tflite_tpc,
29
+ QNNPACK_TP_MODEL: get_qnnpack_tpc}
30
30
 
31
31
 
32
32
  def get_target_platform_capabilities(fw_name: str,
@@ -47,13 +47,10 @@ def get_target_platform_capabilities(fw_name: str,
47
47
  """
48
48
  assert target_platform_name in tpc_dict, f'Target platform {target_platform_name} is not defined!'
49
49
  fw_tpc = tpc_dict.get(target_platform_name)
50
- assert fw_name in fw_tpc, f'Framework {fw_name} is not supported in {target_platform_name}. Please make sure the relevant ' \
51
- f'packages are installed when using MCT for optimizing a {fw_name} model. ' \
52
- f'For Tensorflow, please install tensorflow. ' \
53
- f'For PyTorch, please install torch.'
54
- tpc_versions = fw_tpc.get(fw_name)
50
+ tpc_versions = fw_tpc(fw_name)
55
51
  if target_platform_version is None:
56
52
  target_platform_version = LATEST
57
53
  else:
58
- assert target_platform_version in tpc_versions, f'TPC version {target_platform_version} is not supported for framework {fw_name}.'
54
+ assert target_platform_version in tpc_versions, (f'TPC version {target_platform_version} is not supported for '
55
+ f'framework {fw_name}.')
59
56
  return tpc_versions[target_platform_version]()
@@ -12,61 +12,74 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from model_compression_toolkit.logger import Logger
15
16
 
16
17
  from model_compression_toolkit.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH
17
18
  from model_compression_toolkit.target_platform_capabilities.constants import LATEST
18
19
 
19
- ###############################
20
- # Build Tensorflow TPC models
21
- ###############################
22
- keras_tpc_models_dict = None
23
- if FOUND_TF:
24
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_keras_tpc_latest
25
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1
26
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v1_lut
27
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_keras import get_keras_tpc as get_keras_tpc_v1_pot
28
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_keras import get_keras_tpc as get_keras_tpc_v2
29
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v2_lut
30
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_keras import get_keras_tpc as get_keras_tpc_v3
31
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v3_lut
32
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_keras import get_keras_tpc as get_keras_tpc_v4
33
20
 
34
- # Keras: TPC versioning
35
- keras_tpc_models_dict = {'v1': get_keras_tpc_v1,
36
- 'v1_lut': get_keras_tpc_v1_lut,
37
- 'v1_pot': get_keras_tpc_v1_pot,
38
- 'v2': get_keras_tpc_v2,
39
- 'v2_lut': get_keras_tpc_v2_lut,
40
- 'v3': get_keras_tpc_v3,
41
- 'v3_lut': get_keras_tpc_v3_lut,
42
- 'v4': get_keras_tpc_v4,
43
- LATEST: get_keras_tpc_latest}
21
+ def get_tpc_dict_by_fw(fw_name):
22
+ tpc_models_dict = None
23
+ if fw_name == TENSORFLOW:
24
+ ###############################
25
+ # Build Tensorflow TPC models
26
+ ###############################
27
+ if FOUND_TF:
28
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import \
29
+ get_keras_tpc_latest
30
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_keras import \
31
+ get_keras_tpc as get_keras_tpc_v1
32
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_keras import \
33
+ get_keras_tpc as get_keras_tpc_v1_lut
34
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_keras import \
35
+ get_keras_tpc as get_keras_tpc_v1_pot
36
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_keras import \
37
+ get_keras_tpc as get_keras_tpc_v2
38
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_keras import \
39
+ get_keras_tpc as get_keras_tpc_v2_lut
40
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_keras import \
41
+ get_keras_tpc as get_keras_tpc_v3
42
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_keras import \
43
+ get_keras_tpc as get_keras_tpc_v3_lut
44
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_keras import \
45
+ get_keras_tpc as get_keras_tpc_v4
44
46
 
45
- ###############################
46
- # Build Pytorch TPC models
47
- ###############################
48
- pytorch_tpc_models_dict = None
49
- if FOUND_TORCH:
50
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_pytorch_tpc_latest
51
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_pytorch import \
52
- get_pytorch_tpc as get_pytorch_tpc_v1
53
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_pytorch import \
54
- get_pytorch_tpc as get_pytorch_tpc_v1_pot
55
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_pytorch import \
56
- get_pytorch_tpc as get_pytorch_tpc_v1_lut
57
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_pytorch import \
58
- get_pytorch_tpc as get_pytorch_tpc_v2
59
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_pytorch import \
60
- get_pytorch_tpc as get_pytorch_tpc_v2_lut
61
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_pytorch import \
62
- get_pytorch_tpc as get_pytorch_tpc_v3
63
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_pytorch import \
64
- get_pytorch_tpc as get_pytorch_tpc_v3_lut
65
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_pytorch import \
66
- get_pytorch_tpc as get_pytorch_tpc_v4
47
+ # Keras: TPC versioning
48
+ tpc_models_dict = {'v1': get_keras_tpc_v1,
49
+ 'v1_lut': get_keras_tpc_v1_lut,
50
+ 'v1_pot': get_keras_tpc_v1_pot,
51
+ 'v2': get_keras_tpc_v2,
52
+ 'v2_lut': get_keras_tpc_v2_lut,
53
+ 'v3': get_keras_tpc_v3,
54
+ 'v3_lut': get_keras_tpc_v3_lut,
55
+ 'v4': get_keras_tpc_v4,
56
+ LATEST: get_keras_tpc_latest}
57
+ elif fw_name == PYTORCH:
58
+ ###############################
59
+ # Build Pytorch TPC models
60
+ ###############################
61
+ if FOUND_TORCH:
62
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import \
63
+ get_pytorch_tpc_latest
64
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_pytorch import \
65
+ get_pytorch_tpc as get_pytorch_tpc_v1
66
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_pot.tpc_pytorch import \
67
+ get_pytorch_tpc as get_pytorch_tpc_v1_pot
68
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1_lut.tpc_pytorch import \
69
+ get_pytorch_tpc as get_pytorch_tpc_v1_lut
70
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_pytorch import \
71
+ get_pytorch_tpc as get_pytorch_tpc_v2
72
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_pytorch import \
73
+ get_pytorch_tpc as get_pytorch_tpc_v2_lut
74
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_pytorch import \
75
+ get_pytorch_tpc as get_pytorch_tpc_v3
76
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_pytorch import \
77
+ get_pytorch_tpc as get_pytorch_tpc_v3_lut
78
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_pytorch import \
79
+ get_pytorch_tpc as get_pytorch_tpc_v4
67
80
 
68
- # Pytorch: TPC versioning
69
- pytorch_tpc_models_dict = {'v1': get_pytorch_tpc_v1,
81
+ # Pytorch: TPC versioning
82
+ tpc_models_dict = {'v1': get_pytorch_tpc_v1,
70
83
  'v1_lut': get_pytorch_tpc_v1_lut,
71
84
  'v1_pot': get_pytorch_tpc_v1_pot,
72
85
  'v2': get_pytorch_tpc_v2,
@@ -75,7 +88,10 @@ if FOUND_TORCH:
75
88
  'v3_lut': get_pytorch_tpc_v3_lut,
76
89
  'v4': get_pytorch_tpc_v4,
77
90
  LATEST: get_pytorch_tpc_latest}
78
-
79
- tpc_dict = {TENSORFLOW: keras_tpc_models_dict,
80
- PYTORCH: pytorch_tpc_models_dict}
81
-
91
+ if tpc_models_dict is not None:
92
+ return tpc_models_dict
93
+ else:
94
+ Logger.critical(f'Framework {fw_name} is not supported in imx500 or the relevant packages are not '
95
+ f'installed. Please make sure the relevant packages are installed when using MCT for optimizing'
96
+ f' a {fw_name} model. For Tensorflow, please install tensorflow. For PyTorch, please install '
97
+ f'torch.') # pragma: no cover
@@ -14,35 +14,41 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from model_compression_toolkit.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH
17
+ from model_compression_toolkit.logger import Logger
17
18
  from model_compression_toolkit.target_platform_capabilities.constants import LATEST
18
19
 
19
-
20
- ###############################
21
- # Build Tensorflow TPC models
22
- ###############################
23
- keras_tpc_models_dict = None
24
- if FOUND_TF:
25
- from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1
26
- from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.latest import get_keras_tpc_latest
27
-
28
- # Keras: TPC versioning
29
- keras_tpc_models_dict = {'v1': get_keras_tpc_v1,
30
- LATEST: get_keras_tpc_latest}
31
-
32
- ###############################
33
- # Build Pytorch TPC models
34
- ###############################
35
- pytorch_tpc_models_dict = None
36
- if FOUND_TORCH:
37
- from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_pytorch import \
38
- get_pytorch_tpc as get_pytorch_tpc_v1
39
- from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.latest import get_pytorch_tpc_latest
40
-
41
- # Pytorch: TPC versioning
42
- pytorch_tpc_models_dict = {'v1': get_pytorch_tpc_v1,
20
+ def get_tpc_dict_by_fw(fw_name):
21
+ tpc_models_dict = None
22
+ if fw_name == TENSORFLOW:
23
+ ###############################
24
+ # Build Tensorflow TPC models
25
+ ###############################
26
+ if FOUND_TF:
27
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_keras import \
28
+ get_keras_tpc as get_keras_tpc_v1
29
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.latest import \
30
+ get_keras_tpc_latest
31
+
32
+ # Keras: TPC versioning
33
+ tpc_models_dict = {'v1': get_keras_tpc_v1,
34
+ LATEST: get_keras_tpc_latest}
35
+ elif fw_name == PYTORCH:
36
+ ###############################
37
+ # Build Pytorch TPC models
38
+ ###############################
39
+ if FOUND_TORCH:
40
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_pytorch import \
41
+ get_pytorch_tpc as get_pytorch_tpc_v1
42
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.latest import \
43
+ get_pytorch_tpc_latest
44
+
45
+ # Pytorch: TPC versioning
46
+ tpc_models_dict = {'v1': get_pytorch_tpc_v1,
43
47
  LATEST: get_pytorch_tpc_latest}
44
-
45
- tpc_dict = {TENSORFLOW: keras_tpc_models_dict,
46
- PYTORCH: pytorch_tpc_models_dict}
47
-
48
-
48
+ if tpc_models_dict is not None:
49
+ return tpc_models_dict
50
+ else:
51
+ Logger.critical(f'Framework {fw_name} is not supported in imx500 or the relevant packages are not '
52
+ f'installed. Please make sure the relevant packages are installed when using MCT for optimizing'
53
+ f' a {fw_name} model. For Tensorflow, please install tensorflow. For PyTorch, please install '
54
+ f'torch.') # pragma: no cover
@@ -14,34 +14,41 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from model_compression_toolkit.constants import FOUND_TF, FOUND_TORCH, TENSORFLOW, PYTORCH
17
+ from model_compression_toolkit.logger import Logger
17
18
  from model_compression_toolkit.target_platform_capabilities.constants import LATEST
18
19
 
19
-
20
- ###############################
21
- # Build Tensorflow TPC models
22
- ###############################
23
- keras_tpc_models_dict = None
24
- if FOUND_TF:
25
- from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_v1
26
- from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.latest import get_keras_tpc_latest
27
-
28
- # Keras: TPC versioning
29
- keras_tpc_models_dict = {'v1': get_keras_tpc_v1,
30
- LATEST: get_keras_tpc_latest}
31
-
32
- ###############################
33
- # Build Pytorch TPC models
34
- ###############################
35
- pytorch_tpc_models_dict = None
36
- if FOUND_TORCH:
37
- from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_pytorch import \
38
- get_pytorch_tpc as get_pytorch_tpc_v1
39
- from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.latest import get_pytorch_tpc_latest
40
-
41
- # Pytorch: TPC versioning
42
- pytorch_tpc_models_dict = {'v1': get_pytorch_tpc_v1,
20
+ def get_tpc_dict_by_fw(fw_name):
21
+ tpc_models_dict = None
22
+ if fw_name == TENSORFLOW:
23
+ ###############################
24
+ # Build Tensorflow TPC models
25
+ ###############################
26
+ if FOUND_TF:
27
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_keras import \
28
+ get_keras_tpc as get_keras_tpc_v1
29
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.latest import \
30
+ get_keras_tpc_latest
31
+
32
+ # Keras: TPC versioning
33
+ tpc_models_dict = {'v1': get_keras_tpc_v1,
34
+ LATEST: get_keras_tpc_latest}
35
+ elif fw_name == PYTORCH:
36
+ ###############################
37
+ # Build Pytorch TPC models
38
+ ###############################
39
+ if FOUND_TORCH:
40
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_pytorch import \
41
+ get_pytorch_tpc as get_pytorch_tpc_v1
42
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.latest import \
43
+ get_pytorch_tpc_latest
44
+
45
+ # Pytorch: TPC versioning
46
+ tpc_models_dict = {'v1': get_pytorch_tpc_v1,
43
47
  LATEST: get_pytorch_tpc_latest}
44
-
45
- tpc_dict = {TENSORFLOW: keras_tpc_models_dict,
46
- PYTORCH: pytorch_tpc_models_dict}
47
-
48
+ if tpc_models_dict is not None:
49
+ return tpc_models_dict
50
+ else:
51
+ Logger.critical(f'Framework {fw_name} is not supported in imx500 or the relevant packages are not '
52
+ f'installed. Please make sure the relevant packages are installed when using MCT for optimizing'
53
+ f' a {fw_name} model. For Tensorflow, please install tensorflow. For PyTorch, please install '
54
+ f'torch.') # pragma: no cover
@@ -27,6 +27,8 @@ INTERMEDIATE_SIMILARITY_METRICS_VAL = 'intermediate_similarity_metrics_val'
27
27
  # Graph attribute names:
28
28
  XQUANT_REPR = 'xquant_repr'
29
29
  XQUANT_VAL = 'xquant_val'
30
+ CUT_MEMORY_ELEMENTS = 'cut_memory_elements'
31
+ CUT_TOTAL_SIZE = 'cut_total_size'
30
32
 
31
33
  # Report file name:
32
34
  REPORT_FILENAME = 'quant_report.json'
@@ -36,3 +38,4 @@ TENSORBOARD_DEFAULT_TAG = 'xquant'
36
38
 
37
39
  # When extracting the activations of a model we hold the output using a dedicated key:
38
40
  MODEL_OUTPUT_KEY = 'model_output_key'
41
+
@@ -45,6 +45,9 @@ def core_report_generator(float_model: Any,
45
45
  Returns:
46
46
  Dict[str, Any]: A dictionary containing the collected similarity metrics and report data.
47
47
  """
48
+ # Get metadata from the quantized model
49
+ quantized_model_metadata = fw_report_utils.get_metadata_fn(quantized_model)
50
+
48
51
  # Collect histograms on the float model.
49
52
  float_graph = fw_report_utils.model_folding_utils.create_float_folded_graph(float_model, repr_dataset)
50
53
  mi = ModelCollector(float_graph, fw_report_utils.fw_impl, fw_report_utils.fw_info)
@@ -74,7 +77,12 @@ def core_report_generator(float_model: Any,
74
77
  # Add a graph of the quantized model with the similarity metrics to TensorBoard for visualization.
75
78
  fw_report_utils.tb_utils.add_graph_to_tensorboard(quantized_model,
76
79
  similarity_metrics,
77
- repr_dataset)
80
+ repr_dataset,
81
+ quantized_model_metadata)
82
+
83
+ # Adds text information (like max cut and output similarity metrics) to the tensorboard writer.
84
+ fw_report_utils.tb_utils.add_text_information(similarity_metrics,
85
+ quantized_model_metadata)
78
86
 
79
87
  # Save data to a json file.
80
88
  fw_report_utils.dump_report_to_json(report_dir=xquant_config.report_dir,
@@ -18,7 +18,7 @@ import os
18
18
 
19
19
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
20
20
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
21
- from typing import Any, Dict
21
+ from typing import Any, Dict, Callable
22
22
 
23
23
  from model_compression_toolkit.xquant.common.constants import REPORT_FILENAME
24
24
  from model_compression_toolkit.xquant.common.dataset_utils import DatasetUtils
@@ -39,7 +39,8 @@ class FrameworkReportUtils:
39
39
  similarity_calculator: SimilarityCalculator,
40
40
  dataset_utils: DatasetUtils,
41
41
  model_folding_utils: ModelFoldingUtils,
42
- tb_utils: TensorboardUtils):
42
+ tb_utils: TensorboardUtils,
43
+ get_metadata_fn: Callable):
43
44
  """
44
45
  Initializes the FrameworkReportUtils class with various utility components required for generating the report.
45
46
 
@@ -50,6 +51,7 @@ class FrameworkReportUtils:
50
51
  dataset_utils (DatasetUtils): Utilities for handling datasets.
51
52
  model_folding_utils (ModelFoldingUtils): Utilities for model folding operations.
52
53
  tb_utils (TensorboardUtils): Utilities for TensorBoard operations.
54
+ get_metadata_fn (Callable): Function to retrieve the metadata from the quantized model.
53
55
  """
54
56
  self.fw_info = fw_info
55
57
  self.fw_impl = fw_impl
@@ -57,18 +59,7 @@ class FrameworkReportUtils:
57
59
  self.dataset_utils = dataset_utils
58
60
  self.model_folding_utils = model_folding_utils
59
61
  self.tb_utils = tb_utils
60
-
61
- def create_report_directory(self, dir_path: str):
62
- """
63
- Create a directory for saving reports.
64
-
65
- Args:
66
- dir_path (str): The path to the directory to create.
67
-
68
- """
69
- if not os.path.exists(dir_path):
70
- os.makedirs(dir_path, exist_ok=True)
71
- Logger.info(f"Directory created at: {dir_path}")
62
+ self.get_metadata_fn = get_metadata_fn
72
63
 
73
64
  def dump_report_to_json(self,
74
65
  report_dir: str,