mct-nightly 2.2.0.20250113.527__py3-none-any.whl → 2.2.0.20250114.84821__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/RECORD +103 -105
  3. model_compression_toolkit/__init__.py +2 -2
  4. model_compression_toolkit/core/common/framework_info.py +1 -3
  5. model_compression_toolkit/core/common/fusion/layer_fusing.py +6 -5
  6. model_compression_toolkit/core/common/graph/base_graph.py +20 -21
  7. model_compression_toolkit/core/common/graph/base_node.py +44 -17
  8. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +7 -6
  9. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +0 -6
  10. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +26 -135
  11. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +36 -62
  12. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +667 -0
  13. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +25 -202
  14. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py +164 -470
  15. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +30 -7
  16. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +3 -5
  17. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  18. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +7 -6
  19. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +0 -1
  20. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +0 -1
  21. model_compression_toolkit/core/common/pruning/pruner.py +5 -3
  22. model_compression_toolkit/core/common/quantization/bit_width_config.py +6 -12
  23. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -2
  24. model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
  25. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -1
  26. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  27. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +1 -1
  28. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +1 -1
  29. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +1 -1
  30. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +1 -1
  31. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +1 -1
  32. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +1 -1
  33. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +15 -14
  34. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
  35. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +1 -1
  36. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  37. model_compression_toolkit/core/graph_prep_runner.py +12 -11
  38. model_compression_toolkit/core/keras/data_util.py +24 -5
  39. model_compression_toolkit/core/keras/default_framework_info.py +1 -1
  40. model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +1 -2
  41. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +5 -6
  42. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  43. model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
  44. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
  45. model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -1
  46. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +4 -5
  47. model_compression_toolkit/core/runner.py +33 -60
  48. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +1 -1
  49. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +1 -1
  50. model_compression_toolkit/gptq/keras/quantization_facade.py +8 -9
  51. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  52. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
  53. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
  54. model_compression_toolkit/gptq/pytorch/quantization_facade.py +8 -9
  55. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  56. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
  57. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
  58. model_compression_toolkit/metadata.py +11 -10
  59. model_compression_toolkit/pruning/keras/pruning_facade.py +5 -6
  60. model_compression_toolkit/pruning/pytorch/pruning_facade.py +6 -7
  61. model_compression_toolkit/ptq/keras/quantization_facade.py +8 -9
  62. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -9
  63. model_compression_toolkit/qat/keras/quantization_facade.py +5 -6
  64. model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py +1 -1
  65. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
  66. model_compression_toolkit/qat/pytorch/quantization_facade.py +5 -9
  67. model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py +1 -1
  68. model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py +1 -1
  69. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
  70. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +1 -1
  71. model_compression_toolkit/target_platform_capabilities/__init__.py +9 -0
  72. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  73. model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +2 -2
  74. model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +18 -18
  75. model_compression_toolkit/target_platform_capabilities/schema/v1.py +13 -13
  76. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/__init__.py +6 -6
  77. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2fw.py +10 -10
  78. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2keras.py +3 -3
  79. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2pytorch.py +3 -2
  80. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/current_tpc.py +8 -8
  81. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities.py → targetplatform2framework/framework_quantization_capabilities.py} +40 -40
  82. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities_component.py → targetplatform2framework/framework_quantization_capabilities_component.py} +2 -2
  83. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/layer_filter_params.py +0 -1
  84. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/operations_to_layers.py +8 -8
  85. model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +24 -24
  86. model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +18 -18
  87. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +3 -3
  88. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/{tp_model.py → tpc.py} +31 -32
  89. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +3 -3
  90. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/{tp_model.py → tpc.py} +27 -27
  91. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +4 -4
  92. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/{tp_model.py → tpc.py} +27 -27
  93. model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +1 -2
  94. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +2 -1
  95. model_compression_toolkit/trainable_infrastructure/keras/activation_quantizers/lsq/symmetric_lsq.py +1 -2
  96. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +1 -1
  97. model_compression_toolkit/xquant/common/model_folding_utils.py +7 -6
  98. model_compression_toolkit/xquant/keras/keras_report_utils.py +4 -4
  99. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -3
  100. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +0 -105
  101. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +0 -33
  102. model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +0 -23
  103. {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/LICENSE.md +0 -0
  104. {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/WHEEL +0 -0
  105. {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/top_level.txt +0 -0
  106. /model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attribute_filter.py +0 -0
@@ -18,33 +18,34 @@ from typing import Dict, Any
18
18
  from model_compression_toolkit.constants import OPERATORS_SCHEDULING, FUSED_NODES_MAPPING, CUTS, MAX_CUT, OP_ORDER, \
19
19
  OP_RECORD, SHAPE, NODE_OUTPUT_INDEX, NODE_NAME, TOTAL_SIZE, MEM_ELEMENTS
20
20
  from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import SchedulerInfo
21
- from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
21
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
22
+ FrameworkQuantizationCapabilities
22
23
 
23
24
 
24
- def create_model_metadata(tpc: TargetPlatformCapabilities,
25
+ def create_model_metadata(fqc: FrameworkQuantizationCapabilities,
25
26
  scheduling_info: SchedulerInfo = None) -> Dict:
26
27
  """
27
28
  Creates and returns a metadata dictionary for the model, including version information
28
29
  and optional scheduling information.
29
30
 
30
31
  Args:
31
- tpc: A TPC object to get the version.
32
+ fqc: A FQC object to get the version.
32
33
  scheduling_info: An object containing scheduling details and metadata. Default is None.
33
34
 
34
35
  Returns:
35
36
  Dict: A dictionary containing the model's version information and optional scheduling information.
36
37
  """
37
- _metadata = get_versions_dict(tpc)
38
+ _metadata = get_versions_dict(fqc)
38
39
  if scheduling_info:
39
40
  scheduler_metadata = get_scheduler_metadata(scheduler_info=scheduling_info)
40
41
  _metadata['scheduling_info'] = scheduler_metadata
41
42
  return _metadata
42
43
 
43
44
 
44
- def get_versions_dict(tpc) -> Dict:
45
+ def get_versions_dict(fqc) -> Dict:
45
46
  """
46
47
 
47
- Returns: A dictionary with TPC, MCT and TPC-Schema versions.
48
+ Returns: A dictionary with FQC, MCT and FQC-Schema versions.
48
49
 
49
50
  """
50
51
  # imported inside to avoid circular import error
@@ -53,10 +54,10 @@ def get_versions_dict(tpc) -> Dict:
53
54
  @dataclass
54
55
  class TPCVersions:
55
56
  mct_version: str
56
- tpc_minor_version: str = f'{tpc.tp_model.tpc_minor_version}'
57
- tpc_patch_version: str = f'{tpc.tp_model.tpc_patch_version}'
58
- tpc_platform_type: str = f'{tpc.tp_model.tpc_platform_type}'
59
- tpc_schema: str = f'{tpc.tp_model.SCHEMA_VERSION}'
57
+ tpc_minor_version: str = f'{fqc.tpc.tpc_minor_version}'
58
+ tpc_patch_version: str = f'{fqc.tpc.tpc_patch_version}'
59
+ tpc_platform_type: str = f'{fqc.tpc.tpc_platform_type}'
60
+ tpc_schema: str = f'{fqc.tpc.SCHEMA_VERSION}'
60
61
 
61
62
  return asdict(TPCVersions(mct_version))
62
63
 
@@ -17,7 +17,7 @@ from typing import Callable, Tuple
17
17
 
18
18
  from model_compression_toolkit import get_target_platform_capabilities
19
19
  from model_compression_toolkit.constants import TENSORFLOW
20
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
20
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
21
21
  from model_compression_toolkit.verify_packages import FOUND_TF
22
22
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
23
23
  from model_compression_toolkit.core.common.pruning.pruner import Pruner
@@ -26,17 +26,16 @@ from model_compression_toolkit.core.common.pruning.pruning_info import PruningIn
26
26
  from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
27
27
  from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
28
28
  from model_compression_toolkit.logger import Logger
29
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
30
29
  from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
31
30
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
32
31
 
33
32
  if FOUND_TF:
33
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
34
+ AttachTpcToKeras
34
35
  from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
35
36
  from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation
36
37
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
37
38
  from tensorflow.keras.models import Model
38
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
39
- AttachTpcToKeras
40
39
 
41
40
  DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
42
41
 
@@ -44,7 +43,7 @@ if FOUND_TF:
44
43
  target_resource_utilization: ResourceUtilization,
45
44
  representative_data_gen: Callable,
46
45
  pruning_config: PruningConfig = PruningConfig(),
47
- target_platform_capabilities: TargetPlatformModel = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]:
46
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]:
48
47
  """
49
48
  Perform structured pruning on a Keras model to meet a specified target resource utilization.
50
49
  This function prunes the provided model according to the target resource utilization by grouping and pruning
@@ -62,7 +61,7 @@ if FOUND_TF:
62
61
  target_resource_utilization (ResourceUtilization): The target Key Performance Indicators to be achieved through pruning.
63
62
  representative_data_gen (Callable): A function to generate representative data for pruning analysis.
64
63
  pruning_config (PruningConfig): Configuration settings for the pruning process. Defaults to standard config.
65
- target_platform_capabilities (TargetPlatformCapabilities): Platform-specific constraints and capabilities. Defaults to DEFAULT_KERAS_TPC.
64
+ target_platform_capabilities (FrameworkQuantizationCapabilities): Platform-specific constraints and capabilities. Defaults to DEFAULT_KERAS_TPC.
66
65
 
67
66
  Returns:
68
67
  Tuple[Model, PruningInfo]: A tuple containing the pruned Keras model and associated pruning information.
@@ -16,7 +16,7 @@
16
16
  from typing import Callable, Tuple
17
17
  from model_compression_toolkit import get_target_platform_capabilities
18
18
  from model_compression_toolkit.constants import PYTORCH
19
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
19
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
20
20
  from model_compression_toolkit.verify_packages import FOUND_TORCH
21
21
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
22
22
  from model_compression_toolkit.core.common.pruning.pruner import Pruner
@@ -25,7 +25,6 @@ from model_compression_toolkit.core.common.pruning.pruning_info import PruningIn
25
25
  from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
26
26
  from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
27
27
  from model_compression_toolkit.logger import Logger
28
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
29
28
  from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
30
29
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
31
30
 
@@ -38,7 +37,7 @@ if FOUND_TORCH:
38
37
  PruningPytorchImplementation
39
38
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
40
39
  from torch.nn import Module
41
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
40
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
42
41
  AttachTpcToPytorch
43
42
 
44
43
  # Set the default Target Platform Capabilities (TPC) for PyTorch.
@@ -48,7 +47,7 @@ if FOUND_TORCH:
48
47
  target_resource_utilization: ResourceUtilization,
49
48
  representative_data_gen: Callable,
50
49
  pruning_config: PruningConfig = PruningConfig(),
51
- target_platform_capabilities: TargetPlatformModel = DEFAULT_PYOTRCH_TPC) -> \
50
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYOTRCH_TPC) -> \
52
51
  Tuple[Module, PruningInfo]:
53
52
  """
54
53
  Perform structured pruning on a Pytorch model to meet a specified target resource utilization.
@@ -121,12 +120,12 @@ if FOUND_TORCH:
121
120
 
122
121
  # Attach TPC to framework
123
122
  attach2pytorch = AttachTpcToPytorch()
124
- target_platform_capabilities = attach2pytorch.attach(target_platform_capabilities)
123
+ framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities)
125
124
 
126
125
  # Convert the original Pytorch model to an internal graph representation.
127
126
  float_graph = read_model_to_graph(model,
128
127
  representative_data_gen,
129
- target_platform_capabilities,
128
+ framework_platform_capabilities,
130
129
  DEFAULT_PYTORCH_INFO,
131
130
  fw_impl)
132
131
 
@@ -143,7 +142,7 @@ if FOUND_TORCH:
143
142
  target_resource_utilization,
144
143
  representative_data_gen,
145
144
  pruning_config,
146
- target_platform_capabilities)
145
+ framework_platform_capabilities)
147
146
 
148
147
  # Apply the pruning process.
149
148
  pruned_graph = pruner.prune_graph()
@@ -22,17 +22,18 @@ from model_compression_toolkit.core.common.quantization.quantize_graph_weights i
22
22
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
23
23
  from model_compression_toolkit.logger import Logger
24
24
  from model_compression_toolkit.constants import TENSORFLOW
25
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
25
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
26
26
  from model_compression_toolkit.verify_packages import FOUND_TF
27
27
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
28
28
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
29
29
  MixedPrecisionQuantizationConfig
30
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
31
30
  from model_compression_toolkit.core.runner import core_runner
32
31
  from model_compression_toolkit.ptq.runner import ptq_runner
33
32
  from model_compression_toolkit.metadata import create_model_metadata
34
33
 
35
34
  if FOUND_TF:
35
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
36
+ AttachTpcToKeras
36
37
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
37
38
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
38
39
  from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
@@ -42,8 +43,6 @@ if FOUND_TF:
42
43
 
43
44
  from model_compression_toolkit import get_target_platform_capabilities
44
45
  from mct_quantizers.keras.metadata import add_metadata
45
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
46
- AttachTpcToKeras
47
46
 
48
47
  DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
49
48
 
@@ -52,7 +51,7 @@ if FOUND_TF:
52
51
  representative_data_gen: Callable,
53
52
  target_resource_utilization: ResourceUtilization = None,
54
53
  core_config: CoreConfig = CoreConfig(),
55
- target_platform_capabilities: TargetPlatformModel = DEFAULT_KERAS_TPC):
54
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
56
55
  """
57
56
  Quantize a trained Keras model using post-training quantization. The model is quantized using a
58
57
  symmetric constraint quantization thresholds (power of two).
@@ -139,7 +138,7 @@ if FOUND_TF:
139
138
  fw_impl = KerasImplementation()
140
139
 
141
140
  attach2keras = AttachTpcToKeras()
142
- target_platform_capabilities = attach2keras.attach(
141
+ framework_platform_capabilities = attach2keras.attach(
143
142
  target_platform_capabilities,
144
143
  custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
145
144
 
@@ -149,7 +148,7 @@ if FOUND_TF:
149
148
  core_config=core_config,
150
149
  fw_info=fw_info,
151
150
  fw_impl=fw_impl,
152
- tpc=target_platform_capabilities,
151
+ fqc=framework_platform_capabilities,
153
152
  target_resource_utilization=target_resource_utilization,
154
153
  tb_w=tb_w)
155
154
 
@@ -177,9 +176,9 @@ if FOUND_TF:
177
176
  fw_info)
178
177
 
179
178
  exportable_model, user_info = get_exportable_keras_model(graph_with_stats_correction)
180
- if target_platform_capabilities.tp_model.add_metadata:
179
+ if framework_platform_capabilities.tpc.add_metadata:
181
180
  exportable_model = add_metadata(exportable_model,
182
- create_model_metadata(tpc=target_platform_capabilities,
181
+ create_model_metadata(fqc=framework_platform_capabilities,
183
182
  scheduling_info=scheduling_info))
184
183
  return exportable_model, user_info
185
184
 
@@ -19,9 +19,8 @@ from typing import Callable
19
19
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
20
20
  from model_compression_toolkit.logger import Logger
21
21
  from model_compression_toolkit.constants import PYTORCH
22
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
22
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
23
23
  from model_compression_toolkit.verify_packages import FOUND_TORCH
24
- from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
25
24
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
26
25
  from model_compression_toolkit.core import CoreConfig
27
26
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
@@ -40,7 +39,7 @@ if FOUND_TORCH:
40
39
  from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
41
40
  from model_compression_toolkit import get_target_platform_capabilities
42
41
  from mct_quantizers.pytorch.metadata import add_metadata
43
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
42
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
44
43
  AttachTpcToPytorch
45
44
 
46
45
  DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
@@ -49,11 +48,11 @@ if FOUND_TORCH:
49
48
  representative_data_gen: Callable,
50
49
  target_resource_utilization: ResourceUtilization = None,
51
50
  core_config: CoreConfig = CoreConfig(),
52
- target_platform_capabilities: TargetPlatformModel = DEFAULT_PYTORCH_TPC):
51
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
53
52
  """
