mct-nightly 2.2.0.20250113.134913__py3-none-any.whl → 2.2.0.20250114.134534__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/RECORD +102 -104
  3. model_compression_toolkit/__init__.py +2 -2
  4. model_compression_toolkit/core/common/framework_info.py +1 -3
  5. model_compression_toolkit/core/common/fusion/layer_fusing.py +6 -5
  6. model_compression_toolkit/core/common/graph/base_graph.py +20 -21
  7. model_compression_toolkit/core/common/graph/base_node.py +44 -17
  8. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +7 -6
  9. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +187 -0
  10. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +0 -6
  11. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +35 -162
  12. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +36 -62
  13. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +668 -0
  14. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +25 -202
  15. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +74 -51
  16. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +3 -5
  17. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +2 -2
  18. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +7 -6
  19. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +0 -1
  20. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +0 -1
  21. model_compression_toolkit/core/common/pruning/pruner.py +5 -3
  22. model_compression_toolkit/core/common/quantization/bit_width_config.py +6 -12
  23. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +1 -2
  24. model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
  25. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -1
  26. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -1
  27. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +1 -1
  28. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +1 -1
  29. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +1 -1
  30. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +1 -1
  31. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +1 -1
  32. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +1 -1
  33. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +15 -14
  34. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
  35. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +1 -1
  36. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -5
  37. model_compression_toolkit/core/graph_prep_runner.py +12 -11
  38. model_compression_toolkit/core/keras/default_framework_info.py +1 -1
  39. model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +1 -2
  40. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +5 -6
  41. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  42. model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
  43. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
  44. model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -1
  45. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +4 -5
  46. model_compression_toolkit/core/runner.py +33 -60
  47. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +1 -1
  48. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +1 -1
  49. model_compression_toolkit/gptq/keras/quantization_facade.py +8 -9
  50. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  51. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
  52. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
  53. model_compression_toolkit/gptq/pytorch/quantization_facade.py +8 -9
  54. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +1 -1
  55. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +1 -1
  56. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
  57. model_compression_toolkit/metadata.py +11 -10
  58. model_compression_toolkit/pruning/keras/pruning_facade.py +5 -6
  59. model_compression_toolkit/pruning/pytorch/pruning_facade.py +6 -7
  60. model_compression_toolkit/ptq/keras/quantization_facade.py +8 -9
  61. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -9
  62. model_compression_toolkit/qat/keras/quantization_facade.py +5 -6
  63. model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py +1 -1
  64. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +1 -1
  65. model_compression_toolkit/qat/pytorch/quantization_facade.py +5 -9
  66. model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py +1 -1
  67. model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py +1 -1
  68. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +1 -1
  69. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +1 -1
  70. model_compression_toolkit/target_platform_capabilities/__init__.py +9 -0
  71. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  72. model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +2 -2
  73. model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +18 -18
  74. model_compression_toolkit/target_platform_capabilities/schema/v1.py +13 -13
  75. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/__init__.py +6 -6
  76. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2fw.py +10 -10
  77. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2keras.py +3 -3
  78. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attach2pytorch.py +3 -2
  79. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/current_tpc.py +8 -8
  80. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities.py → targetplatform2framework/framework_quantization_capabilities.py} +40 -40
  81. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework/target_platform_capabilities_component.py → targetplatform2framework/framework_quantization_capabilities_component.py} +2 -2
  82. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/layer_filter_params.py +0 -1
  83. model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/operations_to_layers.py +8 -8
  84. model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +24 -24
  85. model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +18 -18
  86. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +3 -3
  87. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/{tp_model.py → tpc.py} +31 -32
  88. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +3 -3
  89. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/{tp_model.py → tpc.py} +27 -27
  90. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +4 -4
  91. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/{tp_model.py → tpc.py} +27 -27
  92. model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +1 -2
  93. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +2 -1
  94. model_compression_toolkit/trainable_infrastructure/keras/activation_quantizers/lsq/symmetric_lsq.py +1 -2
  95. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +1 -1
  96. model_compression_toolkit/xquant/common/model_folding_utils.py +7 -6
  97. model_compression_toolkit/xquant/keras/keras_report_utils.py +4 -4
  98. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -3
  99. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +0 -105
  100. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +0 -33
  101. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py +0 -528
  102. model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +0 -23
  103. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/LICENSE.md +0 -0
  104. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/WHEEL +0 -0
  105. {mct_nightly-2.2.0.20250113.134913.dist-info → mct_nightly-2.2.0.20250114.134534.dist-info}/top_level.txt +0 -0
  106. /model_compression_toolkit/target_platform_capabilities/{target_platform/targetplatform2framework → targetplatform2framework}/attribute_filter.py +0 -0
