mct-nightly 2.2.0.20250113.527__py3-none-any.whl → 2.2.0.20250114.84821__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/RECORD +103 -105
  3. model_compression_toolkit/__init__.py +2 -2
  4. model_compression_toolkit/core/common/framework_info.py +1 -3
  5. model_compression_toolkit/core/common/fusion/layer_fusing.py +6 -5
  6. model_compression_toolkit/core/common/graph/base_graph.py +20 -21
  7. model_compression_toolkit/core/common/graph/base_node.py +44 -17
  8. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +7 -6
  9. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +0 -6
  10. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +26 -135
  11. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +36 -62
  12. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +667 -0
  13. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +25 -202
  14. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py +164 -470
  15. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +30 -7
  16. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +3 -5
  17. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  18. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +7 -6
  19. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +0 -1
  20. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +0 -1
  21. model_compression_toolkit/core/common/pruning/pruner.py +5 -3
  22. model_compression_toolkit/core/common/quantization/bit_width_config.py +6 -12
  23. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -2
  24. model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
  25. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -1
  26. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  27. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +1 -1
  28. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +1 -1
  29. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +1 -1
  30. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +1 -1
  31. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +1 -1
  32. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +1 -1
  33. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +15 -14
  34. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
  35. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +1 -1
  36. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  37. model_compression_toolkit/core/graph_prep_runner.py +12 -11
  38. model_compression_toolkit/core/keras/data_util.py +24 -5
  39. model_compression_toolkit/core/keras/default_framework_info.py +1 -1
  40. model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +1 -2
  41. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +5 -6
  42. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  43. model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
  44. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
  45. model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -1
  46. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +4 -5
  47. model_compression_toolkit/core/runner.py +33 -60
  48. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +1 -1
  49. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +1 -1
  50. model_compression_toolkit/gptq/keras/quantization_facade.py +8 -9
  51. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  52. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
  53. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
  54. model_compression_toolkit/gptq/pytorch/quantization_facade.py +8 -9
  55. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  56. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
  57. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
  58. model_compression_toolkit/metadata.py +11 -10
  59. model_compression_toolkit/pruning/keras/pruning_facade.py +5 -6
  60. model_compression_toolkit/pruning/pytorch/pruning_facade.py +6 -7
  61. model_compression_toolkit/ptq/keras/quantization_facade.py +8 -9
  62. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -9
  63. model_compression_toolkit/qat/keras/quantization_facade.py +5 -6
  64. model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py +1 -1
  65. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
  66. model_compression_toolkit/qat/pytorch/quantization_facade.py +5 -9
  67. model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py +1 -1
  68. model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py +1 -1
  69. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
  70. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +1 -1
  71. model_compression_toolkit/target_platform_capabilities/__init__.py +9 -0
  72. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  73. model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +2 -2
  74. model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +18 -18
  75. model_compression_toolkit/target_platform_capabilities/schema/v1.py +13 -13
  76. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/__init__.py +6 -6
  77. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2fw.py +10 -10
  78. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2keras.py +3 -3
  79. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2pytorch.py +3 -2
  80. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/current_tpc.py +8 -8
  81. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities.py → targetplatform2framework/framework_quantization_capabilities.py} +40 -40
  82. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities_component.py → targetplatform2framework/framework_quantization_capabilities_component.py} +2 -2
  83. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/layer_filter_params.py +0 -1
  84. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/operations_to_layers.py +8 -8
  85. model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +24 -24
  86. model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +18 -18
  87. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +3 -3
  88. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/{tp_model.py → tpc.py} +31 -32
  89. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +3 -3
  90. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/{tp_model.py → tpc.py} +27 -27
  91. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +4 -4
  92. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/{tp_model.py → tpc.py} +27 -27
  93. model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +1 -2
  94. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +2 -1
  95. model_compression_toolkit/trainable_infrastructure/keras/activation_quantizers/lsq/symmetric_lsq.py +1 -2
  96. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +1 -1
  97. model_compression_toolkit/xquant/common/model_folding_utils.py +7 -6
  98. model_compression_toolkit/xquant/keras/keras_report_utils.py +4 -4
  99. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -3
  100. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +0 -105
  101. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +0 -33
  102. model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +0 -23
  103. {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/LICENSE.md +0 -0
  104. {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/WHEEL +0 -0
  105. {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/top_level.txt +0 -0
  106. /model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attribute_filter.py +0 -0
@@ -414,7 +414,7 @@ class QuantizationConfigOptions(BaseModel):
414
414
 
415
415
  class TargetPlatformModelComponent(BaseModel):
416
416
  """
417
- Component of TargetPlatformModel (Fusing, OperatorsSet, etc.).
417
+ Component of TargetPlatformCapabilities (Fusing, OperatorsSet, etc.).
418
418
  """
419
419
  class Config:
420
420
  frozen = True
@@ -433,7 +433,7 @@ class OperatorsSet(OperatorsSetBase):
433
433
  Set of operators that are represented by a unique label.
434
434
 
435
435
  Attributes:
436
- name (Union[str, OperatorSetNames]): The set's label (must be unique within a TargetPlatformModel).
436
+ name (Union[str, OperatorSetNames]): The set's label (must be unique within a TargetPlatformCapabilities).
437
437
  qc_options (Optional[QuantizationConfigOptions]): Configuration options to use for this set of operations.
438
438
  If None, it represents a fusing set.
439
439
  type (Literal["OperatorsSet"]): Fixed type identifier.
@@ -457,7 +457,7 @@ class OperatorsSet(OperatorsSetBase):
457
457
  return {"name": self.name}
458
458
 
459
459
 
460
- class OperatorSetConcat(OperatorsSetBase):
460
+ class OperatorSetGroup(OperatorsSetBase):
461
461
  """
462
462
  Concatenate a tuple of operator sets to treat them similarly in different places (like fusing).
463
463
 
@@ -469,7 +469,7 @@ class OperatorSetConcat(OperatorsSetBase):
469
469
  name: Optional[str] = None # Will be set in the validator if not given
470
470
 
471
471
  # Define a private attribute _type
472
- type: Literal["OperatorSetConcat"] = "OperatorSetConcat"
472
+ type: Literal["OperatorSetGroup"] = "OperatorSetGroup"
473
473
 
474
474
  class Config:
475
475
  frozen = True
@@ -518,11 +518,11 @@ class Fusing(TargetPlatformModelComponent):
518
518
  hence no quantization is applied between them.
519
519
 
520
520
  Attributes:
521
- operator_groups (Tuple[Union[OperatorsSet, OperatorSetConcat], ...]): A tuple of operator groups,
522
- each being either an OperatorSetConcat or an OperatorsSet.
521
+ operator_groups (Tuple[Union[OperatorsSet, OperatorSetGroup], ...]): A tuple of operator groups,
522
+ each being either an OperatorSetGroup or an OperatorsSet.
523
523
  name (Optional[str]): The name for the Fusing instance. If not provided, it is generated from the operator groups' names.
524
524
  """
525
- operator_groups: Tuple[Annotated[Union[OperatorsSet, OperatorSetConcat], Field(discriminator='type')], ...]
525
+ operator_groups: Tuple[Annotated[Union[OperatorsSet, OperatorSetGroup], Field(discriminator='type')], ...]
526
526
  name: Optional[str] = None # Will be set in the validator if not given.
527
527
 
528
528
  class Config:
@@ -591,7 +591,7 @@ class Fusing(TargetPlatformModelComponent):
591
591
  for i in range(len(self.operator_groups) - len(other.operator_groups) + 1):
592
592
  for j in range(len(other.operator_groups)):
593
593
  if self.operator_groups[i + j] != other.operator_groups[j] and not (
594
- isinstance(self.operator_groups[i + j], OperatorSetConcat) and (
594
+ isinstance(self.operator_groups[i + j], OperatorSetGroup) and (
595
595
  other.operator_groups[j] in self.operator_groups[i + j].operators_set)):
596
596
  break
597
597
  else:
@@ -621,7 +621,7 @@ class Fusing(TargetPlatformModelComponent):
621
621
  for x in self.operator_groups
622
622
  ])
623
623
 
624
- class TargetPlatformModel(BaseModel):
624
+ class TargetPlatformCapabilities(BaseModel):
625
625
  """
626
626
  Represents the hardware configuration used for quantized model inference.
627
627
 
@@ -644,7 +644,7 @@ class TargetPlatformModel(BaseModel):
644
644
  tpc_patch_version: Optional[int]
645
645
  tpc_platform_type: Optional[str]
646
646
  add_metadata: bool = True
647
- name: Optional[str] = "default_tp_model"
647
+ name: Optional[str] = "default_tpc"
648
648
  is_simd_padding: bool = False
649
649
 
650
650
  SCHEMA_VERSION: int = 1
@@ -682,10 +682,10 @@ class TargetPlatformModel(BaseModel):
682
682
 
683
683
  def get_info(self) -> Dict[str, Any]:
684
684
  """
685
- Get a dictionary summarizing the TargetPlatformModel properties.
685
+ Get a dictionary summarizing the TargetPlatformCapabilities properties.
686
686
 
687
687
  Returns:
688
- Dict[str, Any]: Summary of the TargetPlatformModel properties.
688
+ Dict[str, Any]: Summary of the TargetPlatformCapabilities properties.
689
689
  """
690
690
  return {
691
691
  "Model name": self.name,
@@ -695,6 +695,6 @@ class TargetPlatformModel(BaseModel):
695
695
 
696
696
  def show(self):
697
697
  """
698
- Display the TargetPlatformModel.
698
+ Display the TargetPlatformCapabilities.
699
699
  """
700
700
  pprint.pprint(self.get_info(), sort_dicts=False)
@@ -13,13 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import get_current_tpc
17
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities import TargetPlatformCapabilities
18
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import \
19
- Eq, GreaterEq, NotEq, SmallerEq, Greater, Smaller
20
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.layer_filter_params import \
16
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import get_current_tpc
17
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import FrameworkQuantizationCapabilities
18
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.layer_filter_params import \
21
19
  LayerFilterParams
22
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.operations_to_layers import \
20
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import \
21
+ Eq, GreaterEq, NotEq, SmallerEq, Greater, Smaller
22
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.operations_to_layers import \
23
23
  OperationsToLayers, OperationsSetToLayers
24
24
 
25
25
 
@@ -1,12 +1,12 @@
1
1
  from typing import Dict, Optional
2
2
 
3
3
  from model_compression_toolkit.logger import Logger
4
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel, \
4
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities, \
5
5
  OperatorsSet
6
- from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, \
7
- OperationsSetToLayers
8
6
 
9
7
  from model_compression_toolkit.core.common.quantization.quantization_config import CustomOpsetLayers
8
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import \
9
+ FrameworkQuantizationCapabilities, OperationsSetToLayers
10
10
 
11
11
 
12
12
  class AttachTpcToFramework:
@@ -19,25 +19,25 @@ class AttachTpcToFramework:
19
19
  # in the operation set are provided in the mapping, a DefaultDict should be supplied to handle missing entries.
20
20
  self._opset2attr_mapping = None # Mapping of operation sets to their corresponding framework-specific layers
21
21
 
22
- def attach(self, tpc_model: TargetPlatformModel,
22
+ def attach(self, tpc_model: TargetPlatformCapabilities,
23
23
  custom_opset2layer: Optional[Dict[str, 'CustomOpsetLayers']] = None
24
- ) -> TargetPlatformCapabilities:
24
+ ) -> FrameworkQuantizationCapabilities:
25
25
  """
26
- Attaching a TargetPlatformModel which includes a platform capabilities description to specific
26
+ Attaching a TargetPlatformCapabilities which includes a platform capabilities description to specific
27
27
  framework's operators.
28
28
 
29
29
  Args:
30
- tpc_model: a TargetPlatformModel object.
30
+ tpc_model: a TargetPlatformCapabilities object.
31
31
  custom_opset2layer: optional set of custom operator sets which allows to add/override the built-in set
32
32
  of framework operator, to define a specific behavior for those operators. This dictionary should map
33
33
  an operator set unique name to a pair of: a list of framework operators and an optional
34
34
  operator's attributes names mapping.
35
35
 
36
- Returns: a TargetPlatformCapabilities object.
36
+ Returns: a FrameworkQuantizationCapabilities object.
37
37
 
38
38
  """
39
39
 
40
- tpc = TargetPlatformCapabilities(tpc_model)
40
+ tpc = FrameworkQuantizationCapabilities(tpc_model)
41
41
  custom_opset2layer = custom_opset2layer if custom_opset2layer is not None else {}
42
42
 
43
43
  with tpc:
@@ -59,7 +59,7 @@ class AttachTpcToFramework:
59
59
  attr_mapping = self._opset2attr_mapping.get(opset.name)
60
60
  OperationsSetToLayers(opset.name, layers, attr_mapping=attr_mapping)
61
61
  else:
62
- Logger.critical(f'{opset.name} is defined in TargetPlatformModel, '
62
+ Logger.critical(f'{opset.name} is defined in TargetPlatformCapabilities, '
63
63
  f'but is not defined in the framework set of operators or in the provided '
64
64
  f'custom operator sets mapping.')
65
65
 
@@ -16,6 +16,9 @@
16
16
  import tensorflow as tf
17
17
  from packaging import version
18
18
 
19
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
20
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
21
+ AttachTpcToFramework
19
22
  from model_compression_toolkit.verify_packages import FOUND_SONY_CUSTOM_LAYERS
20
23
 
21
24
  if FOUND_SONY_CUSTOM_LAYERS:
@@ -34,9 +37,6 @@ from model_compression_toolkit import DefaultDict
34
37
  from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS, \
35
38
  BIAS_ATTR, KERAS_KERNEL, KERAS_DEPTHWISE_KERNEL
36
39
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
37
- from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams
38
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \
39
- AttachTpcToFramework
40
40
 
41
41
 
42
42
  class AttachTpcToKeras(AttachTpcToFramework):
@@ -28,9 +28,10 @@ from model_compression_toolkit import DefaultDict
28
28
  from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, PYTORCH_KERNEL, BIAS, \
29
29
  BIAS_ATTR
30
30
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
31
- from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams, Eq
32
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \
31
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
32
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
33
33
  AttachTpcToFramework
34
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import Eq
34
35
 
35
36
 
36
37
  class AttachTpcToPytorch(AttachTpcToFramework):
@@ -18,7 +18,7 @@ from model_compression_toolkit.logger import Logger
18
18
  def get_current_tpc():
19
19
  """
20
20
 
21
- Returns: The current TargetPlatformCapabilities that is being used and accessed.
21
+ Returns: The current FrameworkQuantizationCapabilities that is being used and accessed.
22
22
 
23
23
  """
24
24
  return _current_tpc.get()
@@ -26,7 +26,7 @@ def get_current_tpc():
26
26
 
27
27
  class _CurrentTPC(object):
28
28
  """
29
- Wrapper of the current TargetPlatformCapabilities object that is being accessed and defined.
29
+ Wrapper of the current FrameworkQuantizationCapabilities object that is being accessed and defined.
30
30
  """
31
31
  def __init__(self):
32
32
  super(_CurrentTPC, self).__init__()
@@ -35,28 +35,28 @@ class _CurrentTPC(object):
35
35
  def get(self):
36
36
  """
37
37
 
38
- Returns: The current TargetPlatformCapabilities that is being defined.
38
+ Returns: The current FrameworkQuantizationCapabilities that is being defined.
39
39
 
40
40
  """
41
41
  if self.tpc is None:
42
- Logger.critical("'TargetPlatformCapabilities' (TPC) instance is not initialized.")
42
+ Logger.critical("'FrameworkQuantizationCapabilities' (TPC) instance is not initialized.")
43
43
  return self.tpc
44
44
 
45
45
  def reset(self):
46
46
  """
47
47
 
48
- Reset the current TargetPlatformCapabilities so a new TargetPlatformCapabilities can be wrapped and
49
- used as the current TargetPlatformCapabilities object.
48
+ Reset the current FrameworkQuantizationCapabilities so a new FrameworkQuantizationCapabilities can be wrapped and
49
+ used as the current FrameworkQuantizationCapabilities object.
50
50
 
51
51
  """
52
52
  self.tpc = None
53
53
 
54
54
  def set(self, target_platform_capabilities):
55
55
  """
56
- Set and wrap a TargetPlatformCapabilities as the current TargetPlatformCapabilities.
56
+ Set and wrap a FrameworkQuantizationCapabilities as the current FrameworkQuantizationCapabilities.
57
57
 
58
58
  Args:
59
- target_platform_capabilities: TargetPlatformCapabilities to set as the current TargetPlatformCapabilities
59
+ target_platform_capabilities: FrameworkQuantizationCapabilities to set as the current FrameworkQuantizationCapabilities
60
60
  to access and use.
61
61
 
62
62
  """
@@ -21,38 +21,38 @@ from typing import List, Any, Dict, Tuple
21
21
  from model_compression_toolkit.logger import Logger
22
22
  from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import \
23
23
  get_config_options_by_operators_set, get_default_op_quantization_config, get_opset_by_name
24
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.operations_to_layers import \
25
- OperationsToLayers, OperationsSetToLayers
26
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
27
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.layer_filter_params import LayerFilterParams
24
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.operations_to_layers import OperationsToLayers, \
25
+ OperationsSetToLayers
26
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities_component import \
27
+ FrameworkQuantizationCapabilitiesComponent
28
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.layer_filter_params import LayerFilterParams
28
29
  from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass
29
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel, OperatorsSetBase, \
30
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities, OperatorsSetBase, \
30
31
  OpQuantizationConfig, QuantizationConfigOptions
31
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
32
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import _current_tpc
32
33
 
33
-
34
- class TargetPlatformCapabilities(ImmutableClass):
34
+ class FrameworkQuantizationCapabilities(ImmutableClass):
35
35
  """
36
36
  Attach framework information to a modeled hardware.
37
37
  """
38
38
  def __init__(self,
39
- tp_model: TargetPlatformModel,
39
+ tpc: TargetPlatformCapabilities,
40
40
  name: str = "base"):
41
41
  """
42
42
 
43
43
  Args:
44
- tp_model (TargetPlatformModel): Modeled hardware to attach framework information to.
45
- name (str): Name of the TargetPlatformCapabilities.
44
+ tpc (TargetPlatformCapabilities): Modeled hardware to attach framework information to.
45
+ name (str): Name of the FrameworkQuantizationCapabilities.
46
46
  """
47
47
 
48
48
  super().__init__()
49
49
  self.name = name
50
- assert isinstance(tp_model, TargetPlatformModel), f'Target platform model that was passed to TargetPlatformCapabilities must be of type TargetPlatformModel, but has type of {type(tp_model)}'
51
- self.tp_model = tp_model
50
+ assert isinstance(tpc, TargetPlatformCapabilities), f'Target platform model that was passed to FrameworkQuantizationCapabilities must be of type TargetPlatformCapabilities, but has type of {type(tpc)}'
51
+ self.tpc = tpc
52
52
  self.op_sets_to_layers = OperationsToLayers() # Init an empty OperationsToLayers
53
53
  self.layer2qco, self.filterlayer2qco = {}, {} # Init empty mappings from layers/LayerFilterParams to QC options
54
54
  # Track the unused opsets for warning purposes.
55
- self.__tp_model_opsets_not_used = [s.name for s in tp_model.operator_set]
55
+ self.__tpc_opsets_not_used = [s.name for s in tpc.operator_set]
56
56
  self.remove_fusing_names_from_not_used_list()
57
57
 
58
58
  def get_layers_by_opset_name(self, opset_name: str) -> List[Any]:
@@ -66,9 +66,9 @@ class TargetPlatformCapabilities(ImmutableClass):
66
66
  Returns:
67
67
  List of layers/LayerFilterParams that are attached to the opset name.
68
68
  """
69
- opset = get_opset_by_name(self.tp_model, opset_name)
69
+ opset = get_opset_by_name(self.tpc, opset_name)
70
70
  if opset is None:
71
- Logger.warning(f'{opset_name} was not found in TargetPlatformCapabilities.')
71
+ Logger.warning(f'{opset_name} was not found in FrameworkQuantizationCapabilities.')
72
72
  return None
73
73
  return self.get_layers_by_opset(opset)
74
74
 
@@ -100,9 +100,9 @@ class TargetPlatformCapabilities(ImmutableClass):
100
100
 
101
101
  """
102
102
  res = []
103
- if self.tp_model.fusing_patterns is None:
103
+ if self.tpc.fusing_patterns is None:
104
104
  return res
105
- for p in self.tp_model.fusing_patterns:
105
+ for p in self.tpc.fusing_patterns:
106
106
  ops = [self.get_layers_by_opset(x) for x in p.operator_groups]
107
107
  res.extend(itertools.product(*ops))
108
108
  return [list(x) for x in res]
@@ -111,47 +111,47 @@ class TargetPlatformCapabilities(ImmutableClass):
111
111
  def get_info(self) -> Dict[str, Any]:
112
112
  """
113
113
 
114
- Returns: Summarization of information in the TargetPlatformCapabilities.
114
+ Returns: Summarization of information in the FrameworkQuantizationCapabilities.
115
115
 
116
116
  """
117
117
  return {"Target Platform Capabilities": self.name,
118
- "Minor version": self.tp_model.tpc_minor_version,
119
- "Patch version": self.tp_model.tpc_patch_version,
120
- "Platform type": self.tp_model.tpc_platform_type,
121
- "Target Platform Model": self.tp_model.get_info(),
118
+ "Minor version": self.tpc.tpc_minor_version,
119
+ "Patch version": self.tpc.tpc_patch_version,
120
+ "Platform type": self.tpc.tpc_platform_type,
121
+ "Target Platform Model": self.tpc.get_info(),
122
122
  "Operations to layers": {op2layer.name:[l.__name__ for l in op2layer.layers] for op2layer in self.op_sets_to_layers.op_sets_to_layers}}
123
123
 
124
124
  def show(self):
125
125
  """
126
126
 
127
- Display the TargetPlatformCapabilities.
127
+ Display the FrameworkQuantizationCapabilities.
128
128
 
129
129
  """
130
130
  pprint.pprint(self.get_info(), sort_dicts=False, width=110)
131
131
 
132
- def append_component(self, tpc_component: TargetPlatformCapabilitiesComponent):
132
+ def append_component(self, tpc_component: FrameworkQuantizationCapabilitiesComponent):
133
133
  """
134
- Append a Component (like OperationsSetToLayers) to the TargetPlatformCapabilities.
134
+ Append a Component (like OperationsSetToLayers) to the FrameworkQuantizationCapabilities.
135
135
 
136
136
  Args:
137
- tpc_component: Component to append to TargetPlatformCapabilities.
137
+ tpc_component: Component to append to FrameworkQuantizationCapabilities.
138
138
 
139
139
  """
140
140
  if isinstance(tpc_component, OperationsSetToLayers):
141
141
  self.op_sets_to_layers += tpc_component
142
142
  else:
143
- Logger.critical(f"Attempt to append an unrecognized 'TargetPlatformCapabilitiesComponent' of type: '{type(tpc_component)}'. Ensure the component is compatible.") # pragma: no cover
143
+ Logger.critical(f"Attempt to append an unrecognized 'FrameworkQuantizationCapabilitiesComponent' of type: '{type(tpc_component)}'. Ensure the component is compatible.") # pragma: no cover
144
144
 
145
145
  def __enter__(self):
146
146
  """
147
- Init a TargetPlatformCapabilities object.
147
+ Init a FrameworkQuantizationCapabilities object.
148
148
  """
149
149
  _current_tpc.set(self)
150
150
  return self
151
151
 
152
152
  def __exit__(self, exc_type, exc_value, tb):
153
153
  """
154
- Finalize a TargetPlatformCapabilities object.
154
+ Finalize a FrameworkQuantizationCapabilities object.
155
155
  """
156
156
  if exc_value is not None:
157
157
  print(exc_value, exc_value.args)
@@ -164,11 +164,11 @@ class TargetPlatformCapabilities(ImmutableClass):
164
164
  def get_default_op_qc(self) -> OpQuantizationConfig:
165
165
  """
166
166
 
167
- Returns: The default OpQuantizationConfig of the TargetPlatformModel that is attached
168
- to the TargetPlatformCapabilities.
167
+ Returns: The default OpQuantizationConfig of the TargetPlatformCapabilities that is attached
168
+ to the FrameworkQuantizationCapabilities.
169
169
 
170
170
  """
171
- return get_default_op_quantization_config(self.tp_model)
171
+ return get_default_op_quantization_config(self.tpc)
172
172
 
173
173
 
174
174
  def _get_config_options_mapping(self) -> Tuple[Dict[Any, QuantizationConfigOptions],
@@ -184,9 +184,9 @@ class TargetPlatformCapabilities(ImmutableClass):
184
184
  filterlayer2qco = {}
185
185
  for op2layers in self.op_sets_to_layers.op_sets_to_layers:
186
186
  for l in op2layers.layers:
187
- qco = get_config_options_by_operators_set(self.tp_model, op2layers.name)
187
+ qco = get_config_options_by_operators_set(self.tpc, op2layers.name)
188
188
  if qco is None:
189
- qco = self.tp_model.default_qco
189
+ qco = self.tpc.default_qco
190
190
 
191
191
  # here, we need to take care of mapping a general attribute name into a framework and
192
192
  # layer type specific attribute name.
@@ -208,8 +208,8 @@ class TargetPlatformCapabilities(ImmutableClass):
208
208
  Remove OperatorSets names from the list of the unused sets (so a warning
209
209
  will not be displayed).
210
210
  """
211
- if self.tp_model.fusing_patterns is not None:
212
- for f in self.tp_model.fusing_patterns:
211
+ if self.tpc.fusing_patterns is not None:
212
+ for f in self.tpc.fusing_patterns:
213
213
  for s in f.operator_groups:
214
214
  self.remove_opset_from_not_used_list(s.name)
215
215
 
@@ -222,8 +222,8 @@ class TargetPlatformCapabilities(ImmutableClass):
222
222
  opset_to_remove: OperatorsSet name to remove.
223
223
 
224
224
  """
225
- if opset_to_remove in self.__tp_model_opsets_not_used:
226
- self.__tp_model_opsets_not_used.remove(opset_to_remove)
225
+ if opset_to_remove in self.__tpc_opsets_not_used:
226
+ self.__tpc_opsets_not_used.remove(opset_to_remove)
227
227
 
228
228
  @property
229
229
  def is_simd_padding(self) -> bool:
@@ -232,4 +232,4 @@ class TargetPlatformCapabilities(ImmutableClass):
232
232
  Returns: Check if the TP model defines that padding due to SIMD constrains occurs.
233
233
 
234
234
  """
235
- return self.tp_model.is_simd_padding
235
+ return self.tpc.is_simd_padding
@@ -13,10 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
16
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import _current_tpc
17
17
 
18
18
 
19
- class TargetPlatformCapabilitiesComponent:
19
+ class FrameworkQuantizationCapabilitiesComponent:
20
20
  def __init__(self, name: str):
21
21
  self.name = name
22
22
  _current_tpc.get().append_component(self)
@@ -14,7 +14,6 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from typing import Any
17
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import AttributeFilter
18
17
 
19
18
 
20
19
  class LayerFilterParams:
@@ -18,13 +18,13 @@ from typing import List, Any, Dict
18
18
  from model_compression_toolkit.logger import Logger
19
19
  from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import \
20
20
  get_config_options_by_operators_set, is_opset_in_model
21
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
22
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
23
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorsSetBase, OperatorSetConcat
21
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import _current_tpc
22
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities_component import FrameworkQuantizationCapabilitiesComponent
23
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorsSetBase, OperatorSetGroup
24
24
  from model_compression_toolkit import DefaultDict
25
25
 
26
26
 
27
- class OperationsSetToLayers(TargetPlatformCapabilitiesComponent):
27
+ class OperationsSetToLayers(FrameworkQuantizationCapabilitiesComponent):
28
28
  """
29
29
  Associate an OperatorsSet to a list of framework's layers.
30
30
  """
@@ -57,7 +57,7 @@ class OperationsSetToLayers(TargetPlatformCapabilitiesComponent):
57
57
 
58
58
  class OperationsToLayers:
59
59
  """
60
- Gather multiple OperationsSetToLayers to represent mapping of framework's layers to TargetPlatformModel OperatorsSet.
60
+ Gather multiple OperationsSetToLayers to represent mapping of framework's layers to TargetPlatformCapabilities OperatorsSet.
61
61
  """
62
62
  def __init__(self,
63
63
  op_sets_to_layers: List[OperationsSetToLayers]=None):
@@ -88,7 +88,7 @@ class OperationsToLayers:
88
88
  for o in self.op_sets_to_layers:
89
89
  if op.name == o.name:
90
90
  return o.layers
91
- if isinstance(op, OperatorSetConcat): # If its a concat - return all layers from all OperatorsSets that in the OperatorSetConcat
91
+ if isinstance(op, OperatorSetGroup): # If its a concat - return all layers from all OperatorsSets that in the OperatorSetGroup
92
92
  layers = []
93
93
  for o in op.operators_set:
94
94
  layers.extend(self.get_layers_by_op(o))
@@ -142,9 +142,9 @@ class OperationsToLayers:
142
142
  assert ops2layers.name not in existing_opset_names, f'OperationsSetToLayers names should be unique, but {ops2layers.name} appears to violate it.'
143
143
  existing_opset_names.append(ops2layers.name)
144
144
 
145
- # Assert that a layer does not appear in more than a single OperatorsSet in the TargetPlatformModel.
145
+ # Assert that a layer does not appear in more than a single OperatorsSet in the TargetPlatformCapabilities.
146
146
  for layer in ops2layers.layers:
147
- qco_by_opset_name = get_config_options_by_operators_set(_current_tpc.get().tp_model, ops2layers.name)
147
+ qco_by_opset_name = get_config_options_by_operators_set(_current_tpc.get().tpc, ops2layers.name)
148
148
  if layer in existing_layers:
149
149
  Logger.critical(f'Found layer {layer.__name__} in more than one '
150
150
  f'OperatorsSet') # pragma: no cover
@@ -16,34 +16,34 @@ from pathlib import Path
16
16
  from typing import Union
17
17
 
18
18
  from model_compression_toolkit.logger import Logger
19
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
19
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
20
20
  import json
21
21
 
22
22
 
23
- def load_target_platform_model(tp_model_or_path: Union[TargetPlatformModel, str]) -> TargetPlatformModel:
23
+ def load_target_platform_model(tpc_obj_or_path: Union[TargetPlatformCapabilities, str]) -> TargetPlatformCapabilities:
24
24
  """
25
- Parses the tp_model input, which can be either a TargetPlatformModel object
25
+ Parses the tpc input, which can be either a TargetPlatformCapabilities object
26
26
  or a string path to a JSON file.
27
27
 
28
28
  Parameters:
29
- tp_model_or_path (Union[TargetPlatformModel, str]): Input target platform model or path to .JSON file.
29
+ tpc_obj_or_path (Union[TargetPlatformModel, str]): Input target platform model or path to .JSON file.
30
30
 
31
31
  Returns:
32
- TargetPlatformModel: The parsed TargetPlatformModel.
32
+ TargetPlatformCapabilities: The parsed TargetPlatformCapabilities.
33
33
 
34
34
  Raises:
35
35
  FileNotFoundError: If the JSON file does not exist.
36
- ValueError: If the JSON content is invalid or cannot initialize the TargetPlatformModel.
37
- TypeError: If the input is neither a TargetPlatformModel nor a valid JSON file path.
36
+ ValueError: If the JSON content is invalid or cannot initialize the TargetPlatformCapabilities.
37
+ TypeError: If the input is neither a TargetPlatformCapabilities nor a valid JSON file path.
38
38
  """
39
- if isinstance(tp_model_or_path, TargetPlatformModel):
40
- return tp_model_or_path
39
+ if isinstance(tpc_obj_or_path, TargetPlatformCapabilities):
40
+ return tpc_obj_or_path
41
41
 
42
- if isinstance(tp_model_or_path, str):
43
- path = Path(tp_model_or_path)
42
+ if isinstance(tpc_obj_or_path, str):
43
+ path = Path(tpc_obj_or_path)
44
44
 
45
45
  if not path.exists() or not path.is_file():
46
- raise FileNotFoundError(f"The path '{tp_model_or_path}' is not a valid file.")
46
+ raise FileNotFoundError(f"The path '{tpc_obj_or_path}' is not a valid file.")
47
47
  # Verify that the file has a .json extension
48
48
  if path.suffix.lower() != '.json':
49
49
  raise ValueError(f"The file '{path}' does not have a '.json' extension.")
@@ -51,35 +51,35 @@ def load_target_platform_model(tp_model_or_path: Union[TargetPlatformModel, str]
51
51
  with path.open('r', encoding='utf-8') as file:
52
52
  data = file.read()
53
53
  except OSError as e:
54
- raise ValueError(f"Error reading the file '{tp_model_or_path}': {e.strerror}.") from e
54
+ raise ValueError(f"Error reading the file '{tpc_obj_or_path}': {e.strerror}.") from e
55
55
 
56
56
  try:
57
- return TargetPlatformModel.parse_raw(data)
57
+ return TargetPlatformCapabilities.parse_raw(data)
58
58
  except ValueError as e:
59
- raise ValueError(f"Invalid JSON for loading TargetPlatformModel in '{tp_model_or_path}': {e}.") from e
59
+ raise ValueError(f"Invalid JSON for loading TargetPlatformCapabilities in '{tpc_obj_or_path}': {e}.") from e
60
60
  except Exception as e:
61
- raise ValueError(f"Unexpected error while initializing TargetPlatformModel: {e}.") from e
61
+ raise ValueError(f"Unexpected error while initializing TargetPlatformCapabilities: {e}.") from e
62
62
 
63
63
  raise TypeError(
64
- f"tp_model_or_path must be either a TargetPlatformModel instance or a string path to a JSON file, "
65
- f"but received type '{type(tp_model_or_path).__name__}'."
64
+ f"tpc_obj_or_path must be either a TargetPlatformCapabilities instance or a string path to a JSON file, "
65
+ f"but received type '{type(tpc_obj_or_path).__name__}'."
66
66
  )
67
67
 
68
68
 
69
- def export_target_platform_model(model: TargetPlatformModel, export_path: Union[str, Path]) -> None:
69
+ def export_target_platform_model(model: TargetPlatformCapabilities, export_path: Union[str, Path]) -> None:
70
70
  """
71
- Exports a TargetPlatformModel instance to a JSON file.
71
+ Exports a TargetPlatformCapabilities instance to a JSON file.
72
72
 
73
73
  Parameters:
74
- model (TargetPlatformModel): The TargetPlatformModel instance to export.
74
+ model (TargetPlatformCapabilities): The TargetPlatformCapabilities instance to export.
75
75
  export_path (Union[str, Path]): The file path to export the model to.
76
76
 
77
77
  Raises:
78
- ValueError: If the model is not an instance of TargetPlatformModel.
78
+ ValueError: If the model is not an instance of TargetPlatformCapabilities.
79
79
  OSError: If there is an issue writing to the file.
80
80
  """
81
- if not isinstance(model, TargetPlatformModel):
82
- raise ValueError("The provided model is not a valid TargetPlatformModel instance.")
81
+ if not isinstance(model, TargetPlatformCapabilities):
82
+ raise ValueError("The provided model is not a valid TargetPlatformCapabilities instance.")
83
83
 
84
84
  path = Path(export_path)
85
85
  try: