mct-nightly 2.2.0.20250113.134913__py3-none-any.whl → 2.2.0.20250114.134534__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/RECORD +102 -104
  3. model_compression_toolkit/__init__.py +2 -2
  4. model_compression_toolkit/core/common/framework_info.py +1 -3
  5. model_compression_toolkit/core/common/fusion/layer_fusing.py +6 -5
  6. model_compression_toolkit/core/common/graph/base_graph.py +20 -21
  7. model_compression_toolkit/core/common/graph/base_node.py +44 -17
  8. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +7 -6
  9. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +187 -0
  10. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +0 -6
  11. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +35 -162
  12. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +36 -62
  13. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +668 -0
  14. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +25 -202
  15. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +74 -51
  16. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +3 -5
  17. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  18. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +7 -6
  19. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +0 -1
  20. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +0 -1
  21. model_compression_toolkit/core/common/pruning/pruner.py +5 -3
  22. model_compression_toolkit/core/common/quantization/bit_width_config.py +6 -12
  23. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -2
  24. model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
  25. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -1
  26. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  27. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +1 -1
  28. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +1 -1
  29. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +1 -1
  30. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +1 -1
  31. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +1 -1
  32. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +1 -1
  33. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +15 -14
  34. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
  35. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +1 -1
  36. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  37. model_compression_toolkit/core/graph_prep_runner.py +12 -11
  38. model_compression_toolkit/core/keras/default_framework_info.py +1 -1
  39. model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +1 -2
  40. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +5 -6
  41. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  42. model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
  43. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
  44. model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -1
  45. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +4 -5
  46. model_compression_toolkit/core/runner.py +33 -60
  47. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +1 -1
  48. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +1 -1
  49. model_compression_toolkit/gptq/keras/quantization_facade.py +8 -9
  50. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  51. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
  52. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
  53. model_compression_toolkit/gptq/pytorch/quantization_facade.py +8 -9
  54. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  55. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
  56. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
  57. model_compression_toolkit/metadata.py +11 -10
  58. model_compression_toolkit/pruning/keras/pruning_facade.py +5 -6
  59. model_compression_toolkit/pruning/pytorch/pruning_facade.py +6 -7
  60. model_compression_toolkit/ptq/keras/quantization_facade.py +8 -9
  61. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -9
  62. model_compression_toolkit/qat/keras/quantization_facade.py +5 -6
  63. model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py +1 -1
  64. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
  65. model_compression_toolkit/qat/pytorch/quantization_facade.py +5 -9
  66. model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py +1 -1
  67. model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py +1 -1
  68. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
  69. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +1 -1
  70. model_compression_toolkit/target_platform_capabilities/__init__.py +9 -0
  71. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  72. model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +2 -2
  73. model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +18 -18
  74. model_compression_toolkit/target_platform_capabilities/schema/v1.py +13 -13
  75. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/__init__.py +6 -6
  76. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2fw.py +10 -10
  77. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2keras.py +3 -3
  78. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2pytorch.py +3 -2
  79. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/current_tpc.py +8 -8
  80. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities.py → targetplatform2framework/framework_quantization_capabilities.py} +40 -40
  81. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities_component.py → targetplatform2framework/framework_quantization_capabilities_component.py} +2 -2
  82. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/layer_filter_params.py +0 -1
  83. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/operations_to_layers.py +8 -8
  84. model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +24 -24
  85. model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +18 -18
  86. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +3 -3
  87. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/{tp_model.py → tpc.py} +31 -32
  88. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +3 -3
  89. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/{tp_model.py → tpc.py} +27 -27
  90. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +4 -4
  91. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/{tp_model.py → tpc.py} +27 -27
  92. model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +1 -2
  93. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +2 -1
  94. model_compression_toolkit/trainable_infrastructure/keras/activation_quantizers/lsq/symmetric_lsq.py +1 -2
  95. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +1 -1
  96. model_compression_toolkit/xquant/common/model_folding_utils.py +7 -6
  97. model_compression_toolkit/xquant/keras/keras_report_utils.py +4 -4
  98. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -3
  99. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +0 -105
  100. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +0 -33
  101. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py +0 -528
  102. model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +0 -23
  103. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/LICENSE.md +0 -0
  104. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/WHEEL +0 -0
  105. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/top_level.txt +0 -0
  106. /model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attribute_filter.py +0 -0
