mct-nightly 2.2.0.20250113.527__py3-none-any.whl → 2.2.0.20250114.84821__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/RECORD +103 -105
  3. model_compression_toolkit/__init__.py +2 -2
  4. model_compression_toolkit/core/common/framework_info.py +1 -3
  5. model_compression_toolkit/core/common/fusion/layer_fusing.py +6 -5
  6. model_compression_toolkit/core/common/graph/base_graph.py +20 -21
  7. model_compression_toolkit/core/common/graph/base_node.py +44 -17
  8. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +7 -6
  9. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +0 -6
  10. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +26 -135
  11. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +36 -62
  12. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +667 -0
  13. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +25 -202
  14. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py +164 -470
  15. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +30 -7
  16. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +3 -5
  17. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  18. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +7 -6
  19. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +0 -1
  20. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +0 -1
  21. model_compression_toolkit/core/common/pruning/pruner.py +5 -3
  22. model_compression_toolkit/core/common/quantization/bit_width_config.py +6 -12
  23. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -2
  24. model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
  25. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -1
  26. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  27. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +1 -1
  28. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +1 -1
  29. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +1 -1
  30. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +1 -1
  31. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +1 -1
  32. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +1 -1
  33. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +15 -14
  34. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
  35. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +1 -1
  36. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  37. model_compression_toolkit/core/graph_prep_runner.py +12 -11
  38. model_compression_toolkit/core/keras/data_util.py +24 -5
  39. model_compression_toolkit/core/keras/default_framework_info.py +1 -1
  40. model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +1 -2
  41. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +5 -6
  42. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  43. model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
  44. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
  45. model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -1
  46. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +4 -5
  47. model_compression_toolkit/core/runner.py +33 -60
  48. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +1 -1
  49. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +1 -1
  50. model_compression_toolkit/gptq/keras/quantization_facade.py +8 -9
  51. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  52. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
  53. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
  54. model_compression_toolkit/gptq/pytorch/quantization_facade.py +8 -9
  55. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  56. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
  57. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
  58. model_compression_toolkit/metadata.py +11 -10
  59. model_compression_toolkit/pruning/keras/pruning_facade.py +5 -6
  60. model_compression_toolkit/pruning/pytorch/pruning_facade.py +6 -7
  61. model_compression_toolkit/ptq/keras/quantization_facade.py +8 -9
  62. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -9
  63. model_compression_toolkit/qat/keras/quantization_facade.py +5 -6
  64. model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py +1 -1
  65. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
  66. model_compression_toolkit/qat/pytorch/quantization_facade.py +5 -9
  67. model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py +1 -1
  68. model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py +1 -1
  69. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
  70. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +1 -1
  71. model_compression_toolkit/target_platform_capabilities/__init__.py +9 -0
  72. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  73. model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +2 -2
  74. model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +18 -18
  75. model_compression_toolkit/target_platform_capabilities/schema/v1.py +13 -13
  76. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/__init__.py +6 -6
  77. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2fw.py +10 -10
  78. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2keras.py +3 -3
  79. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2pytorch.py +3 -2
  80. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/current_tpc.py +8 -8
  81. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities.py → targetplatform2framework/framework_quantization_capabilities.py} +40 -40
  82. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities_component.py → targetplatform2framework/framework_quantization_capabilities_component.py} +2 -2
  83. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/layer_filter_params.py +0 -1
  84. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/operations_to_layers.py +8 -8
  85. model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +24 -24
  86. model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +18 -18
  87. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +3 -3
  88. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/{tp_model.py → tpc.py} +31 -32
  89. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +3 -3
  90. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/{tp_model.py → tpc.py} +27 -27
  91. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +4 -4
  92. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/{tp_model.py → tpc.py} +27 -27
  93. model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +1 -2
  94. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +2 -1
  95. model_compression_toolkit/trainable_infrastructure/keras/activation_quantizers/lsq/symmetric_lsq.py +1 -2
  96. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +1 -1
  97. model_compression_toolkit/xquant/common/model_folding_utils.py +7 -6
  98. model_compression_toolkit/xquant/keras/keras_report_utils.py +4 -4
  99. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -3
  100. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +0 -105
  101. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +0 -33
  102. model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +0 -23
  103. {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/LICENSE.md +0 -0
  104. {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/WHEEL +0 -0
  105. {mct_nightly-2.2.0.20250113.527.dist-info → mct_nightly-2.2.0.20250114.84821.dist-info}/top_level.txt +0 -0
  106. /model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attribute_filter.py +0 -0
@@ -32,8 +32,9 @@ from model_compression_toolkit.core.common.collectors.statistics_collector impor
32
32
  from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
33
33
  from model_compression_toolkit.core.common.user_info import UserInformation
34
34
  from model_compression_toolkit.logger import Logger
35
- from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
36
- TargetPlatformCapabilities, LayerFilterParams
35
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
36
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
37
+ FrameworkQuantizationCapabilities
37
38
 
38
39
  OutTensor = namedtuple('OutTensor', 'node node_out_index')
39
40
 
@@ -86,29 +87,29 @@ class Graph(nx.MultiDiGraph, GraphSearches):
86
87
 
87
88
  self.fw_info = fw_info
88
89
 
89
- def set_tpc(self,
90
- tpc: TargetPlatformCapabilities):
90
+ def set_fqc(self,
91
+ fqc: FrameworkQuantizationCapabilities):
91
92
  """
92
- Set the graph's TPC.
93
+ Set the graph's FQC.
93
94
  Args:
94
- tpc: TargetPlatformCapabilities object.
95
+ fqc: FrameworkQuantizationCapabilities object.
95
96
  """
96
- # validate graph nodes are either from the framework or a custom layer defined in the TPC
97
- # Validate graph nodes are either built-in layers from the framework or custom layers defined in the TPC
98
- tpc_layers = tpc.op_sets_to_layers.get_layers()
99
- tpc_filtered_layers = [layer for layer in tpc_layers if isinstance(layer, LayerFilterParams)]
97
+ # validate graph nodes are either from the framework or a custom layer defined in the FQC
98
+ # Validate graph nodes are either built-in layers from the framework or custom layers defined in the FQC
99
+ fqc_layers = fqc.op_sets_to_layers.get_layers()
100
+ fqc_filtered_layers = [layer for layer in fqc_layers if isinstance(layer, LayerFilterParams)]
100
101
  for n in self.nodes:
101
- is_node_in_tpc = any([n.is_match_type(_type) for _type in tpc_layers]) or \
102
- any([n.is_match_filter_params(filtered_layer) for filtered_layer in tpc_filtered_layers])
102
+ is_node_in_fqc = any([n.is_match_type(_type) for _type in fqc_layers]) or \
103
+ any([n.is_match_filter_params(filtered_layer) for filtered_layer in fqc_filtered_layers])
103
104
  if n.is_custom:
104
- if not is_node_in_tpc:
105
+ if not is_node_in_fqc:
105
106
  Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. '
106
- ' Please add the custom layer to Target Platform Capabilities (TPC), or file a feature '
107
+ ' Please add the custom layer to Framework Quantization Capabilities (FQC), or file a feature '
107
108
  'request or an issue if you believe this should be supported.') # pragma: no cover
108
- if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(tpc).quantization_configurations]):
109
+ if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(fqc).quantization_configurations]):
109
110
  Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') # pragma: no cover