54
53
  Quantize a trained Pytorch module using post-training quantization.
55
54
  By default, the module is quantized using a symmetric constraint quantization thresholds
56
- (power of two) as defined in the default TargetPlatformCapabilities.
55
+ (power of two) as defined in the default FrameworkQuantizationCapabilities.
57
56
  The module is first optimized using several transformations (e.g. BatchNormalization folding to
58
57
  preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
59
58
  being collected for each layer's output (and input, depends on the quantization configuration).
@@ -112,7 +111,7 @@ if FOUND_TORCH:
112
111
 
113
112
  # Attach tpc model to framework
114
113
  attach2pytorch = AttachTpcToPytorch()
115
- target_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
114
+ framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
116
115
  core_config.quantization_config.custom_tpc_opset_to_layer)
117
116
 
118
117
  # Ignore hessian info service as it is not used here yet.
@@ -121,7 +120,7 @@ if FOUND_TORCH:
121
120
  core_config=core_config,
122
121
  fw_info=fw_info,
123
122
  fw_impl=fw_impl,
124
- tpc=target_platform_capabilities,
123
+ fqc=framework_platform_capabilities,
125
124
  target_resource_utilization=target_resource_utilization,
126
125
  tb_w=tb_w)
127
126
 
@@ -149,9 +148,9 @@ if FOUND_TORCH:
149
148
  fw_info)