@@ -19,9 +19,8 @@ from typing import Callable
19
19
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
20
20
  from model_compression_toolkit.logger import Logger
21
21
  from model_compression_toolkit.constants import PYTORCH
22
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
22
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
23
23
  from model_compression_toolkit.verify_packages import FOUND_TORCH
24
- from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
25
24
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
26
25
  from model_compression_toolkit.core import CoreConfig
27
26
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
@@ -40,7 +39,7 @@ if FOUND_TORCH:
40
39
  from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
41
40
  from model_compression_toolkit import get_target_platform_capabilities
42
41
  from mct_quantizers.pytorch.metadata import add_metadata
43
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
42
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
44
43
  AttachTpcToPytorch
45
44
 
46
45
  DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
@@ -49,11 +48,11 @@ if FOUND_TORCH:
49
48
  representative_data_gen: Callable,
50
49
  target_resource_utilization: ResourceUtilization = None,
51
50
  core_config: CoreConfig = CoreConfig(),
52
- target_platform_capabilities: TargetPlatformModel = DEFAULT_PYTORCH_TPC):
51
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
53
52
  """
54
53
  Quantize a trained Pytorch module using post-training quantization.
55
54
  By default, the module is quantized using a symmetric constraint quantization thresholds
56
- (power of two) as defined in the default TargetPlatformCapabilities.
55
+ (power of two) as defined in the default FrameworkQuantizationCapabilities.
57
56
  The module is first optimized using several transformations (e.g. BatchNormalization folding to
58
57
  preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are
59
58
  being collected for each layer's output (and input, depends on the quantization configuration).
@@ -112,7 +111,7 @@ if FOUND_TORCH:
112
111
 
113
112
  # Attach tpc model to framework
114
113
  attach2pytorch = AttachTpcToPytorch()
115
- target_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
114
+ framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
116
115
  core_config.quantization_config.custom_tpc_opset_to_layer)
117
116
 
118
117
  # Ignore hessian info service as it is not used here yet.
@@ -121,7 +120,7 @@ if FOUND_TORCH:
121
120
  core_config=core_config,
122
121
  fw_info=fw_info,
123
122
  fw_impl=fw_impl,
124
- tpc=target_platform_capabilities,
123
+ fqc=framework_platform_capabilities,
125
124
  target_resource_utilization=target_resource_utilization,
126
125
  tb_w=tb_w)
127
126
 
@@ -149,9 +148,9 @@ if FOUND_TORCH:
149
148
  fw_info)
150
149
 
151
150
  exportable_model, user_info = get_exportable_pytorch_model(graph_with_stats_correction)
152
- if target_platform_capabilities.tp_model.add_metadata:
151
+ if framework_platform_capabilities.tpc.add_metadata:
153
152
  exportable_model = add_metadata(exportable_model,
154
- create_model_metadata(tpc=target_platform_capabilities,
153
+ create_model_metadata(fqc=framework_platform_capabilities,
155
154
  scheduling_info=scheduling_info))
156
155
  return exportable_model, user_info
157
156
 
@@ -19,13 +19,14 @@ from functools import partial
19
19
  from model_compression_toolkit.core import CoreConfig
20
20
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
21
21
  from model_compression_toolkit.logger import Logger
22
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
22
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
23
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
24
+ AttachTpcToKeras
23
25
  from model_compression_toolkit.verify_packages import FOUND_TF
24
26
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
25
27
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
26
28
  MixedPrecisionQuantizationConfig
27
29
  from mct_quantizers import KerasActivationQuantizationHolder
28
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
29
30
  from model_compression_toolkit.core.runner import core_runner
30
31
  from model_compression_toolkit.ptq.runner import ptq_runner
31
32
 
@@ -55,8 +56,6 @@ if FOUND_TF:
55
56
  from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder, \
56
57
  get_activation_quantizer_holder
57
58
  from model_compression_toolkit.qat.common.qat_config import QATConfig
58
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
59
- AttachTpcToKeras
60
59
 
61
60
  DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
62
61
 
@@ -93,7 +92,7 @@ if FOUND_TF:
93
92
  target_resource_utilization: ResourceUtilization = None,
94
93
  core_config: CoreConfig = CoreConfig(),
95
94
  qat_config: QATConfig = QATConfig(),
96
- target_platform_capabilities: TargetPlatformModel = DEFAULT_KERAS_TPC):
95
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
97
96
  """