110
111
 
111
- self.tpc = tpc
112
+ self.fqc = fqc
112
113
 
113
114
  def get_topo_sorted_nodes(self):
114
115
  """
@@ -544,10 +545,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
544
545
  potential_conf_nodes = [n for n in list(self) if fw_info.is_kernel_op(n.type)]
545
546
 
546
547
  def is_configurable(n):
547
- kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0]
548
- return (n.is_weights_quantization_enabled(kernel_attr) and
549
- not n.is_all_weights_candidates_equal(kernel_attr) and
550
- (not n.reuse or include_reused_nodes))
548
+ kernel_attrs = fw_info.get_kernel_op_attributes(n.type)
549
+ return any(n.is_configurable_weight(attr) for attr in kernel_attrs) and (not n.reuse or include_reused_nodes)
551
550
 
552
551
  return [n for n in potential_conf_nodes if is_configurable(n)]
553
552
 
@@ -576,7 +575,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
576
575
  Returns:
577
576
  A list of nodes that their activation can be configured (namely, has one or more activation qc candidate).
578
577
  """
579
- return [n for n in list(self) if n.is_activation_quantization_enabled() and not n.is_all_activation_candidates_equal()]
578
+ return [n for n in list(self) if n.has_configurable_activation()]
580
579
 
581
580
  def get_sorted_activation_configurable_nodes(self) -> List[BaseNode]:
582
581
  """
@@ -25,7 +25,9 @@ from model_compression_toolkit.logger import Logger
25
25
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \
26
26
  OpQuantizationConfig
27
27
  from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
28
- from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, LayerFilterParams
28
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
29
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
30
+ FrameworkQuantizationCapabilities
29
31
 
30
32
 
31
33
  class BaseNode:
@@ -150,6 +152,27 @@ class BaseNode:
150
152
 
151
153
  return False
152
154
 
155
+ def is_configurable_weight(self, attr_name: str) -> bool:
156
+ """
157
+ Checks whether the specific weight attribute has a configurable quantization.
158
+
159
+ Args:
160
+ attr_name: weight attribute name.
161
+
162
+ Returns:
163
+ Whether the weight attribute is configurable.
164
+ """
165
+ return self.is_weights_quantization_enabled(attr_name) and not self.is_all_weights_candidates_equal(attr_name)
166
+
167
+ def has_configurable_activation(self) -> bool:
168
+ """
169
+ Checks whether the activation has a configurable quantization.
170
+
171
+ Returns:
172
+ Whether the activation has a configurable quantization.
173
+ """
174
+ return self.is_activation_quantization_enabled() and not self.is_all_activation_candidates_equal()
175
+
153
176
  def __repr__(self):
154
177
  """
155
178
 
@@ -420,11 +443,15 @@ class BaseNode:
420
443
 
421
444
  Returns: Output size.
422
445
  """
423
- output_shapes = self.output_shape if isinstance(self.output_shape, List) else [self.output_shape]
446
+ # shape can be tuple or list, and multiple shapes can be packed in list or tuple
447
+ if self.output_shape and isinstance(self.output_shape[0], (tuple, list)):
448
+ output_shapes = self.output_shape
449
+ else:
450
+ output_shapes = [self.output_shape]
424
451
 
425
452
  # remove batch size (first element) from output shape
426
453
  output_shapes = [s[1:] for s in output_shapes]
427
-
454
+ # for scalar shape (None,) prod returns 1
428
455
  return sum([np.prod([x for x in output_shape if x is not None]) for output_shape in output_shapes])
429
456
 
430
457
  def find_min_candidates_indices(self) -> List[int]:
@@ -536,34 +563,34 @@ class BaseNode:
536
563
  # the inner method would log an exception.
537
564
  return [c.weights_quantization_cfg.get_attr_config(attr) for c in self.candidates_quantization_cfg]
538
565
 
539
- def get_qco(self, tpc: TargetPlatformCapabilities) -> QuantizationConfigOptions:
566
+ def get_qco(self, fqc: FrameworkQuantizationCapabilities) -> QuantizationConfigOptions:
540
567
  """
541
568
  Get the QuantizationConfigOptions of the node according
542
- to the mappings from layers/LayerFilterParams to the OperatorsSet in the TargetPlatformModel.
569
+ to the mappings from layers/LayerFilterParams to the OperatorsSet in the TargetPlatformCapabilities.
543
570
 
544
571
  Args:
545
- tpc: TPC to extract the QuantizationConfigOptions for the node.
572
+ fqc: FQC to extract the QuantizationConfigOptions for the node.
546
573
 
547
574
  Returns:
548
575
  QuantizationConfigOptions of the node.
549
576
  """
550
577
 
551
- if tpc is None:
552
- Logger.critical(f'Can not retrieve QC options for None TPC') # pragma: no cover
578
+ if fqc is None:
579
+ Logger.critical(f'Can not retrieve QC options for None FQC') # pragma: no cover
553
580
 
554
- for fl, qco in tpc.filterlayer2qco.items():
581
+ for fl, qco in fqc.filterlayer2qco.items():
555
582
  if self.is_match_filter_params(fl):
556
583
  return qco
557
584
  # Extract qco with is_match_type to overcome mismatch of function types in TF 2.15
558
- matching_qcos = [_qco for _type, _qco in tpc.layer2qco.items() if self.is_match_type(_type)]
585
+ matching_qcos = [_qco for _type, _qco in fqc.layer2qco.items() if self.is_match_type(_type)]
559
586
  if matching_qcos:
560
587
  if all([_qco == matching_qcos[0] for _qco in matching_qcos]):
561
588
  return matching_qcos[0]
562
589
  else:
563
590
  Logger.critical(f"Found duplicate qco types for node '{self.name}' of type '{self.type}'!") # pragma: no cover
564
- return tpc.tp_model.default_qco
591
+ return fqc.tpc.default_qco
565
592
 
566
- def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities,
593
+ def filter_node_qco_by_graph(self, fqc: FrameworkQuantizationCapabilities,
567
594
  next_nodes: List, node_qc_options: QuantizationConfigOptions
568
595
  ) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
569
596
  """
@@ -573,7 +600,7 @@ class BaseNode:
573
600
  filters out quantization config that don't comply to these attributes.
574
601
 
575
602
  Args:
576
- tpc: TPC to extract the QuantizationConfigOptions for the next nodes.
603
+ fqc: FQC to extract the QuantizationConfigOptions for the next nodes.
577
604
  next_nodes: Output nodes of current node.
578
605
  node_qc_options: Node's QuantizationConfigOptions.
579
606
 
@@ -584,7 +611,7 @@ class BaseNode:
584
611
  _base_config = node_qc_options.base_config
585
612
  _node_qc_options = node_qc_options.quantization_configurations
586
613
  if len(next_nodes):
587
- next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
614
+ next_nodes_qc_options = [_node.get_qco(fqc) for _node in next_nodes]
588
615
  next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
589
616
  for qc_opts in next_nodes_qc_options
590
617
  for op_cfg in qc_opts.quantization_configurations])
@@ -593,7 +620,7 @@ class BaseNode:
593
620
  _node_qc_options = [_option for _option in _node_qc_options
594
621
  if _option.activation_n_bits <= next_nodes_supported_input_bitwidth]
595
622
  if len(_node_qc_options) == 0:
596
- Logger.critical(f"Graph doesn't match TPC bit configurations: {self} -> {next_nodes}.") # pragma: no cover
623
+ Logger.critical(f"Graph doesn't match FQC bit configurations: {self} -> {next_nodes}.") # pragma: no cover
597
624
 
598
625
  # Verify base config match
599
626
  if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
@@ -603,9 +630,9 @@ class BaseNode:
603
630
  if len(_node_qc_options) > 0:
604
631
  output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
605
632
  _base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
606
- Logger.warning(f"Node {self} base quantization config changed to match Graph and TPC configuration.\nCause: {self} -> {next_nodes}.")
633
+ Logger.warning(f"Node {self} base quantization config changed to match Graph and FQC configuration.\nCause: {self} -> {next_nodes}.")
607
634
  else:
608
- Logger.critical(f"Graph doesn't match TPC bit configurations: {self} -> {next_nodes}.") # pragma: no cover
635
+ Logger.critical(f"Graph doesn't match FQC bit configurations: {self} -> {next_nodes}.") # pragma: no cover
609
636
 
610
637
  return _base_config, _node_qc_options
611
638
 
@@ -17,18 +17,19 @@ import numpy as np
17
17
  from model_compression_toolkit.core import ResourceUtilization, FrameworkInfo
18
18
  from model_compression_toolkit.core.common import Graph
19
19
  from model_compression_toolkit.logger import Logger
20
- from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
20
+ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
21
+ FrameworkQuantizationCapabilities
21
22
 
22
23
 
