mct-nightly 2.3.0.20250511.614__py3-none-any.whl → 2.3.0.20250513.611__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. {mct_nightly-2.3.0.20250511.614.dist-info → mct_nightly-2.3.0.20250513.611.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.3.0.20250511.614.dist-info → mct_nightly-2.3.0.20250513.611.dist-info}/RECORD +23 -23
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/framework_implementation.py +6 -33
  5. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +22 -3
  6. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +8 -5
  7. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +69 -58
  8. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +82 -79
  9. model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py +32 -26
  10. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -4
  11. model_compression_toolkit/core/common/quantization/node_quantization_config.py +7 -0
  12. model_compression_toolkit/core/common/similarity_analyzer.py +1 -1
  13. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +37 -73
  14. model_compression_toolkit/core/keras/keras_implementation.py +8 -45
  15. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +7 -5
  16. model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +6 -5
  17. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +46 -78
  18. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +7 -9
  19. model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +12 -10
  20. model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -41
  21. {mct_nightly-2.3.0.20250511.614.dist-info → mct_nightly-2.3.0.20250513.611.dist-info}/WHEEL +0 -0
  22. {mct_nightly-2.3.0.20250511.614.dist-info → mct_nightly-2.3.0.20250513.611.dist-info}/licenses/LICENSE.md +0 -0
  23. {mct_nightly-2.3.0.20250511.614.dist-info → mct_nightly-2.3.0.20250513.611.dist-info}/top_level.txt +0 -0
@@ -12,16 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import contextlib
15
16
  import copy
17
+ import itertools
16
18
 
17
19
  import numpy as np
18
- from typing import Callable, Any, List, Tuple
20
+ from typing import Callable, Any, List, Tuple, Dict, Optional
19
21
 
20
- from model_compression_toolkit.constants import AXIS
21
22
  from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfig
22
23
  from model_compression_toolkit.core.common import Graph, BaseNode
24
+ from model_compression_toolkit.core.common.mixed_precision.set_layer_to_bitwidth import \
25
+ set_activation_quant_layer_to_bitwidth, set_weights_quant_layer_to_bitwidth
23
26
  from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
24
- from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
25
27
  from model_compression_toolkit.core.common.similarity_analyzer import compute_kl_divergence
26
28
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
27
29
  from model_compression_toolkit.logger import Logger
@@ -41,7 +43,6 @@ class SensitivityEvaluation:
41
43
  representative_data_gen: Callable,
42
44
  fw_info: FrameworkInfo,
43
45
  fw_impl: Any,
44
- set_layer_to_bitwidth: Callable,
45
46
  disable_activation_for_metric: bool = False,
46
47
  hessian_info_service: HessianInfoService = None
47
48
  ):
@@ -63,8 +64,6 @@ class SensitivityEvaluation:
63
64
  quant_config: MP Quantization configuration for how the graph should be quantized.
64
65
  representative_data_gen: Dataset used for getting batches for inference.
65
66
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
66
- set_layer_to_bitwidth: A fw-dependent function that allows to configure a configurable MP model
67
- with a specific bit-width configuration.
68
67
  disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
69
68
  hessian_info_service: HessianInfoService to fetch Hessian approximation information.
70
69
 
@@ -74,10 +73,9 @@ class SensitivityEvaluation:
74
73
  self.representative_data_gen = representative_data_gen
75
74
  self.fw_info = fw_info
76
75
  self.fw_impl = fw_impl
77
- self.set_layer_to_bitwidth = set_layer_to_bitwidth
78
76
  self.disable_activation_for_metric = disable_activation_for_metric
79
77
  if self.quant_config.use_hessian_based_scores:
80
- if not isinstance(hessian_info_service, HessianInfoService):
78
+ if not isinstance(hessian_info_service, HessianInfoService): # pragma: no cover
81
79
  Logger.critical(
82
80
  f"When using Hessian-based approximations for sensitivity evaluation, a valid HessianInfoService object is required; found {type(hessian_info_service)}.")
83
81
  self.hessian_info_service = hessian_info_service
@@ -159,44 +157,44 @@ class SensitivityEvaluation:
159
157
  axis_list.append(axis if distance_fn == compute_kl_divergence else None)
160
158
  return distance_fns_list, axis_list
161
159
 