98
97
  Prepare a trained Keras model for quantization aware training. First the model quantization is optimized
99
98
  with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
@@ -200,7 +199,7 @@ if FOUND_TF:
200
199
  core_config=core_config,
201
200
  fw_info=DEFAULT_KERAS_INFO,
202
201
  fw_impl=fw_impl,
203
- tpc=target_platform_capabilities,
202
+ fqc=target_platform_capabilities,
204
203
  target_resource_utilization=target_resource_utilization,
205
204
  tb_w=tb_w)
206
205
 
@@ -21,7 +21,7 @@ from tensorflow.python.framework.tensor_shape import TensorShape
21
21
 
22
22
  from model_compression_toolkit.trainable_infrastructure import TrainingMethod
23
23
 
24
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
24
+ from mct_quantizers import QuantizationMethod
25
25
  from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
26
26
  from mct_quantizers import QuantizationTarget, mark_quantizer
27
27
  from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
@@ -22,7 +22,7 @@ from model_compression_toolkit.trainable_infrastructure.common.constants import
22
22
 
23
23
  from model_compression_toolkit.trainable_infrastructure import TrainingMethod
24
24
 
25
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
25
+ from mct_quantizers import QuantizationMethod
26
26
  from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
27
27
  from mct_quantizers import QuantizationTarget, mark_quantizer
28
28
  from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
@@ -12,13 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- import copy
16
15
  from typing import Callable
17
16
  from functools import partial
18
17
 
19
18
  from model_compression_toolkit.constants import PYTORCH
20
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
21
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
19
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
20
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
22
21
  AttachTpcToPytorch
23
22
  from model_compression_toolkit.verify_packages import FOUND_TORCH
24
23
 
@@ -26,12 +25,9 @@ from model_compression_toolkit.core import CoreConfig
26
25
  from model_compression_toolkit.core import common
27
26
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
28
27
  from model_compression_toolkit.logger import Logger
29
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
30
28
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
31
29
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
32
30
  MixedPrecisionQuantizationConfig
33
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
34
- TargetPlatformCapabilities
35
31
  from model_compression_toolkit.core.runner import core_runner
36
32
  from model_compression_toolkit.ptq.runner import ptq_runner
37
33
 
@@ -82,7 +78,7 @@ if FOUND_TORCH:
82
78
  target_resource_utilization: ResourceUtilization = None,
83
79
  core_config: CoreConfig = CoreConfig(),
84
80
  qat_config: QATConfig = QATConfig(),