@@ -13,23 +13,24 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Callable, Tuple
17
- from typing import Dict, List
16
+ from typing import Callable, Dict, List
17
+
18
18
  import numpy as np
19
19
 
20
20
  from model_compression_toolkit.core.common import BaseNode
21
- from model_compression_toolkit.logger import Logger
22
21
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
22
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
23
  from model_compression_toolkit.core.common.graph.base_graph import Graph
24
24
  from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \
25
25
  VirtualSplitWeightsNode, VirtualSplitActivationNode
26
- from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget, ResourceUtilization
27
- from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_functions_mapping import RuFunctions
28
- from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_aggregation_methods import MpRuAggregation
29
- from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import MpRuMetric, calc_graph_cuts
30
- from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import Cut
31
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
26
+ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
27
+ RUTarget, ResourceUtilization
28
+ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
29
+ TargetInclusionCriterion, BitwidthMode
30
+ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_ru_helper import \
31
+ MixedPrecisionRUHelper
32
32
  from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
33
+ from model_compression_toolkit.logger import Logger
33
34
 
34
35
 
35
36
  class MixedPrecisionSearchManager:
@@ -42,7 +43,6 @@ class MixedPrecisionSearchManager:
42
43
  fw_info: FrameworkInfo,
43
44
  fw_impl: FrameworkImplementation,
44
45
  sensitivity_evaluator: SensitivityEvaluation,
45
- ru_functions: Dict[RUTarget, RuFunctions],
46
46
  target_resource_utilization: ResourceUtilization,
47
47
  original_graph: Graph = None):
48
48
  """
@@ -53,8 +53,6 @@ class MixedPrecisionSearchManager:
53
53
  fw_impl: FrameworkImplementation object with specific framework methods implementation.
54
54
  sensitivity_evaluator: A SensitivityEvaluation which provides a function that evaluates the sensitivity of
55
55
  a bit-width configuration for the MP model.
56
- ru_functions: A dictionary with pairs of (MpRuMethod, MpRuAggregationMethod) mapping a RUTarget to
57
- a couple of resource utilization metric function and resource utilization aggregation function.
58
56
  target_resource_utilization: Target Resource Utilization to bound our feasible solution space s.t the configuration does not violate it.
59
57
  original_graph: In case we have a search over a virtual graph (if we have BOPS utilization target), then this argument
60
58
  will contain the original graph (for config reconstruction purposes).
@@ -69,29 +67,23 @@ class MixedPrecisionSearchManager:
69
67
  self.compute_metric_fn = self.get_sensitivity_metric()
70
68
  self._cuts = None
71
69
 
72
- ru_types = [ru_target for ru_target, ru_value in
73
- target_resource_utilization.get_resource_utilization_dict().items() if ru_value < np.inf]
74
- self.compute_ru_functions = {ru_target: ru_fn for ru_target, ru_fn in ru_functions.items() if ru_target in ru_types}
70
+ # To define RU Total constraints we need to compute weights and activations even if they have no constraints
71
+ # TODO currently this logic is duplicated in linear_programming.py
72
+ targets = target_resource_utilization.get_restricted_metrics()
73
+ if RUTarget.TOTAL in targets:
74
+ targets = targets.union({RUTarget.ACTIVATION, RUTarget.WEIGHTS}) - {RUTarget.TOTAL}
75
+ self.ru_targets_to_compute = targets
76
+
77
+ self.ru_helper = MixedPrecisionRUHelper(graph, fw_info, fw_impl)
75
78
  self.target_resource_utilization = target_resource_utilization
76
79
  self.min_ru_config = self.graph.get_min_candidates_config(fw_info)
77
80
  self.max_ru_config = self.graph.get_max_candidates_config(fw_info)
78
- self.min_ru = self.compute_min_ru()
79
- self.non_conf_ru_dict = self._non_configurable_nodes_ru()
81
+ self.min_ru = self.ru_helper.compute_utilization(self.ru_targets_to_compute, self.min_ru_config)
82
+ self.non_conf_ru_dict = self.ru_helper.compute_utilization(self.ru_targets_to_compute, None)
80
83
 
81
84
  self.config_reconstruction_helper = ConfigReconstructionHelper(virtual_graph=self.graph,
82
85
  original_graph=self.original_graph)
83
86
 
84
- @property
85
- def cuts(self) -> List[Cut]:
86
- """
87
- Calculates graph cuts. Written as property, so it will only be calculated once and
88
- only if cuts are needed.
89
-
90
- """
91
- if self._cuts is None:
92
- self._cuts = calc_graph_cuts(self.original_graph)
93
- return self._cuts
94
-
95
87
  def get_search_space(self) -> Dict[int, List[int]]:
96
88
  """