150
149
 
151
150
  exportable_model, user_info = get_exportable_pytorch_model(graph_with_stats_correction)
152
- if target_platform_capabilities.tp_model.add_metadata:
151
+ if framework_platform_capabilities.tpc.add_metadata:
153
152
  exportable_model = add_metadata(exportable_model,
154
- create_model_metadata(tpc=target_platform_capabilities,
153
+ create_model_metadata(fqc=framework_platform_capabilities,
155
154
  scheduling_info=scheduling_info))
156
155
  return exportable_model, user_info
157
156
 
@@ -19,13 +19,14 @@ from functools import partial
19
19
  from model_compression_toolkit.core import CoreConfig
20
20
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
21
21
  from model_compression_toolkit.logger import Logger
22
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
22
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
23
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
24
+ AttachTpcToKeras
23
25
  from model_compression_toolkit.verify_packages import FOUND_TF
24
26
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
25
27
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
26
28
  MixedPrecisionQuantizationConfig
27
29
  from mct_quantizers import KerasActivationQuantizationHolder
28
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
29
30
  from model_compression_toolkit.core.runner import core_runner
30
31
  from model_compression_toolkit.ptq.runner import ptq_runner
31
32
 
@@ -55,8 +56,6 @@ if FOUND_TF:
55
56
  from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder, \
56
57
  get_activation_quantizer_holder
57
58
  from model_compression_toolkit.qat.common.qat_config import QATConfig
58
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
59
- AttachTpcToKeras
60
59
 
61
60
  DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
62
61
 
@@ -93,7 +92,7 @@ if FOUND_TF:
93
92
  target_resource_utilization: ResourceUtilization = None,
94
93
  core_config: CoreConfig = CoreConfig(),
95
94
  qat_config: QATConfig = QATConfig(),
96
- target_platform_capabilities: TargetPlatformModel = DEFAULT_KERAS_TPC):
95
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
97
96
  """
98
97
  Prepare a trained Keras model for quantization aware training. First the model quantization is optimized
99
98
  with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
@@ -200,7 +199,7 @@ if FOUND_TF:
200
199
  core_config=core_config,
201
200
  fw_info=DEFAULT_KERAS_INFO,
202
201
  fw_impl=fw_impl,
203
- tpc=target_platform_capabilities,
202
+ fqc=target_platform_capabilities,
204
203
  target_resource_utilization=target_resource_utilization,
205
204
  tb_w=tb_w)
206
205
 
@@ -21,7 +21,7 @@ from tensorflow.python.framework.tensor_shape import TensorShape
21
21
 
22
22
  from model_compression_toolkit.trainable_infrastructure import TrainingMethod
23
23
 
24
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
24
+ from mct_quantizers import QuantizationMethod
25
25
  from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
26
26
  from mct_quantizers import QuantizationTarget, mark_quantizer
27
27
  from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
@@ -22,7 +22,7 @@ from model_compression_toolkit.trainable_infrastructure.common.constants import
22
22
 
23
23
  from model_compression_toolkit.trainable_infrastructure import TrainingMethod
24
24
 
25
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
25
+ from mct_quantizers import QuantizationMethod
26
26
  from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
27
27
  from mct_quantizers import QuantizationTarget, mark_quantizer
28
28
  from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
@@ -12,13 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- import copy
16
15
  from typing import Callable
17
16
  from functools import partial
18
17
 
19
18
  from model_compression_toolkit.constants import PYTORCH
20
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
21
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
19
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
20
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
22
21
  AttachTpcToPytorch
23
22
  from model_compression_toolkit.verify_packages import FOUND_TORCH
24
23
 
@@ -26,12 +25,9 @@ from model_compression_toolkit.core import CoreConfig
26
25
  from model_compression_toolkit.core import common
27
26
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
28
27
  from model_compression_toolkit.logger import Logger
29
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
30
28
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
31
29
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
32
30
  MixedPrecisionQuantizationConfig
33
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
34
- TargetPlatformCapabilities
35
31
  from model_compression_toolkit.core.runner import core_runner
36
32
  from model_compression_toolkit.ptq.runner import ptq_runner
37
33
 
@@ -82,7 +78,7 @@ if FOUND_TORCH:
82
78
  target_resource_utilization: ResourceUtilization = None,
83
79
  core_config: CoreConfig = CoreConfig(),
84
80
  qat_config: QATConfig = QATConfig(),
85
- target_platform_capabilities: TargetPlatformModel = DEFAULT_PYTORCH_TPC):
81
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
86
82
  """