23
24
  def filter_candidates_for_mixed_precision(graph: Graph,
24
25
  target_resource_utilization: ResourceUtilization,
25
26
  fw_info: FrameworkInfo,
26
- tpc: TargetPlatformCapabilities):
27
+ fqc: FrameworkQuantizationCapabilities):
27
28
  """
28
29
  Filters out candidates in case of mixed precision search for only weights or activation compression.
29
30
  For instance, if running only weights compression - filters out candidates of activation configurable nodes
30
31
  such that only a single candidate would remain, with the bitwidth equal to the one defined in the matching layer's
31
- base config in the TPC.
32
+ base config in the FQC.
32
33
 
33
34
  Note: This function modifies the graph inplace!
34
35
 
@@ -36,7 +37,7 @@ def filter_candidates_for_mixed_precision(graph: Graph,
36
37
  graph: A graph representation of the model to be quantized.
37
38
  target_resource_utilization: The resource utilization of the target device.
38
39
  fw_info: fw_info: Information needed for quantization about the specific framework.
39
- tpc: TargetPlatformCapabilities object that describes the desired inference target platform.
40
+ fqc: FrameworkQuantizationCapabilities object that describes the desired inference target platform.
40
41
 
41
42
  """
42
43
 
@@ -50,7 +51,7 @@ def filter_candidates_for_mixed_precision(graph: Graph,
50
51
  weights_conf = graph.get_weights_configurable_nodes(fw_info)
51
52
  activation_configurable_nodes = [n for n in graph.get_activation_configurable_nodes() if n not in weights_conf]
52
53
  for n in activation_configurable_nodes:
53
- base_cfg_nbits = n.get_qco(tpc).base_config.activation_n_bits
54
+ base_cfg_nbits = n.get_qco(fqc).base_config.activation_n_bits
54
55
  filtered_conf = [c for c in n.candidates_quantization_cfg if
55
56
  c.activation_quantization_cfg.enable_activation_quantization and
56
57
  c.activation_quantization_cfg.activation_n_bits == base_cfg_nbits]
@@ -67,7 +68,7 @@ def filter_candidates_for_mixed_precision(graph: Graph,
67
68
  weight_configurable_nodes = [n for n in graph.get_weights_configurable_nodes(fw_info) if n not in activation_conf]
68
69
  for n in weight_configurable_nodes:
69
70
  kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
70
- base_cfg_nbits = n.get_qco(tpc).base_config.attr_weights_configs_mapping[kernel_attr].weights_n_bits
71
+ base_cfg_nbits = n.get_qco(fqc).base_config.attr_weights_configs_mapping[kernel_attr].weights_n_bits
71
72
  filtered_conf = [c for c in n.candidates_quantization_cfg if
72
73
  c.weights_quantization_cfg.get_attr_config(kernel_attr).enable_weights_quantization and
73
74
  c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == base_cfg_nbits]
@@ -22,7 +22,6 @@ from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
22
22
  from model_compression_toolkit.core.common import Graph
23
23
  from model_compression_toolkit.core.common.hessian import HessianInfoService
24
24
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget
25
- from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_functions_mapping import ru_functions_mapping
26
25
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
27
26
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import MixedPrecisionSearchManager
28
27
  from model_compression_toolkit.core.common.mixed_precision.search_methods.linear_programming import \
@@ -105,16 +104,11 @@ def search_bit_width(graph_to_search_cfg: Graph,
105
104
  disable_activation_for_metric=disable_activation_for_metric,
106
105
  hessian_info_service=hessian_info_service)
107
106
 
108
- # Each pair of (resource utilization method, resource utilization aggregation) should match to a specific
109
- # provided target resource utilization
110
- ru_functions = ru_functions_mapping
111
-
112
107
  # Instantiate a manager object
113
108
  search_manager = MixedPrecisionSearchManager(graph,
114
109
  fw_info,
115
110
  fw_impl,
116
111
  se,
117
- ru_functions,
118
112
  target_resource_utilization,
119
113
  original_graph=graph_to_search_cfg)
120
114
 
@@ -13,23 +13,24 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Callable, Tuple
17
- from typing import Dict, List
16
+ from typing import Callable, Dict, List
17
+
18
18
  import numpy as np
19
19
 
20
20
  from model_compression_toolkit.core.common import BaseNode
21
- from model_compression_toolkit.logger import Logger
22
21
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
22
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
23
  from model_compression_toolkit.core.common.graph.base_graph import Graph
24
24
  from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \
25
25
  VirtualSplitWeightsNode, VirtualSplitActivationNode
26
- from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget, ResourceUtilization
27
- from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_functions_mapping import RuFunctions
28
- from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_aggregation_methods import MpRuAggregation
29
- from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import MpRuMetric, calc_graph_cuts
30
- from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import Cut
31
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
26
+ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
27
+ RUTarget, ResourceUtilization
28
+ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
29
+ ResourceUtilizationCalculator, TargetInclusionCriterion, BitwidthMode
30
+ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import \
31
+ MixedPrecisionRUHelper
32
32
  from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
33
+ from model_compression_toolkit.logger import Logger
33
34
 
34
35
 
35
36
  class MixedPrecisionSearchManager:
@@ -42,7 +43,6 @@ class MixedPrecisionSearchManager:
42
43
  fw_info: FrameworkInfo,
43
44
  fw_impl: FrameworkImplementation,
44
45
  sensitivity_evaluator: SensitivityEvaluation,
45
- ru_functions: Dict[RUTarget, RuFunctions],
46
46
  target_resource_utilization: ResourceUtilization,
47
47
  original_graph: Graph = None):
48
48
  """
@@ -53,8 +53,6 @@ class MixedPrecisionSearchManager:
53
53
  fw_impl: FrameworkImplementation object with specific framework methods implementation.
54
54
  sensitivity_evaluator: A SensitivityEvaluation which provides a function that evaluates the sensitivity of
55
55
  a bit-width configuration for the MP model.
56
- ru_functions: A dictionary with pairs of (MpRuMethod, MpRuAggregationMethod) mapping a RUTarget to
57
- a couple of resource utilization metric function and resource utilization aggregation function.
58
56
  target_resource_utilization: Target Resource Utilization to bound our feasible solution space s.t the configuration does not violate it.
59
57
  original_graph: In case we have a search over a virtual graph (if we have BOPS utilization target), then this argument
60
58
  will contain the original graph (for config reconstruction purposes).
@@ -69,29 +67,17 @@ class MixedPrecisionSearchManager:
69
67
  self.compute_metric_fn = self.get_sensitivity_metric()
70
68
  self._cuts = None
71
69
 
72
- ru_types = [ru_target for ru_target, ru_value in
73
- target_resource_utilization.get_resource_utilization_dict().items() if ru_value < np.inf]
74
- self.compute_ru_functions = {ru_target: ru_fn for ru_target, ru_fn in ru_functions.items() if ru_target in ru_types}
70
+ self.ru_metrics = target_resource_utilization.get_restricted_metrics()
71
+ self.ru_helper = MixedPrecisionRUHelper(graph, fw_info, fw_impl)
75
72
  self.target_resource_utilization = target_resource_utilization
76
73
  self.min_ru_config = self.graph.get_min_candidates_config(fw_info)
77
74
  self.max_ru_config = self.graph.get_max_candidates_config(fw_info)
78
- self.min_ru = self.compute_min_ru()
75
+ self.min_ru = self.ru_helper.compute_utilization(self.ru_metrics, self.min_ru_config)
79
76
  self.non_conf_ru_dict = self._non_configurable_nodes_ru()
80
77
 
81
78
  self.config_reconstruction_helper = ConfigReconstructionHelper(virtual_graph=self.graph,
82
79
  original_graph=self.original_graph)
83
80
 
84
- @property
85
- def cuts(self) -> List[Cut]:
86
- """
87
- Calculates graph cuts. Written as property, so it will only be calculated once and
88
- only if cuts are needed.
89
-
90
- """
91
- if self._cuts is None:
92
- self._cuts = calc_graph_cuts(self.original_graph)
93
- return self._cuts
94
-
95
81
  def get_search_space(self) -> Dict[int, List[int]]:
96
82
  """
97
83
  The search space is a mapping from a node's index to a list of integers (possible bitwidths candidates indeces
@@ -122,40 +108,6 @@ class MixedPrecisionSearchManager:
122
108
 
123
109
  return self.sensitivity_evaluator.compute_metric
124
110
 
125
- def _calc_ru_fn(self, ru_target, ru_fn, mp_cfg) -> np.ndarray:
126
- """
127
- Computes a resource utilization for a certain mixed precision configuration.
128
- The method computes a resource utilization vector for specific target resource utilization.
129
-
130
- Returns: resource utilization value.
131
-
132
- """
133
- # ru_fn is a pair of resource utilization computation method and
134
- # resource utilization aggregation method (in this method we only need the first one)
135
- if ru_target is RUTarget.ACTIVATION:
136
- return ru_fn.metric_fn(mp_cfg, self.graph, self.fw_info, self.fw_impl, self.cuts)
137
- else:
138
- return ru_fn.metric_fn(mp_cfg, self.graph, self.fw_info, self.fw_impl)
139
-
140
- def compute_min_ru(self) -> Dict[RUTarget, np.ndarray]:
141
- """
142
- Computes a resource utilization vector with the values matching to the minimal mp configuration
143
- (i.e., each node is configured with the quantization candidate that would give the minimal size of the
144
- node's resource utilization).
145
- The method computes the minimal resource utilization vector for each target resource utilization.
146
-
147
- Returns: A dictionary mapping each target resource utilization to its respective minimal
148
- resource utilization values.
149
-
150
- """
151
- min_ru = {}
152
- for ru_target, ru_fn in self.compute_ru_functions.items():
153
- # ru_fns is a pair of resource utilization computation method and
154
- # resource utilization aggregation method (in this method we only need the first one)
155
- min_ru[ru_target] = self._calc_ru_fn(ru_target, ru_fn, self.min_ru_config)
156
-
157
- return min_ru
158
-
159
111
  def compute_resource_utilization_matrix(self, target: RUTarget) -> np.ndarray:
160
112
  """
161
113
  Computes and builds a resource utilization matrix, to be used for the mixed-precision search problem formalization.
@@ -184,7 +136,8 @@ class MixedPrecisionSearchManager:
184
136
  # always be 0 for all entries in the results vector.
185
137
  candidate_rus = np.zeros(shape=self.min_ru[target].shape)
186
138
  else:
187
- candidate_rus = self.compute_candidate_relative_ru(c, candidate_idx, target)
139
+ candidate_rus = self.compute_node_ru_for_candidate(c, candidate_idx, target) - self.min_ru[target]
140
+
188
141
  ru_matrix.append(np.asarray(candidate_rus))
189
142
 
190
143
  # We need to transpose the calculated ru matrix to allow later multiplication with
@@ -195,40 +148,6 @@ class MixedPrecisionSearchManager:
195
148
  np_ru_matrix = np.array(ru_matrix)
196
149
  return np.moveaxis(np_ru_matrix, source=0, destination=len(np_ru_matrix.shape) - 1)
197
150
 
198
- def compute_candidate_relative_ru(self,
199
- conf_node_idx: int,
200
- candidate_idx: int,
201
- target: RUTarget) -> np.ndarray:
202
- """
203
- Computes a resource utilization vector for a given candidates of a given configurable node,
204
- i.e., the matching resource utilization vector which is obtained by computing the given target's
205
- resource utilization function on a minimal configuration in which the given
206
- layer's candidates is changed to the new given one.
207
- The result is normalized by subtracting the target's minimal resource utilization vector.
208
-
209
- Args:
210
- conf_node_idx: The index of a node in a sorted configurable nodes list.
211
- candidate_idx: The index of a node's quantization configuration candidate.
212
- target: The target for which the resource utilization is calculated (a RUTarget value).
213
-
214
- Returns: Normalized node's resource utilization vector
215
-
216
- """
217
- return self.compute_node_ru_for_candidate(conf_node_idx, candidate_idx, target) - \
218
- self.get_min_target_resource_utilization(target)
219
-
220
- def get_min_target_resource_utilization(self, target: RUTarget) -> np.ndarray:
221
- """
222
- Returns the minimal resource utilization vector (pre-calculated on initialization) of a specific target.
223
-
224
- Args:
225
- target: The target for which the resource utilization is calculated (a RUTarget value).
226
-
227
- Returns: Minimal resource utilization vector.
228
-
229
- """
230
- return self.min_ru[target]
231
-
232
151
  def compute_node_ru_for_candidate(self, conf_node_idx: int, candidate_idx: int, target: RUTarget) -> np.ndarray:
233
152
  """
234
153
  Computes a resource utilization vector after replacing the given node's configuration candidate in the minimal
@@ -243,7 +162,8 @@ class MixedPrecisionSearchManager:
243
162
 
244
163
  """
245
164
  cfg = self.replace_config_in_index(self.min_ru_config, conf_node_idx, candidate_idx)
246
- return self._calc_ru_fn(target, self.compute_ru_functions[target], cfg)
165
+ # TODO compute for all targets at once. Currently the way up to add_set_of_ru_constraints is per target.
166
+ return self.ru_helper.compute_utilization({target}, cfg)[target]
247
167
 
248
168
  @staticmethod
249
169
  def replace_config_in_index(mp_cfg: List[int], idx: int, value: int) -> List[int]:
@@ -270,21 +190,10 @@ class MixedPrecisionSearchManager:
270
190
 
271
191
  Returns: A mapping between a RUTarget and its non-configurable nodes' resource utilization vector.
272
192
  """
273
-
274
- non_conf_ru_dict = {}
275
- for target, ru_fns in self.compute_ru_functions.items():
276
- # Call for the ru method of the given target - empty quantization configuration list is passed since we
277
- # compute for non-configurable nodes
278
- if target == RUTarget.BOPS:
279
- ru_vector = None
280
- elif target == RUTarget.ACTIVATION:
281
- ru_vector = ru_fns.metric_fn([], self.graph, self.fw_info, self.fw_impl, self.cuts)
282
- else:
283
- ru_vector = ru_fns.metric_fn([], self.graph, self.fw_info, self.fw_impl)
284
-
285
- non_conf_ru_dict[target] = ru_vector
286
-
287
- return non_conf_ru_dict
193
+ ru_metrics = self.ru_metrics - {RUTarget.BOPS}
194
+ ru = self.ru_helper.compute_utilization(ru_targets=ru_metrics, mp_cfg=None)
195
+ ru[RUTarget.BOPS] = None
196
+ return ru
288
197
 
289
198
  def compute_resource_utilization_for_config(self, config: List[int]) -> ResourceUtilization:
290
199
  """
@@ -297,29 +206,11 @@ class MixedPrecisionSearchManager:
297
206
  with the given config.
298
207
 
299
208
  """
300
-
301
- ru_dict = {}
302
- for ru_target, ru_fns in self.compute_ru_functions.items():
303
- # Passing False to ru methods and aggregations to indicates that the computations
304
- # are not for constraints setting
305
- if ru_target == RUTarget.BOPS:
306
- configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.original_graph, self.fw_info, self.fw_impl, False)
307
- elif ru_target == RUTarget.ACTIVATION:
308
- configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.graph, self.fw_info, self.fw_impl, self.cuts)
309
- else:
310
- configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.original_graph, self.fw_info, self.fw_impl)
311
- non_configurable_nodes_ru_vector = self.non_conf_ru_dict.get(ru_target)
312
- if non_configurable_nodes_ru_vector is None or len(non_configurable_nodes_ru_vector) == 0:
313
- ru_ru = self.compute_ru_functions[ru_target].aggregate_fn(configurable_nodes_ru_vector, False)
314
- else:
315
- ru_ru = self.compute_ru_functions[ru_target].aggregate_fn(
316
- np.concatenate([configurable_nodes_ru_vector, non_configurable_nodes_ru_vector]), False)
317
-
318
- ru_dict[ru_target] = ru_ru[0]
319
-
320
- config_ru = ResourceUtilization()
321
- config_ru.set_resource_utilization_by_target(ru_dict)
322
- return config_ru
209
+ act_qcs, w_qcs = self.ru_helper.get_configurable_qcs(config)
210
+ ru = self.ru_helper.ru_calculator.compute_resource_utilization(
211
+ target_criterion=TargetInclusionCriterion.AnyQuantized, bitwidth_mode=BitwidthMode.QCustom, act_qcs=act_qcs,
212
+ w_qcs=w_qcs)
213
+ return ru
323
214
 
324
215
  def finalize_distance_metric(self, layer_to_metrics_mapping: Dict[int, Dict[int, float]]):
325
216
  """