97
89
  The search space is a mapping from a node's index to a list of integers (possible bitwidths candidates indeces
@@ -122,55 +114,17 @@ class MixedPrecisionSearchManager:
122
114
 
123
115
  return self.sensitivity_evaluator.compute_metric
124
116
 
125
- def _calc_ru_fn(self, ru_target, ru_fn, mp_cfg) -> np.ndarray:
126
- """
127
- Computes a resource utilization for a certain mixed precision configuration.
128
- The method computes a resource utilization vector for specific target resource utilization.
129
-
130
- Returns: resource utilization value.
131
-
132
- """
133
- # ru_fn is a pair of resource utilization computation method and
134
- # resource utilization aggregation method (in this method we only need the first one)
135
- if ru_target is RUTarget.ACTIVATION:
136
- return ru_fn.metric_fn(mp_cfg, self.graph, self.fw_info, self.fw_impl, self.cuts)
137
- else:
138
- return ru_fn.metric_fn(mp_cfg, self.graph, self.fw_info, self.fw_impl)
139
-
140
- def compute_min_ru(self) -> Dict[RUTarget, np.ndarray]:
141
- """
142
- Computes a resource utilization vector with the values matching to the minimal mp configuration
143
- (i.e., each node is configured with the quantization candidate that would give the minimal size of the
144
- node's resource utilization).
145
- The method computes the minimal resource utilization vector for each target resource utilization.
146
-
147
- Returns: A dictionary mapping each target resource utilization to its respective minimal
148
- resource utilization values.
149
-
150
- """
151
- min_ru = {}
152
- for ru_target, ru_fn in self.compute_ru_functions.items():
153
- # ru_fns is a pair of resource utilization computation method and
154
- # resource utilization aggregation method (in this method we only need the first one)
155
- min_ru[ru_target] = self._calc_ru_fn(ru_target, ru_fn, self.min_ru_config)
156
-
157
- return min_ru
158
-
159
117
  def compute_resource_utilization_matrix(self, target: RUTarget) -> np.ndarray:
160
118
  """
161
119
  Computes and builds a resource utilization matrix, to be used for the mixed-precision search problem formalization.
162
- The matrix is constructed as follows (for a given target):
163
- - Each row represents the set of resource utilization values for a specific resource utilization
164
- measure (number of rows should be equal to the length of the output of the respective target compute_ru function).
165
- - Each entry in a specific column represents the resource utilization value of a given configuration
166
- (single layer is configured with specific candidate, all other layer are at the minimal resource
167
- utilization configuration) for the resource utilization measure of the respective row.
120
+ Utilization is computed relative to the minimal configuration, i.e. utilization for it will be 0.
168
121
 
169
122
  Args:
170
123
  target: The resource target for which the resource utilization is calculated (a RUTarget value).
171
124
 
172
- Returns: A resource utilization matrix.
173
-
125
+ Returns:
126
+ A resource utilization matrix of shape (num configurations, num memory elements). Num memory elements
127
+ depends on the target, e.g. num nodes or num cuts, for which utilization is computed.
174
128
  """
175
129
  assert isinstance(target, RUTarget), f"{target} is not a valid resource target"
176
130
 
@@ -180,54 +134,14 @@ class MixedPrecisionSearchManager:
180
134
  for c, c_n in enumerate(configurable_sorted_nodes):
181
135
  for candidate_idx in range(len(c_n.candidates_quantization_cfg)):
182
136
  if candidate_idx == self.min_ru_config[c]:
183
- # skip ru computation for min configuration. Since we compute the difference from min_ru it'll
184
- # always be 0 for all entries in the results vector.
185
- candidate_rus = np.zeros(shape=self.min_ru[target].shape)
137
+ candidate_rus = self.min_ru[target]
186
138
  else:
187
- candidate_rus = self.compute_candidate_relative_ru(c, candidate_idx, target)
188
- ru_matrix.append(np.asarray(candidate_rus))
189
-
190
- # We need to transpose the calculated ru matrix to allow later multiplication with
191
- # the indicators' diagonal matrix.
192
- # We only move the first axis (num of configurations) to be last,
193
- # the remaining axes include the metric specific nodes (rows dimension of the new tensor)
194
- # and the ru metric values (if they are non-scalars)
195
- np_ru_matrix = np.array(ru_matrix)
196
- return np.moveaxis(np_ru_matrix, source=0, destination=len(np_ru_matrix.shape) - 1)
197
-
198
- def compute_candidate_relative_ru(self,
199
- conf_node_idx: int,
200
- candidate_idx: int,
201
- target: RUTarget) -> np.ndarray:
202
- """
203
- Computes a resource utilization vector for a given candidates of a given configurable node,
204
- i.e., the matching resource utilization vector which is obtained by computing the given target's
205
- resource utilization function on a minimal configuration in which the given
206
- layer's candidates is changed to the new given one.
207
- The result is normalized by subtracting the target's minimal resource utilization vector.
139
+ candidate_rus = self.compute_node_ru_for_candidate(c, candidate_idx, target)
208
140
 
209
- Args:
210
- conf_node_idx: The index of a node in a sorted configurable nodes list.
211
- candidate_idx: The index of a node's quantization configuration candidate.
212
- target: The target for which the resource utilization is calculated (a RUTarget value).
213
-
214
- Returns: Normalized node's resource utilization vector
215
-
216
- """
217
- return self.compute_node_ru_for_candidate(conf_node_idx, candidate_idx, target) - \
218
- self.get_min_target_resource_utilization(target)
219
-
220
- def get_min_target_resource_utilization(self, target: RUTarget) -> np.ndarray:
221
- """
222
- Returns the minimal resource utilization vector (pre-calculated on initialization) of a specific target.
223
-
224
- Args:
225
- target: The target for which the resource utilization is calculated (a RUTarget value).
226
-
227
- Returns: Minimal resource utilization vector.
141
+ ru_matrix.append(np.asarray(candidate_rus))
228
142
 
229
- """
230
- return self.min_ru[target]
143
+ np_ru_matrix = np.array(ru_matrix) - self.min_ru[target] # num configurations X num elements
144
+ return np_ru_matrix
231
145
 
232
146
  def compute_node_ru_for_candidate(self, conf_node_idx: int, candidate_idx: int, target: RUTarget) -> np.ndarray:
233
147
  """
@@ -243,7 +157,7 @@ class MixedPrecisionSearchManager:
243
157
 
244
158
  """
245
159
  cfg = self.replace_config_in_index(self.min_ru_config, conf_node_idx, candidate_idx)
246
- return self._calc_ru_fn(target, self.compute_ru_functions[target], cfg)
160
+ return self.ru_helper.compute_utilization({target}, cfg)[target]
247
161
 
248
162
  @staticmethod
249
163
  def replace_config_in_index(mp_cfg: List[int], idx: int, value: int) -> List[int]:
@@ -263,29 +177,6 @@ class MixedPrecisionSearchManager:
263
177
  updated_cfg[idx] = value
264
178
  return updated_cfg
265
179
 
266
- def _non_configurable_nodes_ru(self) -> Dict[RUTarget, np.ndarray]:
267
- """
268
- Computes a resource utilization vector of all non-configurable nodes in the given graph for each of the
269
- resource utilization targets.
270
-
271
- Returns: A mapping between a RUTarget and its non-configurable nodes' resource utilization vector.
272
- """
273
-
274
- non_conf_ru_dict = {}
275
- for target, ru_fns in self.compute_ru_functions.items():
276
- # Call for the ru method of the given target - empty quantization configuration list is passed since we
277
- # compute for non-configurable nodes
278
- if target == RUTarget.BOPS:
279
- ru_vector = None
280
- elif target == RUTarget.ACTIVATION:
281
- ru_vector = ru_fns.metric_fn([], self.graph, self.fw_info, self.fw_impl, self.cuts)
282
- else:
283
- ru_vector = ru_fns.metric_fn([], self.graph, self.fw_info, self.fw_impl)
284
-
285
- non_conf_ru_dict[target] = ru_vector
286
-
287
- return non_conf_ru_dict
288
-
289
180
  def compute_resource_utilization_for_config(self, config: List[int]) -> ResourceUtilization:
290
181
  """
291
182
  Computes the resource utilization values for a given mixed-precision configuration.
@@ -297,29 +188,11 @@ class MixedPrecisionSearchManager:
297
188
  with the given config.
298
189
 
299
190
  """
300
-
301
- ru_dict = {}
302
- for ru_target, ru_fns in self.compute_ru_functions.items():
303
- # Passing False to ru methods and aggregations to indicates that the computations
304
- # are not for constraints setting
305
- if ru_target == RUTarget.BOPS:
306
- configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.original_graph, self.fw_info, self.fw_impl, False)
307
- elif ru_target == RUTarget.ACTIVATION:
308
- configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.graph, self.fw_info, self.fw_impl, self.cuts)
309
- else:
310
- configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.original_graph, self.fw_info, self.fw_impl)
311
- non_configurable_nodes_ru_vector = self.non_conf_ru_dict.get(ru_target)
312
- if non_configurable_nodes_ru_vector is None or len(non_configurable_nodes_ru_vector) == 0:
313
- ru_ru = self.compute_ru_functions[ru_target].aggregate_fn(configurable_nodes_ru_vector, False)
314
- else:
315
- ru_ru = self.compute_ru_functions[ru_target].aggregate_fn(
316
- np.concatenate([configurable_nodes_ru_vector, non_configurable_nodes_ru_vector]), False)
317
-
318
- ru_dict[ru_target] = ru_ru[0]
319
-
320
- config_ru = ResourceUtilization()
321
- config_ru.set_resource_utilization_by_target(ru_dict)
322
- return config_ru
191
+ act_qcs, w_qcs = self.ru_helper.get_quantization_candidates(config)
192
+ ru = self.ru_helper.ru_calculator.compute_resource_utilization(
193
+ target_criterion=TargetInclusionCriterion.AnyQuantized, bitwidth_mode=BitwidthMode.QCustom, act_qcs=act_qcs,
194
+ w_qcs=w_qcs)
195
+ return ru
323
196
 
324
197
  def finalize_distance_metric(self, layer_to_metrics_mapping: Dict[int, Dict[int, float]]):
325
198
  """
@@ -12,29 +12,21 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from dataclasses import dataclass
15
16
  from enum import Enum
16
- from typing import Dict, Any
17
+ from typing import Dict, Any, Set
17
18
 
18
19
  import numpy as np
19
20
 
20
21
 
21
22
  class RUTarget(Enum):
22
23
  """
23
- Targets for which we define Resource Utilization metrics for mixed-precision search.
24
- For each target that we care to consider in a mixed-precision search, there should be defined a set of
25
- resource utilization computation function, resource utilization aggregation function,
26
- and resource utilization target (within a ResourceUtilization object).
27
-
28
- Whenever adding a resource utilization metric to ResourceUtilization class we should add a matching target to this enum.
29
-
30
- WEIGHTS - Weights memory ResourceUtilization metric.
31
-
32
- ACTIVATION - Activation memory ResourceUtilization metric.
33
-
34
- TOTAL - Total memory ResourceUtilization metric.
35
-
36
- BOPS - Total Bit-Operations ResourceUtilization Metric.
24
+ Resource Utilization targets for mixed-precision search.
37
25
 
26
+ WEIGHTS - Weights memory.
27
+ ACTIVATION - Activation memory.
28
+ TOTAL - Total memory.
29
+ BOPS - Total Bit-Operations.
38
30
  """
39
31
 
40
32
  WEIGHTS = 'weights'
@@ -43,34 +35,20 @@ class RUTarget(Enum):
43
35
  BOPS = 'bops'
44
36
 
45
37
 
38
+ @dataclass
46
39
  class ResourceUtilization:
47
40
  """
48
41
  Class to represent measurements of performance.