162
- def compute_metric(self,
163
- mp_model_configuration: List[int],
164
- node_idx: List[int] = None,
165
- baseline_mp_configuration: List[int] = None) -> float:
160
+ def compute_metric(self, mp_a_cfg: Dict[str, Optional[int]], mp_w_cfg: Dict[str, Optional[int]]) -> float:
166
161
  """
167
162
  Compute the sensitivity metric of the MP model for a given configuration (the sensitivity
168
163
  is computed based on the similarity of the interest points' outputs between the MP model
169
164
  and the float model or a custom metric if given).
165
+ Quantization for any configurable activation / weight that were not passed is disabled.
170
166
 
171
167
  Args:
172
- mp_model_configuration: Bitwidth configuration to use to configure the MP model.
173
- node_idx: A list of nodes' indices to configure (instead of using the entire mp_model_configuration).
174
- baseline_mp_configuration: A mixed-precision configuration to set the model back to after modifying it to
175
- compute the metric for the given configuration.
168
+ mp_a_cfg: Bitwidth activations configuration for the MP model.
169
+ mp_w_cfg: Bitwidth weights configuration for the MP model.
176
170
 
177
171
  Returns:
178
172
  The sensitivity metric of the MP model for a given configuration.
179
173
  """
180
174
 
181
- # Configure MP model with the given configuration.
182
- self._configure_bitwidths_model(mp_model_configuration,
183
- node_idx)
175
+ with self._configured_mp_model(mp_a_cfg, mp_w_cfg):
176
+ sensitivity_metric = self._compute_metric()
184
177
 
185
- # Compute the distance metric
186
- if self.quant_config.custom_metric_fn is None:
187
- ipts_distances, out_pts_distances = self._compute_distance()
188
- sensitivity_metric = self._compute_mp_distance_measure(ipts_distances, out_pts_distances,
189
- self.quant_config.distance_weighting_method)
190
- else:
191
- sensitivity_metric = self.quant_config.custom_metric_fn(self.model_mp)
192
- if not isinstance(sensitivity_metric, (float, np.floating)):
193
- raise TypeError(f'The custom_metric_fn is expected to return float or numpy float, got {type(sensitivity_metric).__name__}')
178
+ return sensitivity_metric
194
179
 
195
- # Configure MP model back to the same configuration as the baseline model if baseline provided
196
- if baseline_mp_configuration is not None:
197
- self._configure_bitwidths_model(baseline_mp_configuration,
198
- node_idx)
180
+ def _compute_metric(self) -> float:
181
+ """
182
+ Compute sensitivity metric on a configured mp model.
199
183
 
184
+ Returns:
185
+ Sensitivity metric.
186
+ """
187
+ if self.quant_config.custom_metric_fn:
188
+ sensitivity_metric = self.quant_config.custom_metric_fn(self.model_mp)
189
+ if not isinstance(sensitivity_metric, (float, np.floating)):
190
+ raise TypeError(
191
+ f'The custom_metric_fn is expected to return float or numpy float, got {type(sensitivity_metric).__name__}')
192
+ return sensitivity_metric
193
+
194
+ # compute default metric
195
+ ipts_distances, out_pts_distances = self._compute_distance()
196
+ sensitivity_metric = self._compute_mp_distance_measure(ipts_distances, out_pts_distances,
197
+ self.quant_config.distance_weighting_method)
200
198
  return sensitivity_metric
201
199
 
202
200
  def _init_baseline_tensors_list(self):
@@ -217,17 +215,31 @@ class SensitivityEvaluation:
217
215
 
218
216
  evaluation_graph = copy.deepcopy(self.graph)
219
217
 
220
- if self.disable_activation_for_metric:
221
- for n in evaluation_graph.get_topo_sorted_nodes():
218
+ # Disable quantization for non-configurable nodes, and, if requested, for all activations (quantizers won't
219
+ # be added to the model).
220
+ for n in evaluation_graph.get_topo_sorted_nodes():
221
+ if self.disable_activation_for_metric or not n.has_configurable_activation():
222
222
  for c in n.candidates_quantization_cfg:
223
223
  c.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
224
+ if not n.has_any_configurable_weight():
225
+ for c in n.candidates_quantization_cfg:
226
+ c.weights_quantization_cfg.disable_all_weights_quantization()
224
227
 
225
228
  model_mp, _, conf_node2layers = self.fw_impl.model_builder(evaluation_graph,
226
229
  mode=ModelBuilderMode.MIXEDPRECISION,
227
230
  append2output=self.interest_points + self.output_points,
228
231
  fw_info=self.fw_info)
229
232
 