85
- target_platform_capabilities: TargetPlatformModel = DEFAULT_PYTORCH_TPC):
81
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
86
82
  """
87
83
  Prepare a trained Pytorch model for quantization aware training. First the model quantization is optimized
88
84
  with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
@@ -159,7 +155,7 @@ if FOUND_TORCH:
159
155
 
160
156
  # Attach tpc model to framework
161
157
  attach2pytorch = AttachTpcToPytorch()
162
- target_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
158
+ framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
163
159
  core_config.quantization_config.custom_tpc_opset_to_layer)
164
160
 
165
161
  # Ignore hessian scores service as we do not use it here
@@ -168,7 +164,7 @@ if FOUND_TORCH:
168
164
  core_config=core_config,
169
165
  fw_info=DEFAULT_PYTORCH_INFO,
170
166
  fw_impl=fw_impl,
171
- tpc=target_platform_capabilities,
167
+ fqc=framework_platform_capabilities,
172
168
  target_resource_utilization=target_resource_utilization,
173
169
  tb_w=tb_w)
174
170
 
@@ -18,7 +18,7 @@ import numpy as np
18
18
  import torch
19
19
  import torch.nn as nn
20
20
 
21
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
+ from mct_quantizers import QuantizationMethod
22
22
  from mct_quantizers import PytorchQuantizationWrapper
23
23
  from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
24
24
  from model_compression_toolkit import constants as C
@@ -28,7 +28,7 @@ from model_compression_toolkit.trainable_infrastructure.pytorch.quantizer_utils
28
28
  from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
29
29
  from model_compression_toolkit.trainable_infrastructure.common.trainable_quantizer_config import \
30
30
  TrainableQuantizerWeightsConfig
31
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
31
+ from mct_quantizers import QuantizationMethod
32
32
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
33
33
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import fix_range_to_include_zero
34
34
  from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_weight_quantizer import BasePytorchQATWeightTrainableQuantizer
@@ -18,7 +18,7 @@ import numpy as np
18
18
  import torch
19
19
  import torch.nn as nn
20
20
 
21
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
21
+ from mct_quantizers import QuantizationMethod
22
22
  from mct_quantizers import PytorchQuantizationWrapper
23
23
  from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
24
24
  from model_compression_toolkit import constants as C
@@ -20,7 +20,7 @@ from torch import Tensor
20
20
  from model_compression_toolkit.constants import RANGE_MAX, RANGE_MIN
21
21
  from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
22
22
 
23
- from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
+ from mct_quantizers import QuantizationMethod
24
24
  from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper
25
25
 
26
26
  from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_weight_quantizer import BasePytorchQATWeightTrainableQuantizer
@@ -12,3 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+
16
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import AttributeFilter
17
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import (
18
+ FrameworkQuantizationCapabilities, OperationsSetToLayers, Smaller, SmallerEq, NotEq, Eq, GreaterEq, Greater,
19
+ LayerFilterParams, OperationsToLayers, get_current_tpc)
20
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities, OperatorsSet, \
21
+ OperatorSetGroup, Signedness, AttributeQuantizationConfig, OpQuantizationConfig, QuantizationConfigOptions, Fusing
22
+
23
+ from mct_quantizers import QuantizationMethod
@@ -21,7 +21,7 @@ LATEST = 'latest'
21
21
 
22
22
 
23
23
  # Supported TP models names:
24
- DEFAULT_TP_MODEL = 'default'
24
+ DEFAULT_TP_MODEL= 'default'
25
25
  IMX500_TP_MODEL = 'imx500'
26
26
  TFLITE_TP_MODEL = 'tflite'
27
27
  QNNPACK_TP_MODEL = 'qnnpack'
@@ -7,6 +7,6 @@ OpQuantizationConfig = schema.OpQuantizationConfig
7
7
  QuantizationConfigOptions = schema.QuantizationConfigOptions
8
8
  OperatorsSetBase = schema.OperatorsSetBase
9
9
  OperatorsSet = schema.OperatorsSet
10
- OperatorSetConcat = schema.OperatorSetConcat
10
+ OperatorSetGroup = schema.OperatorSetGroup
11
11
  Fusing = schema.Fusing
12
- TargetPlatformModel = schema.TargetPlatformModel
12
+ TargetPlatformCapabilities = schema.TargetPlatformCapabilities
@@ -16,7 +16,7 @@ from logging import Logger
16
16
  from typing import Optional
17
17
 
18
18
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
19
- TargetPlatformModel, QuantizationConfigOptions, OperatorsSetBase
19
+ TargetPlatformCapabilities, QuantizationConfigOptions, OperatorsSetBase
20
20
 
21
21
 
22
22
  def max_input_activation_n_bits(op_quantization_config: OpQuantizationConfig) -> int:
@@ -32,31 +32,31 @@ def max_input_activation_n_bits(op_quantization_config: OpQuantizationConfig) ->
32
32
  return max(op_quantization_config.supported_input_activation_n_bits)
33
33
 
34
34
 
35
- def get_config_options_by_operators_set(tp_model: TargetPlatformModel,
35
+ def get_config_options_by_operators_set(tpc: TargetPlatformCapabilities,
36
36
  operators_set_name: str) -> QuantizationConfigOptions:
37
37
  """
38
38
  Get the QuantizationConfigOptions of an OperatorsSet by its name.
39
39
 
40
40
  Args:
41
- tp_model (TargetPlatformModel): The target platform model containing the operator sets and their configurations.
41
+ tpc (TargetPlatformCapabilities): The target platform model containing the operator sets and their configurations.
42
42
  operators_set_name (str): The name of the OperatorsSet whose quantization configuration options are to be retrieved.
43
43
 
44
44
  Returns:
45
45
  QuantizationConfigOptions: The quantization configuration options associated with the specified OperatorsSet,
46
46
  or the default quantization configuration options if the OperatorsSet is not found.
47
47
  """
48
- for op_set in tp_model.operator_set:
48
+ for op_set in tpc.operator_set:
49
49
  if operators_set_name == op_set.name:
50
50
  return op_set.qc_options
51
- return tp_model.default_qco
51
+ return tpc.default_qco
52
52
 
53
53
 
54
- def get_default_op_quantization_config(tp_model: TargetPlatformModel) -> OpQuantizationConfig:
54
+ def get_default_op_quantization_config(tpc: TargetPlatformCapabilities) -> OpQuantizationConfig:
55
55
  """
56
- Get the default OpQuantizationConfig of the TargetPlatformModel.
56
+ Get the default OpQuantizationConfig of the TargetPlatformCapabilities.
57
57
 
58
58
  Args:
59
- tp_model (TargetPlatformModel): The target platform model containing the default quantization configuration.
59
+ tpc (TargetPlatformCapabilities): The target platform model containing the default quantization configuration.
60
60
 
61
61
  Returns:
62
62
  OpQuantizationConfig: The default quantization configuration.
@@ -64,32 +64,32 @@ def get_default_op_quantization_config(tp_model: TargetPlatformModel) -> OpQuant
64
64
  Raises:
65
65
  AssertionError: If the default quantization configuration list contains more than one configuration option.
66
66
  """
67
- assert len(tp_model.default_qco.quantization_configurations) == 1, \
67
+ assert len(tpc.default_qco.quantization_configurations) == 1, \
68
68
  f"Default quantization configuration options must contain only one option, " \
69
- f"but found {len(tp_model.default_qco.quantization_configurations)} configurations." # pragma: no cover
70
- return tp_model.default_qco.quantization_configurations[0]
69
+ f"but found {len(tpc.default_qco.quantization_configurations)} configurations." # pragma: no cover
70
+ return tpc.default_qco.quantization_configurations[0]
71
71
 
72
72
 
73
- def is_opset_in_model(tp_model: TargetPlatformModel, opset_name: str) -> bool:
73
+ def is_opset_in_model(tpc: TargetPlatformCapabilities, opset_name: str) -> bool:
74
74
  """
75
75
  Check whether an OperatorsSet is defined in the model.
76
76
 
77
77
  Args:
78
- tp_model (TargetPlatformModel): The target platform model containing the list of operator sets.
78
+ tpc (TargetPlatformCapabilities): The target platform model containing the list of operator sets.
79
79
  opset_name (str): The name of the OperatorsSet to check for existence.
80
80
 
81
81
  Returns:
82
82
  bool: True if an OperatorsSet with the given name exists in the target platform model,
83
83
  otherwise False.
84
84
  """
85
- return tp_model.operator_set is not None and opset_name in [x.name for x in tp_model.operator_set]
85
+ return tpc.operator_set is not None and opset_name in [x.name for x in tpc.operator_set]
86
86
 
87
- def get_opset_by_name(tp_model: TargetPlatformModel, opset_name: str) -> Optional[OperatorsSetBase]:
87
+ def get_opset_by_name(tpc: TargetPlatformCapabilities, opset_name: str) -> Optional[OperatorsSetBase]:
88
88
  """
89
89
  Get an OperatorsSet object from the model by its name.
90
90
 
91
91
  Args:
92
- tp_model (TargetPlatformModel): The target platform model containing the list of operator sets.
92
+ tpc (TargetPlatformCapabilities): The target platform model containing the list of operator sets.
93
93
  opset_name (str): The name of the OperatorsSet to be retrieved.
94
94
 
95
95
  Returns:
@@ -99,7 +99,7 @@ def get_opset_by_name(tp_model: TargetPlatformModel, opset_name: str) -> Optiona
99
99
  Raises:
100
100
  A critical log message if multiple operator sets with the same name are found.
101
101
  """
102
- opset_list = [x for x in tp_model.operator_set if x.name == opset_name]
102
+ opset_list = [x for x in tpc.operator_set if x.name == opset_name]
103
103
  if len(opset_list) > 1:
104
- Logger.critical(f"Found more than one OperatorsSet in TargetPlatformModel with the name {opset_name}.") # pragma: no cover
104
+ Logger.critical(f"Found more than one OperatorsSet in TargetPlatformCapabilities with the name {opset_name}.") # pragma: no cover
105
105
  return opset_list[0] if opset_list else None
@@ -414,7 +414,7 @@ class QuantizationConfigOptions(BaseModel):
414
414
 
415
415
  class TargetPlatformModelComponent(BaseModel):
416
416
  """
417
- Component of TargetPlatformModel (Fusing, OperatorsSet, etc.).
417
+ Component of TargetPlatformCapabilities (Fusing, OperatorsSet, etc.).
418
418
  """
419
419
  class Config:
420
420
  frozen = True
@@ -433,7 +433,7 @@ class OperatorsSet(OperatorsSetBase):
433
433
  Set of operators that are represented by a unique label.
434
434
 
435
435
  Attributes:
436
- name (Union[str, OperatorSetNames]): The set's label (must be unique within a TargetPlatformModel).
436
+ name (Union[str, OperatorSetNames]): The set's label (must be unique within a TargetPlatformCapabilities).
437
437
  qc_options (Optional[QuantizationConfigOptions]): Configuration options to use for this set of operations.
438
438
  If None, it represents a fusing set.
439
439
  type (Literal["OperatorsSet"]): Fixed type identifier.
@@ -457,7 +457,7 @@ class OperatorsSet(OperatorsSetBase):
457
457
  return {"name": self.name}
458
458
 
459
459
 
460
- class OperatorSetConcat(OperatorsSetBase):
460
+ class OperatorSetGroup(OperatorsSetBase):
461
461
  """
462
462
  Concatenate a tuple of operator sets to treat them similarly in different places (like fusing).
463
463
 
@@ -469,7 +469,7 @@ class OperatorSetConcat(OperatorsSetBase):
469
469
  name: Optional[str] = None # Will be set in the validator if not given
470
470
 
471
471
  # Define a private attribute _type
472
- type: Literal["OperatorSetConcat"] = "OperatorSetConcat"
472
+ type: Literal["OperatorSetGroup"] = "OperatorSetGroup"
473
473
 
474
474
  class Config:
475
475
  frozen = True
@@ -518,11 +518,11 @@ class Fusing(TargetPlatformModelComponent):
518
518
  hence no quantization is applied between them.
519
519
 
520
520
  Attributes:
521
- operator_groups (Tuple[Union[OperatorsSet, OperatorSetConcat], ...]): A tuple of operator groups,
522
- each being either an OperatorSetConcat or an OperatorsSet.
521
+ operator_groups (Tuple[Union[OperatorsSet, OperatorSetGroup], ...]): A tuple of operator groups,
522
+ each being either an OperatorSetGroup or an OperatorsSet.
523
523
  name (Optional[str]): The name for the Fusing instance. If not provided, it is generated from the operator groups' names.
524
524
  """
525
- operator_groups: Tuple[Annotated[Union[OperatorsSet, OperatorSetConcat], Field(discriminator='type')], ...]
525
+ operator_groups: Tuple[Annotated[Union[OperatorsSet, OperatorSetGroup], Field(discriminator='type')], ...]
526
526
  name: Optional[str] = None # Will be set in the validator if not given.
527
527
 
528
528
  class Config:
@@ -591,7 +591,7 @@ class Fusing(TargetPlatformModelComponent):
591
591
  for i in range(len(self.operator_groups) - len(other.operator_groups) + 1):
592
592
  for j in range(len(other.operator_groups)):
593
593
  if self.operator_groups[i + j] != other.operator_groups[j] and not (
594
- isinstance(self.operator_groups[i + j], OperatorSetConcat) and (
594
+ isinstance(self.operator_groups[i + j], OperatorSetGroup) and (
595
595
  other.operator_groups[j] in self.operator_groups[i + j].operators_set)):
596
596
  break
597
597
  else:
@@ -621,7 +621,7 @@ class Fusing(TargetPlatformModelComponent):
621
621
  for x in self.operator_groups
622
622
  ])
623
623
 
624
- class TargetPlatformModel(BaseModel):
624
+ class TargetPlatformCapabilities(BaseModel):
625
625
  """
626
626
  Represents the hardware configuration used for quantized model inference.
627
627
 
@@ -644,7 +644,7 @@ class TargetPlatformModel(BaseModel):
644
644
  tpc_patch_version: Optional[int]
645
645
  tpc_platform_type: Optional[str]
646
646
  add_metadata: bool = True