49
- """
50
-
51
- def __init__(self,
52
- weights_memory: float = np.inf,
53
- activation_memory: float = np.inf,
54
- total_memory: float = np.inf,
55
- bops: float = np.inf):
56
- """
57
-
58
- Args:
59
- weights_memory: Memory of a model's weights in bytes. Note that this includes only coefficients that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value, while the bias will not).
60
- activation_memory: Memory of a model's activation in bytes, according to the given activation resource utilization metric.
61
- total_memory: The sum of model's activation and weights memory in bytes, according to the given total resource utilization metric.
62
- bops: The total bit-operations in the model.
63
- """
64
- self.weights_memory = weights_memory
65
- self.activation_memory = activation_memory
66
- self.total_memory = total_memory
67
- self.bops = bops
68
42
 
69
- def __repr__(self):
70
- return f"Weights_memory: {self.weights_memory}, " \
71
- f"Activation_memory: {self.activation_memory}, " \
72
- f"Total_memory: {self.total_memory}, " \
73
- f"BOPS: {self.bops}"
43
+ weights_memory: Memory of a model's weights in bytes.
44
+ activation_memory: Memory of a model's activation in bytes.
45
+ total_memory: The sum of model's activation and weights memory in bytes.
46
+ bops: The total bit-operations in the model.
47
+ """
48
+ weights_memory: float = np.inf
49
+ activation_memory: float = np.inf
50
+ total_memory: float = np.inf
51
+ bops: float = np.inf
74
52
 
75
53
  def weight_restricted(self):
76
54
  return self.weights_memory < np.inf
@@ -93,34 +71,30 @@ class ResourceUtilization:
93
71
  RUTarget.TOTAL: self.total_memory,
94
72
  RUTarget.BOPS: self.bops}
95
73
 
96
- def set_resource_utilization_by_target(self, ru_mapping: Dict[RUTarget, float]):
74
+ def is_satisfied_by(self, ru: 'ResourceUtilization') -> bool:
97
75
  """
98
- Setting a ResourceUtilization object values for each ResourceUtilization target in the given dictionary.
76
+ Checks whether another ResourceUtilization object satisfies the constraints defined by the current object.
99
77
 
100
78
  Args:
101
- ru_mapping: A mapping from a RUTarget to a matching resource utilization value.
79
+ ru: A ResourceUtilization object to check against the current object.
102
80
 
81
+ Returns:
82
+ Whether all constraints are satisfied.
103
83
  """
104
- self.weights_memory = ru_mapping.get(RUTarget.WEIGHTS, np.inf)
105
- self.activation_memory = ru_mapping.get(RUTarget.ACTIVATION, np.inf)
106
- self.total_memory = ru_mapping.get(RUTarget.TOTAL, np.inf)
107
- self.bops = ru_mapping.get(RUTarget.BOPS, np.inf)
84
+ return bool(ru.weights_memory <= self.weights_memory and \
85
+ ru.activation_memory <= self.activation_memory and \
86
+ ru.total_memory <= self.total_memory and \
87
+ ru.bops <= self.bops)
108
88
 
109
- def holds_constraints(self, ru: Any) -> bool:
110
- """
111
- Checks whether the given ResourceUtilization object holds a set of ResourceUtilization constraints defined by
112
- the current ResourceUtilization object.
89
+ def get_restricted_metrics(self) -> Set[RUTarget]:
90
+ d = self.get_resource_utilization_dict()
91
+ return {k for k, v in d.items() if v < np.inf}
113
92
 
114
- Args:
115
- ru: A ResourceUtilization object to check if it holds the constraints.
116
-
117
- Returns: True if all the given resource utilization values are not greater than the referenced resource utilization values.
93
+ def is_any_restricted(self) -> bool:
94
+ return bool(self.get_restricted_metrics())
118
95
 
119
- """
120
- if not isinstance(ru, ResourceUtilization):
121
- return False
122
-
123
- return ru.weights_memory <= self.weights_memory and \
124
- ru.activation_memory <= self.activation_memory and \
125
- ru.total_memory <= self.total_memory and \
126
- ru.bops <= self.bops
96
+ def __repr__(self):
97
+ return f"Weights_memory: {self.weights_memory}, " \
98
+ f"Activation_memory: {self.activation_memory}, " \
99
+ f"Total_memory: {self.total_memory}, " \
100
+ f"BOPS: {self.bops}"