230
- # Build a baseline model.
233
+ # Disable all configurable quantizers. They will be activated one at a time during sensitivity evaluation.
234
+ # Note: from this point mp_model is not in sync with graph quantization configuration for configurable nodes.
235
+ for layer in itertools.chain(*conf_node2layers.values()):
236
+ if isinstance(layer, self.fw_impl.activation_quant_layer_cls):
237
+ set_activation_quant_layer_to_bitwidth(layer, None, self.fw_impl)
238
+ else:
239
+ assert isinstance(layer, self.fw_impl.weights_quant_layer_cls)
240
+ set_weights_quant_layer_to_bitwidth(layer, None, self.fw_impl)
241
+
242
+ # Build a baseline model (to compute distances from).
231
243
  baseline_model, _ = self.fw_impl.model_builder(evaluation_graph,
232
244
  mode=ModelBuilderMode.FLOAT,
233
245
  append2output=self.interest_points + self.output_points)
@@ -259,55 +271,46 @@ class SensitivityEvaluation:
259
271
  # Return the mean approximation value across all images for each interest point
260
272
  return np.mean(approx_by_image, axis=0)
261
273
 
262
- def _configure_bitwidths_model(self,
263
- mp_model_configuration: List[int],
264
- node_idx: List[int]):
265
- """
266
- Configure a dynamic model (namely, model with layers that their weights and activation
267
- bit-width can be configured) using an MP model configuration mp_model_configuration.
268
-
269
- Args:
270
- mp_model_configuration: Configuration of bit-width indices to set to the model.
271
- node_idx: List of nodes' indices to configure (the rest layers are configured as the baseline model).
274
+ @contextlib.contextmanager
275
+ def _configured_mp_model(self, mp_a_cfg: Dict[str, Optional[int]], mp_w_cfg: Dict[str, Optional[int]]):
272
276
  """
277
+ Context manager to configure specific configurable layers of the mp model. At exit, configuration is
278
+ automatically restored to un-quantized.
273
279
 
274
- # Configure model
275
- # Note: Not all nodes in the graph are included in the MP model that is returned by the model builder.
276
- # Thus, the last configurable layer must be included in the interest points for evaluating the metric,
277
- # otherwise, not all configurable nodes will be considered throughout the MP optimization search (since
278
- # they will not affect the metric value).
279
- if node_idx is not None: # configure specific layers in the mp model
280
- for node_idx_to_configure in node_idx:
281
- self._configure_node_bitwidth(self.sorted_configurable_nodes_names,
282
- mp_model_configuration, node_idx_to_configure)
283
- else: # use the entire mp_model_configuration to configure the model
284
- for node_idx_to_configure, bitwidth_idx in enumerate(mp_model_configuration):
285
- self._configure_node_bitwidth(self.sorted_configurable_nodes_names,
286
- mp_model_configuration, node_idx_to_configure)
287
-
288
- def _configure_node_bitwidth(self,
289
- sorted_configurable_nodes_names: List[str],
290
- mp_model_configuration: List[int],
291
- node_idx_to_configure: int):
292
- """
293
- Configures a node with multiple quantization candidates to the bitwidth candidate in the given index.
294
280
  Args:
295
- sorted_configurable_nodes_names: A list of configurable nodes names sorted according to the graph
296
- topological sort order.
297
- mp_model_configuration: Configuration of bit-width indices to set to the model.
298
- node_idx_to_configure: Quantization configuration candidate to configure.
299
-
300
- Returns:
281
+ mp_a_cfg: Nodes bitwidth indices to configure activation quantizers to.
282
+ mp_w_cfg: Nodes bitwidth indices to configure weights quantizers to.
301
283
 
302
284
  """