87
83
  Prepare a trained Pytorch model for quantization aware training. First the model quantization is optimized
88
84
  with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
@@ -159,7 +155,7 @@ if FOUND_TORCH:
159
155
 
160
156
  # Attach tpc model to framework
161
157
  attach2pytorch = AttachTpcToPytorch()
162
- target_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
158
+ framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
163
159
  core_config.quantization_config.custom_tpc_opset_to_layer)
164
160
 
165
161
  # Ignore hessian scores service as we do not use it here
@@ -168,7 +164,7 @@ if FOUND_TORCH:
168
164
  core_config=core_config,
169
165
  fw_info=DEFAULT_PYTORCH_INFO,
170
166
  fw_impl=fw_impl,
171
- tpc=target_platform_capabilities,
167
+ fqc=framework_platform_capabilities,
172
168
  target_resource_utilization=target_resource_utilization,
173
169
  tb_w=tb_w)
174
170
 
@@ -18,7 +18,7 @@ import numpy as np
18
18
  import torch
19
19
  import torch.nn as nn
20
20
 
21
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
+ from mct_quantizers import QuantizationMethod
22
22
  from mct_quantizers import PytorchQuantizationWrapper
23
23
  from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
24
24
  from model_compression_toolkit import constants as C
@@ -28,7 +28,7 @@ from model_compression_toolkit.trainable_infrastructure.pytorch.quantizer_utils
28
28
  from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
29
29
  from model_compression_toolkit.trainable_infrastructure.common.trainable_quantizer_config import \
30
30
  TrainableQuantizerWeightsConfig
31
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
31
+ from mct_quantizers import QuantizationMethod
32
32
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
33
33
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import fix_range_to_include_zero
34
34
  from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_weight_quantizer import BasePytorchQATWeightTrainableQuantizer
@@ -18,7 +18,7 @@ import numpy as np
18
18
  import torch
19
19
  import torch.nn as nn
20
20
 