647
- name: Optional[str] = "default_tp_model"
647
+ name: Optional[str] = "default_tpc"
648
648
  is_simd_padding: bool = False
649
649
 
650
650
  SCHEMA_VERSION: int = 1
@@ -682,10 +682,10 @@ class TargetPlatformModel(BaseModel):
682
682
 
683
683
  def get_info(self) -> Dict[str, Any]:
684
684
  """
685
- Get a dictionary summarizing the TargetPlatformModel properties.
685
+ Get a dictionary summarizing the TargetPlatformCapabilities properties.
686
686
 
687
687
  Returns:
688
- Dict[str, Any]: Summary of the TargetPlatformModel properties.
688
+ Dict[str, Any]: Summary of the TargetPlatformCapabilities properties.
689
689
  """
690
690
  return {
691
691
  "Model name": self.name,
@@ -695,6 +695,6 @@ class TargetPlatformModel(BaseModel):
695
695
 
696
696
  def show(self):
697
697
  """
698
- Display the TargetPlatformModel.
698
+ Display the TargetPlatformCapabilities.
699
699
  """
700
700
  pprint.pprint(self.get_info(), sort_dicts=False)
@@ -13,13 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import get_current_tpc
17
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities import TargetPlatformCapabilities
18
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import \
19
- Eq, GreaterEq, NotEq, SmallerEq, Greater, Smaller
20
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.layer_filter_params import \
16
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import get_current_tpc
17
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import FrameworkQuantizationCapabilities
18
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.layer_filter_params import \
21
19
  LayerFilterParams
22
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.operations_to_layers import \
20
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import \
21
+ Eq, GreaterEq, NotEq, SmallerEq, Greater, Smaller
22
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.operations_to_layers import \
23
23
  OperationsToLayers, OperationsSetToLayers
24
24
 
25
25
 
@@ -1,12 +1,12 @@
1
1
  from typing import Dict, Optional
2
2
 
3
3
  from model_compression_toolkit.logger import Logger
4
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel, \
4
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities, \
5
5
  OperatorsSet
6
- from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, \
7
- OperationsSetToLayers
8
6
 
9
7
  from model_compression_toolkit.core.common.quantization.quantization_config import CustomOpsetLayers
8
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import \
9
+ FrameworkQuantizationCapabilities, OperationsSetToLayers
10
10
 
11
11
 
12
12
  class AttachTpcToFramework:
@@ -19,25 +19,25 @@ class AttachTpcToFramework:
19
19
  # in the operation set are provided in the mapping, a DefaultDict should be supplied to handle missing entries.
20
20
  self._opset2attr_mapping = None # Mapping of operation sets to their corresponding framework-specific layers
21
21
 
22
- def attach(self, tpc_model: TargetPlatformModel,
22
+ def attach(self, tpc_model: TargetPlatformCapabilities,
23
23
  custom_opset2layer: Optional[Dict[str, 'CustomOpsetLayers']] = None
24
- ) -> TargetPlatformCapabilities:
24
+ ) -> FrameworkQuantizationCapabilities:
25
25
  """
26
- Attaching a TargetPlatformModel which includes a platform capabilities description to specific
26
+ Attaching a TargetPlatformCapabilities which includes a platform capabilities description to specific
27
27
  framework's operators.
28
28
 
29
29
  Args:
30
- tpc_model: a TargetPlatformModel object.
30
+ tpc_model: a TargetPlatformCapabilities object.
31
31
  custom_opset2layer: optional set of custom operator sets which allows to add/override the built-in set
32
32
  of framework operator, to define a specific behavior for those operators. This dictionary should map
33
33
  an operator set unique name to a pair of: a list of framework operators and an optional
34
34
  operator's attributes names mapping.
35
35
 
36
- Returns: a TargetPlatformCapabilities object.
36
+ Returns: a FrameworkQuantizationCapabilities object.
37
37
 
38
38
  """
39
39
 
40
- tpc = TargetPlatformCapabilities(tpc_model)
40
+ tpc = FrameworkQuantizationCapabilities(tpc_model)
41
41
  custom_opset2layer = custom_opset2layer if custom_opset2layer is not None else {}
42
42
 
43
43
  with tpc:
@@ -59,7 +59,7 @@ class AttachTpcToFramework:
59
59
  attr_mapping = self._opset2attr_mapping.get(opset.name)
60
60
  OperationsSetToLayers(opset.name, layers, attr_mapping=attr_mapping)
61
61
  else:
62
- Logger.critical(f'{opset.name} is defined in TargetPlatformModel, '
62
+ Logger.critical(f'{opset.name} is defined in TargetPlatformCapabilities, '
63
63
  f'but is not defined in the framework set of operators or in the provided '
64
64
  f'custom operator sets mapping.')
65
65
 
@@ -16,6 +16,9 @@
16
16
  import tensorflow as tf
17
17
  from packaging import version
18
18
 
19
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
20
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
21
+ AttachTpcToFramework
19
22
  from model_compression_toolkit.verify_packages import FOUND_SONY_CUSTOM_LAYERS
20
23
 
21
24
  if FOUND_SONY_CUSTOM_LAYERS:
@@ -34,9 +37,6 @@ from model_compression_toolkit import DefaultDict
34
37
  from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS, \
35
38
  BIAS_ATTR, KERAS_KERNEL, KERAS_DEPTHWISE_KERNEL
36
39
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
37
- from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams
38
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \
39
- AttachTpcToFramework
40
40
 
41
41
 
42
42
  class AttachTpcToKeras(AttachTpcToFramework):
@@ -28,9 +28,10 @@ from model_compression_toolkit import DefaultDict
28
28
  from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, PYTORCH_KERNEL, BIAS, \
29
29
  BIAS_ATTR
30
30
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
31
- from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams, Eq
32
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \
31
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
32
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
33
33
  AttachTpcToFramework
34
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import Eq
34
35
 
35
36
 
36
37
  class AttachTpcToPytorch(AttachTpcToFramework):
@@ -18,7 +18,7 @@ from model_compression_toolkit.logger import Logger
18
18
  def get_current_tpc():
19
19
  """
20
20
 
21
- Returns: The current TargetPlatformCapabilities that is being used and accessed.
21
+ Returns: The current FrameworkQuantizationCapabilities that is being used and accessed.
22
22
 
23
23
  """
24
24
  return _current_tpc.get()
@@ -26,7 +26,7 @@ def get_current_tpc():
26
26
 
27
27
  class _CurrentTPC(object):
28
28
  """
29
- Wrapper of the current TargetPlatformCapabilities object that is being accessed and defined.
29
+ Wrapper of the current FrameworkQuantizationCapabilities object that is being accessed and defined.
30
30
  """
31
31
  def __init__(self):
32
32
  super(_CurrentTPC, self).__init__()
@@ -35,28 +35,28 @@ class _CurrentTPC(object):
35
35
  def get(self):
36
36
  """
37
37
 
38
- Returns: The current TargetPlatformCapabilities that is being defined.
38
+ Returns: The current FrameworkQuantizationCapabilities that is being defined.
39
39
 
40
40
  """
41
41
  if self.tpc is None:
42
- Logger.critical("'TargetPlatformCapabilities' (TPC) instance is not initialized.")
42
+ Logger.critical("'FrameworkQuantizationCapabilities' (TPC) instance is not initialized.")
43
43
  return self.tpc
44
44
 
45
45
  def reset(self):
46
46
  """
47
47
 
48
- Reset the current TargetPlatformCapabilities so a new TargetPlatformCapabilities can be wrapped and
49
- used as the current TargetPlatformCapabilities object.
48
+ Reset the current FrameworkQuantizationCapabilities so a new FrameworkQuantizationCapabilities can be wrapped and
49
+ used as the current FrameworkQuantizationCapabilities object.
50
50
 
51
51
  """
52
52
  self.tpc = None
53
53
 
54
54
  def set(self, target_platform_capabilities):
55
55
  """
56
- Set and wrap a TargetPlatformCapabilities as the current TargetPlatformCapabilities.
56
+ Set and wrap a FrameworkQuantizationCapabilities as the current FrameworkQuantizationCapabilities.
57
57
 
58
58
  Args:
59
- target_platform_capabilities: TargetPlatformCapabilities to set as the current TargetPlatformCapabilities
59
+ target_platform_capabilities: FrameworkQuantizationCapabilities to set as the current FrameworkQuantizationCapabilities
60
60
  to access and use.
61
61
 
62
62
  """