303
- node_name = sorted_configurable_nodes_names[node_idx_to_configure]
304
- layers_to_config = self.conf_node2layers.get(node_name, None)
305
- if layers_to_config is None:
306
- Logger.critical(
307
- f"Matching layers for node {node_name} not found in the mixed precision model configuration.") # pragma: no cover
308
-
309
- for current_layer in layers_to_config:
310
- self.set_layer_to_bitwidth(current_layer, mp_model_configuration[node_idx_to_configure])
285
+ if not (mp_a_cfg and any(v is not None for v in mp_a_cfg.values()) or
286
+ mp_w_cfg and any(v is not None for v in mp_w_cfg.values())):
287
+ raise ValueError(f'Requested configuration is either empty or contain only None values.')
288
+
289
+ # defined here so that it can't be used directly
290
+ def apply_bitwidth_config(a_cfg, w_cfg):
291
+ node_names = set(a_cfg.keys()).union(set(w_cfg.keys()))
292
+ for n in node_names:
293
+ node_quant_layers = self.conf_node2layers.get(n)
294
+ if node_quant_layers is None: # pragma: no cover
295
+ raise ValueError(f"Matching layers for node {n} not found in the mixed precision model configuration.")
296
+ for qlayer in node_quant_layers:
297
+ assert isinstance(qlayer, (self.fw_impl.activation_quant_layer_cls,
298
+ self.fw_impl.weights_quant_layer_cls)), f'Unexpected {type(qlayer)} of node {n}'
299
+ if isinstance(qlayer, self.fw_impl.activation_quant_layer_cls) and n in a_cfg:
300
+ set_activation_quant_layer_to_bitwidth(qlayer, a_cfg[n], self.fw_impl)
301
+ a_cfg.pop(n)
302
+ elif isinstance(qlayer, self.fw_impl.weights_quant_layer_cls) and n in w_cfg:
303
+ set_weights_quant_layer_to_bitwidth(qlayer, w_cfg[n], self.fw_impl)
304
+ w_cfg.pop(n)
305
+ if a_cfg or w_cfg:
306
+ raise ValueError(f'Not all mp configs were consumed, remaining activation config {a_cfg}, '
307
+ f'weights config {w_cfg}.')
308
+
309
+ apply_bitwidth_config(mp_a_cfg.copy(), mp_w_cfg.copy())
310
+ try:
311
+ yield
312
+ finally:
313
+ apply_bitwidth_config({n: None for n in mp_a_cfg}, {n: None for n in mp_w_cfg})
311
314
 