21
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
+ from mct_quantizers import QuantizationMethod
22
22
  from mct_quantizers import PytorchQuantizationWrapper
23
23
  from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
24
24
  from model_compression_toolkit import constants as C
@@ -20,7 +20,7 @@ from torch import Tensor
20
20
  from model_compression_toolkit.constants import RANGE_MAX, RANGE_MIN
21
21
  from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
22
22
 
23
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
+ from mct_quantizers import QuantizationMethod
24
24
  from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
25
25
 
26
26
  from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_weight_quantizer import BasePytorchQATWeightTrainableQuantizer
@@ -12,3 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+
16
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import AttributeFilter
17
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import (
18
+ FrameworkQuantizationCapabilities, OperationsSetToLayers, Smaller, SmallerEq, NotEq, Eq, GreaterEq, Greater,
19
+ LayerFilterParams, OperationsToLayers, get_current_tpc)
20
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities, OperatorsSet, \
21
+ OperatorSetGroup, Signedness, AttributeQuantizationConfig, OpQuantizationConfig, QuantizationConfigOptions, Fusing
22
+
23
+ from mct_quantizers import QuantizationMethod
@@ -21,7 +21,7 @@ LATEST = 'latest'
21
21
 
22
22
 
23
23
  # Supported TP models names:
24
- DEFAULT_TP_MODEL = 'default'
24
+ DEFAULT_TP_MODEL= 'default'
25
25
  IMX500_TP_MODEL = 'imx500'
26
26
  TFLITE_TP_MODEL = 'tflite'
27
27
  QNNPACK_TP_MODEL = 'qnnpack'
@@ -7,6 +7,6 @@ OpQuantizationConfig = schema.OpQuantizationConfig
7
7
  QuantizationConfigOptions = schema.QuantizationConfigOptions
8
8
  OperatorsSetBase = schema.OperatorsSetBase
9
9
  OperatorsSet = schema.OperatorsSet
10
- OperatorSetConcat = schema.OperatorSetConcat
10
+ OperatorSetGroup = schema.OperatorSetGroup
11
11
  Fusing = schema.Fusing
12
- TargetPlatformModel = schema.TargetPlatformModel
12
+ TargetPlatformCapabilities = schema.TargetPlatformCapabilities
@@ -16,7 +16,7 @@ from logging import Logger
16
16
  from typing import Optional
17
17
 
18
18
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
19
- TargetPlatformModel, QuantizationConfigOptions, OperatorsSetBase
19
+ TargetPlatformCapabilities, QuantizationConfigOptions, OperatorsSetBase
20
20
 
21
21
 
22
22
  def max_input_activation_n_bits(op_quantization_config: OpQuantizationConfig) -> int:
@@ -32,31 +32,31 @@ def max_input_activation_n_bits(op_quantization_config: OpQuantizationConfig) ->
32
32
  return max(op_quantization_config.supported_input_activation_n_bits)
33
33
 
34
34
 