312
315
  def _compute_points_distance(self,
313
316
  baseline_tensors: List[Any],
@@ -12,39 +12,45 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Any
15
+ import typing
16
+ from typing import Any, Optional
16
17
 
18
+ if typing.TYPE_CHECKING: # pragma: no cover
19
+ from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
17
20
 
18
- def set_layer_to_bitwidth(quantization_layer: Any,
19
- bitwidth_idx: int,
20
- weights_quantizer_type: type,
21
- activation_quantizer_type: type,
22
- weights_quant_layer_type: type,
23
- activation_quant_layer_type: type):
21
+
22
+ def set_activation_quant_layer_to_bitwidth(quantization_layer: Any,
23
+ bitwidth_idx: Optional[int],
24
+ fw_impl: 'FrameworkImplementation'):
24
25
  """
25
- Configures a layer's configurable quantizer to work with a different bit-width.
26
+ Configures a layer's configurable activation quantizer to work with a different bit-width.
26
27
  The bit-width_idx is the index of the actual quantizer the quantizer object in the quantization_layer wraps/holds.
27
28
 
28
29
  Args:
29
30
  quantization_layer: Layer to change its bit-width.
30
- bitwidth_idx: Index of the bit-width the layer should work with.
31
- weights_quantizer_type: A class of weights quantizer with configurable bitwidth options.
32
- activation_quantizer_type: A class of activation quantizer with configurable bitwidth options.
33
- weights_quant_layer_type: A class of a weights layer wrapper.
34
- activation_quant_layer_type: A class of an activation quantization holder.
31
+ bitwidth_idx: Index of the bit-width the layer should work with, or None to disable quantization.
32
+ fw_impl: framework implementation object.
35
33
  """
34
+ assert isinstance(quantization_layer, fw_impl.activation_quant_layer_cls)
35
+ assert isinstance(quantization_layer.activation_holder_quantizer, fw_impl.configurable_activation_quantizer_cls)
36
+ quantization_layer.activation_holder_quantizer.set_active_activation_quantizer(bitwidth_idx)
37
+
36
38
 
37
- if isinstance(quantization_layer, weights_quant_layer_type):
38
- for _, quantizer in quantization_layer.weights_quantizers.items():
39
- if isinstance(quantizer, weights_quantizer_type):
40
- # Setting bitwidth only for configurable layers. There might be wrapped layers that aren't configurable,
41
- # for instance, if only activations are quantized with mixed precision and weights are quantized with
42
- # fixed precision
43
- quantizer.set_weights_bit_width_index(bitwidth_idx)
39
+ def set_weights_quant_layer_to_bitwidth(quantization_layer: Any,
40
+ bitwidth_idx: Optional[int],
41
+ fw_impl: 'FrameworkImplementation'):
42
+ """
43
+ Configures a layer's configurable weights quantizer to work with a different bit-width.
44
+ The bit-width_idx is the index of the actual quantizer the quantizer object in the quantization_layer wraps/holds.
44
45
 
45
- if isinstance(quantization_layer, activation_quant_layer_type):
46
- if isinstance(quantization_layer.activation_holder_quantizer, activation_quantizer_type):
47
- # Setting bitwidth only for configurable layers. There might be activation layers that isn't configurable,
48
- # for instance, if only weights are quantized with mixed precision and activation are quantized with
49
- # fixed precision
50
- quantization_layer.activation_holder_quantizer.set_active_activation_quantizer(bitwidth_idx)
46
+ Args:
47
+ quantization_layer: Layer to change its bit-width.
48
+ bitwidth_idx: Index of the bit-width the layer should work with, or None to disable quantization.
49
+ fw_impl: framework implementation object.
50
+ """
51
+ assert isinstance(quantization_layer, fw_impl.weights_quant_layer_cls)
52
+ configurable_quantizers = [q for q in quantization_layer.weights_quantizers.values()
53
+ if isinstance(q, fw_impl.configurable_weights_quantizer_cls)]
54
+ assert configurable_quantizers
55
+ for quantizer in configurable_quantizers:
56
+ quantizer.set_weights_bit_width_index(bitwidth_idx)
@@ -104,10 +104,11 @@ def greedy_solution_refinement_procedure(mp_solution: Dict[BaseNode, int],
104
104
  new_solution[node_idx_to_upgrade] = nodes_next_candidate[node_idx_to_upgrade]
105
105
  changed = True
106
106
 
107
- if any([mp_solution[n] != new_solution[n] for n in mp_solution]):
108
- Logger.info(f'Greedy MP algorithm changed configuration from (numbers represent indices of the '
109
- f'chosen bit-width candidate for each layer):\n{mp_solution}\nto\n{new_solution}')
110
-
107
+ changed_solutions = {n: (sol, new_solution[n]) for n, sol in mp_solution.items() if sol != new_solution[n]}
108
+ if changed_solutions:
109
+ msg = '\n'.join(f'{n.name}: {mp_solution[n]} -> {new_solution[n]}' for n in changed_solutions)
110
+ Logger.info(f'Greedy MP algorithm changed configuration for {len(changed_solutions)} out of {len(mp_solution)} '
111
+ f'layers (numbers represent indices of the chosen bit-width candidate for each layer):\n{msg}')
111
112
  return new_solution
112
113
 
113
114
 
@@ -549,6 +549,13 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
549
549
  """
550
550
  return {attr: self.get_attr_config(attr) for attr in self.all_weight_attrs}
551
551
 
552
+ def disable_all_weights_quantization(self):
553
+ """ Disable quantization for all weights. """
554
+ for w_cfg in self.pos_attributes_config_mapping.values():
555
+ w_cfg.enable_weights_quantization = False
556
+ for w_cfg in self.attributes_config_mapping.values():
557
+ w_cfg.enable_weights_quantization = False
558
+
552
559
  def _extract_config_for_attributes_with_name(self, attr_name) -> Dict[str, WeightsAttrQuantizationConfig]:
553
560
  """
554
561
  Extract the saved attributes that contain the given attribute name.
@@ -194,7 +194,7 @@ def compute_cs(float_tensor: np.ndarray,
194
194
  cs = np.sum(float_flat * fxp_flat, axis=axis) / ((float_norm * fxp_norm) + eps)
195
195
 
196
196
  # Return a non-negative float (smaller value -> more similarity)
197
- return (1.0 - cs) / 2.0
197
+ return np.maximum((1.0 - cs) / 2.0, 0)
198
198
 
199
199
 
200
200
  def compute_lp_norm(float_tensor: np.ndarray,
@@ -14,17 +14,16 @@
14
14
  # ==============================================================================
15
15
  from typing import Tuple, Any, Dict, Union, List
16
16
 
17
- from packaging import version
18
17
  import tensorflow as tf
18
+ from packaging import version
19
+
19
20
  if version.parse(tf.__version__) >= version.parse("2.13"):
20
21
  from keras.src.engine.base_layer import Layer
21
22
  else:
22
23
  from keras.engine.base_layer import Layer # pragma: no cover
23
24
 
24
25
  from keras.models import Model
25
- from mct_quantizers import KerasQuantizationWrapper, KerasActivationQuantizationHolder, QuantizationTarget
26
- from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
27
- from mct_quantizers.keras.quantizers import BaseKerasInferableQuantizer
26
+ from mct_quantizers import KerasQuantizationWrapper, KerasActivationQuantizationHolder
28
27
 
29
28
  from model_compression_toolkit.core.common import BaseNode
30
29
  from model_compression_toolkit.core.common.user_info import UserInformation
@@ -34,9 +33,6 @@ from model_compression_toolkit.core.keras.mixed_precision.configurable_activatio
34
33
  from model_compression_toolkit.core.keras.mixed_precision.configurable_weights_quantizer import \
35
34
  ConfigurableWeightsQuantizer
36
35
 
37
- from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \
38
- get_inferable_quantizer_kwargs
39
-
40
36
  from model_compression_toolkit.logger import Logger
41
37
  from model_compression_toolkit.core import common
42
38
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
@@ -75,6 +71,7 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
75
71
  n: common.BaseNode,
76
72
  layer: Layer) -> Union[KerasQuantizationWrapper, Layer]:
77
73
  """
74
+
78
75
  A function which takes a computational graph node and a keras layer and perform the quantization
79
76
  wrapping for mixed precision.
80
77
 
@@ -82,40 +79,21 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
82
79
  n: A node of mct graph.
83
80
  layer: A keras layer
84
81
 
85
- Returns: Wrapped layer with a configurable quantizer if the layer should quantized in mixed precision,
86
- otherwise returns either the layer wrapped with a fixed precision inferable quantizer or the layer as is if it's
87
- not supposed to be quantized.
82
+ Returns:
83
+ Wrapped layer with a configurable quantizer if the layer should be quantized in mixed precision, or the
84
+ layer as is.
88
85
 
86
+ Raises:
87
+ ValueError: if kernel attribute is quantized but not configurable.
89
88
  """
90
89
 
91
90
  kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
92
- if kernel_attr is not None and n.is_weights_quantization_enabled(kernel_attr):
93
- weights_conf_nodes_names = [node.name for node in self.graph.get_weights_configurable_nodes(self.fw_info)]
94
- if n.name in weights_conf_nodes_names:
95
- wq = ConfigurableWeightsQuantizer(**self._get_weights_configurable_quantizer_kwargs(n, kernel_attr))
96
- return KerasQuantizationWrapper(layer, weights_quantizers={kernel_attr: wq})
97
- else:
98
- # TODO: Do we want to include other quantized attributes that are not
99
- # the kernel attribute in the mixed precision model?
100
- # Currently, we only consider kernel attribute quantization (whether it is in mixed precision
101
- # or single precision).
102
- node_weights_qc = n.get_unique_weights_candidates(kernel_attr)
103
- if not len(node_weights_qc) == 1:
104
- Logger.critical(f"Expected a unique weights configuration for node {n.name}, but found {len(node_weights_qc)} configurations.")# pragma: no cover
105
-
106
- weights_quant_cfg = node_weights_qc[0].weights_quantization_cfg
107
- weights_quant_method = weights_quant_cfg.get_attr_config(kernel_attr).weights_quantization_method
108
- quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Weights,
109
- weights_quant_method,
110
- BaseKerasInferableQuantizer)
111
- kwargs = get_inferable_quantizer_kwargs(weights_quant_cfg,
112
- QuantizationTarget.Weights,
113
- kernel_attr)
114
-
115
- return KerasQuantizationWrapper(layer,
116
- weights_quantizers={kernel_attr: quantier_for_node(**kwargs)})
117
-
118
- return layer
91
+ if kernel_attr is None or not n.is_weights_quantization_enabled(kernel_attr):
92
+ return layer
93
+ if not n.is_configurable_weight(kernel_attr): # pragma: no cover
94
+ raise ValueError(f'Weight wrapper is not expected to be created for non-configurable weight of node {n}.')
95
+ wq = ConfigurableWeightsQuantizer(**self._get_weights_configurable_quantizer_kwargs(n, kernel_attr))
96
+ return KerasQuantizationWrapper(layer, weights_quantizers={kernel_attr: wq})
119
97
 
120
98
  def _get_weights_configurable_quantizer_kwargs(self, n: BaseNode, attr: str) -> Dict[str, Any]:
121
99
  """
@@ -147,50 +125,36 @@ class MixedPrecisionKerasModelBuilder(KerasModelBuilder):
147
125
 
148
126
  def mixed_precision_activation_holder(self, n: BaseNode) -> KerasActivationQuantizationHolder:
149
127
  """
150
- Retrieve a KerasActivationQuantizationHolder layer to use for activation quantization for a node.
151
- The layer should hold either a configurable activation quantizer, if it is quantized with mixed precision,
152
- or an inferable quantizer for fixed single bit-width quantization.
128
+ Builds KerasActivationQuantizationHolder layer with a configurable quantizer for mixed precision for a node
129
+ with a configurable activation.
153
130
 
154
131
  Args:
155
132
  n: Node to get KerasActivationQuantizationHolder to attach in its output.
156
133
 
157
134
  Returns:
158
135
  A KerasActivationQuantizationHolder layer for the node activation quantization.
136
+
137
+ Raises:
138
+ ValueError: if node's activation is not configurable.
159
139
  """
140
+ if not n.has_configurable_activation(): # pragma: no cover
141
+ raise ValueError(f'Activation holder is not expected to be created for a non-configurable activation of '
142
+ f'node {n}')
143
+ num_of_outputs = len(n.output_shape) if isinstance(n.output_shape, list) else 1
144
+ node_q_cfg_candidates = n.candidates_quantization_cfg
145
+
146
+ # sorting the candidates by kernel attribute weights number of bits first and then by
147
+ # activation number of bits (in reversed order).
148
+ # since only kernel attribute is quantized in weights mixed precision,
149
+ # if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
150
+ n.sort_node_candidates(self.fw_info)
160
151
 
161
- activation_conf_nodes_names = [n.name for n in self.graph.get_activation_configurable_nodes()]
162
-
163
- activation_quantizers = []
164
- if n.is_activation_quantization_enabled():
165
- num_of_outputs = len(n.output_shape) if isinstance(n.output_shape, list) else 1
166
-
167
- if n.name in activation_conf_nodes_names:
168
- assert n.candidates_quantization_cfg is not None, f"Node {n.name} candidates_quantization_cfg is None"
169
- node_q_cfg_candidates = n.candidates_quantization_cfg
170
-
171
- # sorting the candidates by kernel attribute weights number of bits first and then by
172
- # activation number of bits (in reversed order).
173
- # since only kernel attribute is quantized in weights mixed precision,
174
- # if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
175
- n.sort_node_candidates(self.fw_info)
176
-
177
- max_candidate_idx = n.find_max_candidate_index()
178
- kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
179
- activation_quantizers = [ConfigurableActivationQuantizer(**{'node_q_cfg': node_q_cfg_candidates,
180
- 'max_candidate_idx': max_candidate_idx,
181
- 'kernel_attr': kernel_attr})] \
182
- * num_of_outputs
183
- else:
184
- node_act_qc = n.get_unique_activation_candidates()
185
- assert len(node_act_qc) == 1, f"Expecting node {n.name} to have a unique activation configuration, " \
186
- f"but {len(node_act_qc)} different configurations exist."
187
- quantizer_for_node = get_inferable_quantizer_class(QuantizationTarget.Activation,
188
- node_act_qc[0].activation_quantization_cfg.activation_quantization_method,
189
- BaseKerasInferableQuantizer)
190
- kwargs = get_inferable_quantizer_kwargs(node_act_qc[0].activation_quantization_cfg,
191
- QuantizationTarget.Activation)
192
-
193
- activation_quantizers = [quantizer_for_node(**kwargs)] * num_of_outputs
152
+ max_candidate_idx = n.find_max_candidate_index()
153
+ kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
154
+ activation_quantizers = [ConfigurableActivationQuantizer(**{'node_q_cfg': node_q_cfg_candidates,
155
+ 'max_candidate_idx': max_candidate_idx,
156
+ 'kernel_attr': kernel_attr})] \
157
+ * num_of_outputs
194
158
 
195
159
  # Holder by definition uses a single quantizer for the activation quantization
196
160
  # thus we make sure this is the only possible case (unless it's a node with no activation
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from functools import partial
16
- from typing import List, Any, Tuple, Callable, Dict, Union, Generator
16
+ from typing import List, Any, Tuple, Callable, Union, Generator
17
17
 
18
18
  import numpy as np
19
19
  import tensorflow as tf
@@ -22,7 +22,7 @@ from tensorflow.keras.models import Model
22
22
 
23
23
  from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
24
24
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
25
- from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode, HessianInfoService
25
+ from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode
26
26
  from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader
27
27
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.remove_identity import RemoveIdentity
28
28
  from model_compression_toolkit.core.keras.hessian.activation_hessian_scores_calculator_keras import \
@@ -35,8 +35,6 @@ from model_compression_toolkit.exporter.model_wrapper.fw_agnostic.get_inferable_
35
35
  from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \
36
36
  get_weights_quantizer_for_node, get_activations_quantizer_for_node
37
37
  from model_compression_toolkit.logger import Logger
38
- from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
39
- from model_compression_toolkit.core.common.mixed_precision.set_layer_to_bitwidth import set_layer_to_bitwidth
40
38
  from model_compression_toolkit.core.common.similarity_analyzer import compute_kl_divergence, compute_cs, compute_mse
41
39
  from model_compression_toolkit.core.keras.constants import ACTIVATION, SOFTMAX, SIGMOID, ARGMAX, LAYER_NAME, \
42
40
  COMBINED_NMS
@@ -61,7 +59,7 @@ else:
61
59
  from keras.layers import Dense, Activation, Conv2D, DepthwiseConv2D, Conv2DTranspose, Concatenate, Add # pragma: no cover
62
60
  from keras.layers.core import TFOpLambda # pragma: no cover
63
61
 
64
- from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfig
62
+ from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig
65
63
  from model_compression_toolkit.core import common
66
64
  from model_compression_toolkit.core.common import Graph, BaseNode
67
65
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
@@ -95,7 +93,7 @@ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.mult
95
93
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.scale_equalization import \
96
94
  ScaleEqualization, ScaleEqualizationWithPad, ScaleEqualizationMidActivation, ScaleEqualizationMidActivationWithPad
97
95
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.separableconv_decomposition import \
98
- SeparableConvDecomposition, DEPTH_MULTIPLIER
96
+ SeparableConvDecomposition
99
97
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.shift_negative_activation import \
100
98
  keras_apply_shift_negative_correction
101
99
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.dwconv_to_conv import DwconvToConv
@@ -110,9 +108,10 @@ class KerasImplementation(FrameworkImplementation):
110
108
  """
111
109
  A class with implemented methods to support optimizing Keras models.
112
110
  """
113
-
114
- def __init__(self):
115
- super().__init__()
111
+ weights_quant_layer_cls = KerasQuantizationWrapper
112
+ activation_quant_layer_cls = KerasActivationQuantizationHolder
113
+ configurable_weights_quantizer_cls = ConfigurableWeightsQuantizer
114
+ configurable_activation_quantizer_cls = ConfigurableActivationQuantizer
116
115
 
117
116
  @property
118
117
  def constants(self):
@@ -401,42 +400,6 @@ class KerasImplementation(FrameworkImplementation):
401
400
  substitutions_list.append(keras_batchnorm_refusing())
402
401
  return substitutions_list
403
402
 
404
- def get_sensitivity_evaluator(self,
405
- graph: Graph,
406
- quant_config: MixedPrecisionQuantizationConfig,
407
- representative_data_gen: Callable,
408
- fw_info: FrameworkInfo,
409
- disable_activation_for_metric: bool = False,
410
- hessian_info_service: HessianInfoService = None) -> SensitivityEvaluation:
411
- """
412
- Creates and returns an object which handles the computation of a sensitivity metric for a mixed-precision
413
- configuration (comparing to the float model).
414
-
415
- Args:
416
- graph: Graph to build its float and mixed-precision models.
417
- quant_config: QuantizationConfig of how the model should be quantized.
418
- representative_data_gen: Dataset to use for retrieving images for the models inputs.
419
- fw_info: FrameworkInfo object with information about the specific framework's model.
420
- disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
421
- hessian_info_service: HessianScoresService to fetch scores based on a Hessian-approximation for the float model.
422
-
423
- Returns:
424
- A SensitivityEvaluation object.
425
- """
426
-
427
- return SensitivityEvaluation(graph=graph,
428
- quant_config=quant_config,
429
- representative_data_gen=representative_data_gen,
430
- fw_info=fw_info,
431
- fw_impl=self,
432
- set_layer_to_bitwidth=partial(set_layer_to_bitwidth,
433
- weights_quantizer_type=ConfigurableWeightsQuantizer,
434
- activation_quantizer_type=ConfigurableActivationQuantizer,
435
- weights_quant_layer_type=KerasQuantizationWrapper,
436
- activation_quant_layer_type=KerasActivationQuantizationHolder),
437
- disable_activation_for_metric=disable_activation_for_metric,
438
- hessian_info_service=hessian_info_service)
439
-
440
403
  def get_node_prior_info(self,
441
404
  node: BaseNode,
442
405
  fw_info: FrameworkInfo,