35
- def get_config_options_by_operators_set(tp_model: TargetPlatformModel,
35
+ def get_config_options_by_operators_set(tpc: TargetPlatformCapabilities,
36
36
  operators_set_name: str) -> QuantizationConfigOptions:
37
37
  """
38
38
  Get the QuantizationConfigOptions of an OperatorsSet by its name.
39
39
 
40
40
  Args:
41
- tp_model (TargetPlatformModel): The target platform model containing the operator sets and their configurations.
41
+ tpc (TargetPlatformCapabilities): The target platform model containing the operator sets and their configurations.
42
42
  operators_set_name (str): The name of the OperatorsSet whose quantization configuration options are to be retrieved.
43
43
 
44
44
  Returns:
45
45
  QuantizationConfigOptions: The quantization configuration options associated with the specified OperatorsSet,
46
46
  or the default quantization configuration options if the OperatorsSet is not found.
47
47
  """
48
- for op_set in tp_model.operator_set:
48
+ for op_set in tpc.operator_set:
49
49
  if operators_set_name == op_set.name:
50
50
  return op_set.qc_options
51
- return tp_model.default_qco
51
+ return tpc.default_qco
52
52
 
53
53
 
54
- def get_default_op_quantization_config(tp_model: TargetPlatformModel) -> OpQuantizationConfig:
54
+ def get_default_op_quantization_config(tpc: TargetPlatformCapabilities) -> OpQuantizationConfig:
55
55
  """
56
- Get the default OpQuantizationConfig of the TargetPlatformModel.
56
+ Get the default OpQuantizationConfig of the TargetPlatformCapabilities.
57
57
 
58
58
  Args:
59
- tp_model (TargetPlatformModel): The target platform model containing the default quantization configuration.
59
+ tpc (TargetPlatformCapabilities): The target platform model containing the default quantization configuration.
60
60
 
61
61
  Returns:
62
62
  OpQuantizationConfig: The default quantization configuration.
@@ -64,32 +64,32 @@ def get_default_op_quantization_config(tp_model: TargetPlatformModel) -> OpQuant
64
64
  Raises:
65
65
  AssertionError: If the default quantization configuration list contains more than one configuration option.
66
66
  """
67
- assert len(tp_model.default_qco.quantization_configurations) == 1, \
67
+ assert len(tpc.default_qco.quantization_configurations) == 1, \
68
68
  f"Default quantization configuration options must contain only one option, " \
69
- f"but found {len(tp_model.default_qco.quantization_configurations)} configurations." # pragma: no cover
70
- return tp_model.default_qco.quantization_configurations[0]
69
+ f"but found {len(tpc.default_qco.quantization_configurations)} configurations." # pragma: no cover
70
+ return tpc.default_qco.quantization_configurations[0]
71
71
 
72
72
 
73
- def is_opset_in_model(tp_model: TargetPlatformModel, opset_name: str) -> bool:
73
+ def is_opset_in_model(tpc: TargetPlatformCapabilities, opset_name: str) -> bool:
74
74
  """
75
75
  Check whether an OperatorsSet is defined in the model.
76
76
 
77
77
  Args:
78
- tp_model (TargetPlatformModel): The target platform model containing the list of operator sets.
78
+ tpc (TargetPlatformCapabilities): The target platform model containing the list of operator sets.
79
79
  opset_name (str): The name of the OperatorsSet to check for existence.
80
80
 
81
81
  Returns:
82
82
  bool: True if an OperatorsSet with the given name exists in the target platform model,
83
83
  otherwise False.
84
84
  """
85
- return tp_model.operator_set is not None and opset_name in [x.name for x in tp_model.operator_set]
85
+ return tpc.operator_set is not None and opset_name in [x.name for x in tpc.operator_set]
86
86
 
87
- def get_opset_by_name(tp_model: TargetPlatformModel, opset_name: str) -> Optional[OperatorsSetBase]:
87
+ def get_opset_by_name(tpc: TargetPlatformCapabilities, opset_name: str) -> Optional[OperatorsSetBase]:
88
88
  """
89
89
  Get an OperatorsSet object from the model by its name.
90
90
 
91
91
  Args:
92
- tp_model (TargetPlatformModel): The target platform model containing the list of operator sets.
92
+ tpc (TargetPlatformCapabilities): The target platform model containing the list of operator sets.
93
93
  opset_name (str): The name of the OperatorsSet to be retrieved.
94
94
 
95
95
  Returns:
@@ -99,7 +99,7 @@ def get_opset_by_name(tp_model: TargetPlatformModel, opset_name: str) -> Optiona
99
99
  Raises:
100
100
  A critical log message if multiple operator sets with the same name are found.
101
101
  """
102
- opset_list = [x for x in tp_model.operator_set if x.name == opset_name]
102
+ opset_list = [x for x in tpc.operator_set if x.name == opset_name]
103
103
  if len(opset_list) > 1:
104
- Logger.critical(f"Found more than one OperatorsSet in TargetPlatformModel with the name {opset_name}.") # pragma: no cover
104
+ Logger.critical(f"Found more than one OperatorsSet in TargetPlatformCapabilities with the name {opset_name}.") # pragma: no cover
105
105
  return opset_list[0] if opset_list else None