mct-nightly 2.3.0.20250224.520__py3-none-any.whl → 2.3.0.20250225.512__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. {mct_nightly-2.3.0.20250224.520.dist-info → mct_nightly-2.3.0.20250225.512.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.3.0.20250224.520.dist-info → mct_nightly-2.3.0.20250225.512.dist-info}/RECORD +23 -22
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/collectors/histogram_collector.py +19 -20
  5. model_compression_toolkit/core/common/collectors/statistics_collector.py +7 -3
  6. model_compression_toolkit/core/common/collectors/weighted_histogram_collector.py +114 -0
  7. model_compression_toolkit/core/common/framework_implementation.py +9 -4
  8. model_compression_toolkit/core/common/graph/base_node.py +16 -6
  9. model_compression_toolkit/core/common/hessian/hessian_info_service.py +31 -15
  10. model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py +1 -1
  11. model_compression_toolkit/core/common/hessian/hessian_scores_request.py +7 -2
  12. model_compression_toolkit/core/common/model_collector.py +115 -17
  13. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +2 -0
  14. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +110 -33
  15. model_compression_toolkit/core/keras/keras_implementation.py +35 -27
  16. model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +23 -61
  17. model_compression_toolkit/core/pytorch/pytorch_implementation.py +34 -18
  18. model_compression_toolkit/core/quantization_prep_runner.py +1 -0
  19. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2fw.py +2 -2
  20. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py +2 -1
  21. {mct_nightly-2.3.0.20250224.520.dist-info → mct_nightly-2.3.0.20250225.512.dist-info}/LICENSE.md +0 -0
  22. {mct_nightly-2.3.0.20250224.520.dist-info → mct_nightly-2.3.0.20250225.512.dist-info}/WHEEL +0 -0
  23. {mct_nightly-2.3.0.20250224.520.dist-info → mct_nightly-2.3.0.20250225.512.dist-info}/top_level.txt +0 -0
@@ -15,13 +15,15 @@
15
15
 
16
16
 
17
17
  import numpy as np
18
- from typing import List
18
+ from typing import List, Union, Tuple, Optional
19
19
 
20
20
  from networkx.algorithms.dag import topological_sort
21
- from model_compression_toolkit.core import FrameworkInfo
21
+ from model_compression_toolkit.core import FrameworkInfo, QuantizationErrorMethod
22
22
  from model_compression_toolkit.core import common
23
23
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
24
24
  from model_compression_toolkit.core.common.graph.base_graph import Graph
25
+ from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresGranularity, HessianMode, \
26
+ HessianScoresRequest
25
27
  from model_compression_toolkit.logger import Logger
26
28
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
27
29
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
@@ -66,12 +68,67 @@ def create_tensor2node(graph: common.Graph,
66
68
 
67
69
  """
68
70
  current_sc = graph.get_out_stats_collector(node)
69
- is_list_nostat_collectors = isinstance(current_sc, list) and len([sc for sc in current_sc if not isinstance(sc, common.NoStatsCollector)]) == 0
71
+ is_list_nostat_collectors = isinstance(current_sc, list) and len(
72
+ [sc for sc in current_sc if not isinstance(sc, common.NoStatsCollector)]) == 0
70
73
  if isinstance(current_sc, common.NoStatsCollector) or current_sc is None or is_list_nostat_collectors:
71
74
  stats_collector = common.StatsCollector(fw_info.out_channel_axis_mapping.get(node.type))
72
75
  graph.set_out_stats_collector_to_node(node, stats_collector)
73
76
 
74
77
 
78
+ def ensure_matching_data_lengths(
79
+ stats_collector: Union[List[BaseStatsCollector], Tuple[BaseStatsCollector, ...]],
80
+ tensor_data: Union[List, Tuple],
81
+ hessian_data: Union[List, Tuple]
82
+ ):
83
+ """
84
+ Ensures that the lengths of `tensor_data`, `hessian_data`, and `stats_collector` are matching.
85
+ If the types or lengths do not match, a critical error is logged.
86
+
87
+ Args:
88
+ stats_collector: A list or tuple of statistics collectors.
89
+ tensor_data: A list or tuple of tensors corresponding to the statistics collectors.
90
+ hessian_data: A list or tuple of Hessian tensors corresponding to the statistics collectors.
91
+
92
+ Raises:
93
+ Logs a critical error and halts execution if there is a type mismatch or
94
+ if the lengths of the inputs do not match.
95
+ """
96
+
97
+ if not isinstance(tensor_data, (list, tuple)):
98
+ Logger.critical(
99
+ f"'tensor_data' is of type {type(tensor_data)}, but must be of the same type as 'stats_collector' ({type(stats_collector)})."
100
+ ) # pragma: no cover
101
+
102
+ if len(stats_collector) != len(tensor_data):
103
+ Logger.critical(
104
+ "'tensor_data' and 'stats_collector' must have matching lengths."
105
+ ) # pragma: no cover
106
+
107
+ if not isinstance(hessian_data, (list, tuple)):
108
+ Logger.critical(
109
+ f"'hessian_data' is of type {type(hessian_data)}, but must be of the same type as 'stats_collector' ({type(stats_collector)})."
110
+ ) # pragma: no cover
111
+
112
+ if len(stats_collector) != len(hessian_data):
113
+ Logger.critical(
114
+ "'hessian_data' and 'stats_collector' must have matching lengths."
115
+ ) # pragma: no cover
116
+
117
+
118
+ def convert_to_numpy_and_abs(tensor: Optional[np.ndarray], fw_impl: FrameworkImplementation) -> Optional[np.ndarray]:
119
+ """
120
+ Converts a tensor to a NumPy array and applies the absolute value operation if the tensor is not None.
121
+
122
+ Args:
123
+ tensor: Input tensor to be converted to a NumPy array.
124
+ fw_impl: Framework implementation that provides the 'to_numpy' method for tensor conversion.
125
+
126
+ Returns:
127
+ A NumPy array of the input tensor with absolute values applied. If the input tensor is None, returns None.
128
+ """
129
+ return tensor if tensor is None else np.abs(fw_impl.to_numpy(tensor))
130
+
131
+
75
132
  class ModelCollector:
76
133
  """
77
134
  Build a model from a graph for statistics collection purposes.
@@ -83,6 +140,7 @@ class ModelCollector:
83
140
  def __init__(self, graph: Graph,
84
141
  fw_impl: FrameworkImplementation,
85
142
  fw_info: FrameworkInfo,
143
+ hessian_info_service: HessianInfoService = None,
86
144
  qc: common.QuantizationConfig = common.DEFAULTCONFIG):
87
145
  """
88
146
  Build a model from a graph per framework for statistics collection.
@@ -96,14 +154,18 @@ class ModelCollector:
96
154
 
97
155
  self.fw_impl = fw_impl
98
156
  self.fw_info = fw_info
157
+ self.hessian_service = hessian_info_service
158
+ self.qc = qc
159
+ self.model_outputs = [out.node for out in graph.get_outputs()]
99
160
 
100
- # Assign statisitcs collectors to nodes
161
+ # Assign statistics collectors to nodes
101
162
  for n in graph.get_topo_sorted_nodes():
102
163
  sc = create_stats_collector_for_node(n, fw_info=fw_info) # Get static collector for the node
103
164
  # If we use bias correction, and the node has kernel weights to quantize, we need to make sure
104
165
  # its previous nodes' tensors are consistent with this node.
105
166
  kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0]
106
- if qc.weights_bias_correction and kernel_attr is not None and n.is_weights_quantization_enabled(kernel_attr):
167
+ if qc.weights_bias_correction and kernel_attr is not None and n.is_weights_quantization_enabled(
168
+ kernel_attr):
107
169
  for ie in graph.incoming_edges(n):
108
170
  input_node = ie.source_node
109
171
  create_tensor2node(graph,
@@ -112,7 +174,6 @@ class ModelCollector:
112
174
  if sc is not None:
113
175
  graph.set_out_stats_collector_to_node(n, sc)
114
176
 
115
-
116
177
  outputs_nodes = [] # List of graph nodes, the model should output their outputs.
117
178
  self.stats_containers_list = [] # List of output statistics containers of nodes ordered
118
179
  # the same as outputs_nodes so statistics of outputs can be gathered for the correct statistics container.
@@ -135,11 +196,19 @@ class ModelCollector:
135
196
  outputs_nodes.append(n)
136
197
  self.stats_containers_list.append(out_stats_container)
137
198
 
199
+ self.intermediate_output_tensors = [n for n in outputs_nodes if n not in self.model_outputs]
200
+
201
+ # Append nodes from graph.get_outputs() that are not already in outputs_nodes for Hessian
202
+ # calculation for output nodes that don't collect statistics such as "permute", "transpose" etc.
203
+ # TODO: Add integration test for this case
204
+ append2output = outputs_nodes + [n for n in self.model_outputs if n not in outputs_nodes]
205
+
206
+
138
207
  # Build a float model and output all layers' outputs
139
208
  # (that should be collected) as the model's outputs
140
209
  self.model, _ = self.fw_impl.model_builder(graph,
141
210
  mode=ModelBuilderMode.FLOAT,
142
- append2output=outputs_nodes,
211
+ append2output=append2output,
143
212
  fw_info=self.fw_info)
144
213
 
145
214
  def infer(self, inputs_list: List[np.ndarray]):
@@ -154,14 +223,43 @@ class ModelCollector:
154
223
 
155
224
  # TODO: Thinking about delegating collections to framework
156
225
  # TODO: migrate datasets to framework datasets
157
- tensor_data = self.fw_impl.run_model_inference(self.model, inputs_list)
158
- for td, sc in zip(tensor_data, self.stats_containers_list):
159
- if isinstance(sc, (list, tuple)):
160
- if not isinstance(td, (list, tuple)):
161
- Logger.critical(f"\'tensor_data\' is of type {type(td)} but must be of the same type as \'stats_containers_list\', which is of type {type(sc)}") # pragma: no cover
162
- if len(sc) != len(td):
163
- Logger.critical('\'tensor_data\' and \'stats_containers_list\' must have matching lengths') # pragma: no cover
164
- for tdi, sci in zip(td, sc):
165
- sci.update_statistics(self.fw_impl.to_numpy(tdi))
226
+ compute_hessians = self.qc.activation_error_method == QuantizationErrorMethod.HMSE
227
+
228
+ # Retrieve intermediate layer activations for statistical analysis.
229
+ # Enable gradient computation if Hessian calculations are required.
230
+ activation_tensors = self.fw_impl.run_model_inference(self.model, inputs_list, requires_grad=compute_hessians)
231
+
232
+ if compute_hessians:
233
+ if self.hessian_service is None:
234
+ Logger.critical(
235
+ "Hessian computation is enabled but `hessian_service` is not initialized. "
236
+ "Ensure that `hessian_service` is properly set."
237
+ ) # pragma: no cover
238
+ request = HessianScoresRequest(
239
+ mode=HessianMode.ACTIVATION,
240
+ granularity=HessianScoresGranularity.PER_ELEMENT,
241
+ target_nodes=self.intermediate_output_tensors,
242
+ data_loader=None,
243
+ n_samples=None,
244
+ compute_from_tensors=True
245
+ )
246
+ hessian_tensors = self.hessian_service.fetch_hessian(request=request,
247
+ activation_tensors=activation_tensors)
248
+ hessian_tensors = list(hessian_tensors.values())
249
+ else:
250
+ hessian_tensors = []
251
+
252
+ # Hessian is not calculated for the output, add "None" as weights for output tenosrs
253
+ hessian_tensors += [None for _ in range(len(activation_tensors) - len(hessian_tensors))]
254
+
255
+ for activation_tensor, hessian_tensor, stats_container in zip(activation_tensors, hessian_tensors, self.stats_containers_list):
256
+ if isinstance(stats_container, (list, tuple)):
257
+ if hessian_tensor is None:
258
+ hessian_tensor = [None for _ in range(len(activation_tensor))]
259
+ ensure_matching_data_lengths(activation_tensor, hessian_tensor, stats_container)
260
+ for activation_tensor_i, hessian_tensor_i, sci in zip(activation_tensor, hessian_tensor, stats_container):
261
+ sci.update_statistics(self.fw_impl.to_numpy(activation_tensor_i),
262
+ convert_to_numpy_and_abs(hessian_tensor_i, self.fw_impl))
166
263
  else:
167
- sc.update_statistics(self.fw_impl.to_numpy(td))
264
+ stats_container.update_statistics(self.fw_impl.to_numpy(activation_tensor),
265
+ convert_to_numpy_and_abs(hessian_tensor, self.fw_impl))
@@ -514,6 +514,8 @@ def get_threshold_selection_histogram_error_function(quantization_method: Quanti
514
514
  quant_method_error_function_mapping = {
515
515
  qc.QuantizationErrorMethod.MSE: lambda q_bins, q_count, bins, counts, threshold, _range:
516
516
  _mse_error_histogram(q_bins, q_count, bins, counts),
517
+ qc.QuantizationErrorMethod.HMSE: lambda q_bins, q_count, bins, counts, threshold, _range:
518
+ _mse_error_histogram(q_bins, q_count, bins, counts), #HMSE need the same functionality as MSE
517
519
  qc.QuantizationErrorMethod.MAE: lambda q_bins, q_count, bins, counts, threshold, _range:
518
520
  _mae_error_histogram(q_bins, q_count, bins, counts),
519
521
  qc.QuantizationErrorMethod.LP: lambda q_bins, q_count, bins, counts, threshold, _range:
@@ -13,66 +13,143 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import numpy as np
16
- from typing import Dict, Union
16
+ from typing import Dict, Union, Optional, Tuple
17
17
 
18
18
  from mct_quantizers import QuantizationMethod
19
+ from model_compression_toolkit.core import QuantizationErrorMethod
19
20
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
20
21
  from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
21
22
  from model_compression_toolkit.core.common.quantization import quantization_params_generation
22
23
  from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
23
24
  from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
24
25
 
25
-
26
- def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConfig,
27
- nodes_prior_info: NodePriorInfo,
28
- out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]:
26
+ def get_histogram_data(
27
+ activation_quant_cfg: NodeActivationQuantizationConfig,
28
+ out_stats_container: BaseStatsCollector
29
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
29
30
  """
30
- Compute the activations params for a given node in a graph according to a params function.
31
+ Extract and filter the histogram data from the statistics container.
31
32
 
32
33
  Args:
33
- activation_quant_cfg: node's activation quantization configuration.
34
- nodes_prior_info: Prior info collected for the node that is being quantized.
35
- out_stats_container: Tensor containing output statistics of the node.
34
+ activation_quant_cfg: Node's activation quantization configuration.
35
+ out_stats_container: Statistics container with histogram data.
36
36
 
37
37
  Returns:
38
- The computed activation quantization params.
38
+ A tuple containing the filtered bins_values and bins_counts.
39
39
  """
40
-
41
40
  bins_values, bins_counts = None, None
42
41
 
43
42
  # If the statistics container collected the histogram, we start by filtering outliers using z threshold
44
43
  # filtering, and then computing the threshold based on the filtered histogram.
45
44
  if out_stats_container.require_collection():
46
- bins_values, bins_counts = out_stats_container.hc.get_histogram()
47
- bins_counts = quantization_params_generation.z_score_filter(activation_quant_cfg.z_threshold,
48
- bins_values,
49
- bins_counts)
50
- min_value, max_value = out_stats_container.get_min_max_values()
45
+ if activation_quant_cfg.activation_error_method == QuantizationErrorMethod.HMSE:
46
+ bins_values, bins_counts = out_stats_container.weighted_hc.get_histogram()
47
+ else:
48
+ bins_values, bins_counts = out_stats_container.hc.get_histogram()
49
+ bins_counts = quantization_params_generation.z_score_filter(
50
+ activation_quant_cfg.z_threshold,
51
+ bins_values,
52
+ bins_counts
53
+ )
54
+ return bins_values, bins_counts
55
+
56
+ def determine_signedness(
57
+ activation_quant_cfg: NodeActivationQuantizationConfig,
58
+ nodes_prior_info: NodePriorInfo,
59
+ min_value: float,
60
+ bins_values: Optional[np.ndarray],
61
+ bins_counts: Optional[np.ndarray]
62
+ ) -> bool:
63
+ """
64
+ Determine if the activations should be considered signed based on the quantization configuration,
65
+ node prior information, and histogram statistics.
66
+
67
+ Args:
68
+ activation_quant_cfg: Node's activation quantization configuration.
69
+ nodes_prior_info: Prior info collected for the node that is being quantized.
70
+ min_value: Minimum value from the statistics container.
71
+ bins_values: Numpy array of histogram bin values.
72
+ bins_counts: Numpy array of histogram bin counts.
51
73
 
74
+ Returns:
75
+ A boolean indicating if the activations are signed.
76
+ """
52
77
  if activation_quant_cfg.signedness in [Signedness.SIGNED, Signedness.UNSIGNED]:
53
- signed = activation_quant_cfg.signedness == Signedness.SIGNED
54
- elif nodes_prior_info.is_output_bounded():
55
- signed = min_value < 0
56
- else:
57
- signed = np.any(bins_values[:-1][bins_counts > 0] < 0)
78
+ return activation_quant_cfg.signedness == Signedness.SIGNED
58
79
 
80
+ if nodes_prior_info.is_output_bounded():
81
+ return min_value < 0
82
+
83
+ return np.any(bins_values[:-1][bins_counts > 0] < 0)
84
+
85
+
86
+ def update_activation_quantization_params_fn(
87
+ activation_quant_cfg: NodeActivationQuantizationConfig,
88
+ nodes_prior_info: NodePriorInfo):
89
+ """
90
+ Update the activation quantization parameters function based on the quantization method
91
+ and whether the node's output is bounded.
92
+
93
+ Args:
94
+ activation_quant_cfg: Node's activation quantization configuration.
95
+ nodes_prior_info: Prior info collected for the node that is being quantized.
96
+ """
59
97
  if nodes_prior_info.is_output_bounded():
60
98
  if activation_quant_cfg.activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
61
- activation_quant_cfg.activation_quantization_params_fn = \
99
+ activation_quant_cfg.set_activation_quantization_params_fn(
62
100
  quantization_params_generation.power_of_two_no_clipping_selection_min_max
101
+ )
63
102
  elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.SYMMETRIC:
64
- activation_quant_cfg.activation_quantization_params_fn = \
103
+ activation_quant_cfg.set_activation_quantization_params_fn(
65
104
  quantization_params_generation.symmetric_no_clipping_selection_min_max
105
+ )
66
106
  elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.UNIFORM:
67
- activation_quant_cfg.activation_quantization_params_fn = \
107
+ activation_quant_cfg.set_activation_quantization_params_fn(
68
108
  quantization_params_generation.uniform_no_clipping_selection_min_max
109
+ )
110
+
111
+
112
+ def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConfig,
113
+ nodes_prior_info: NodePriorInfo,
114
+ out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]:
115
+ """
116
+ Compute the activations params for a given node in a graph according to a params function.
117
+
118
+ Args:
119
+ activation_quant_cfg: node's activation quantization configuration.
120
+ nodes_prior_info: Prior info collected for the node that is being quantized.
121
+ out_stats_container: Tensor containing output statistics of the node.
122
+
123
+ Returns:
124
+ The computed activation quantization params.
125
+ """
126
+ # Update quantization parameters function based on output bounds and quantization method.
127
+ update_activation_quantization_params_fn(activation_quant_cfg, nodes_prior_info)
128
+
129
+ # Extract and filter histogram data from the statistics container.
130
+ bins_values, bins_counts = get_histogram_data(activation_quant_cfg, out_stats_container)
131
+
132
+ # Retrieve the minimum and maximum values from the statistics container.
133
+ min_value, max_value = out_stats_container.get_min_max_values()
134
+
135
+ # Determine if the activations should be considered signed.
136
+ signed = determine_signedness(
137
+ activation_quant_cfg,
138
+ nodes_prior_info,
139
+ min_value,
140
+ bins_values,
141
+ bins_counts
142
+ )
69
143
 
70
- return activation_quant_cfg.activation_quantization_params_fn(bins_values,
71
- bins_counts,
72
- activation_quant_cfg.l_p_value,
73
- activation_quant_cfg.activation_n_bits,
74
- min_value,
75
- max_value,
76
- min_threshold=activation_quant_cfg.min_threshold,
77
- quant_error_method=activation_quant_cfg.activation_error_method,
78
- is_signed=signed)
144
+ # Compute and return the activation quantization parameters.
145
+ return activation_quant_cfg.activation_quantization_params_fn(
146
+ bins_values,
147
+ bins_counts,
148
+ activation_quant_cfg.l_p_value,
149
+ activation_quant_cfg.activation_n_bits,
150
+ min_value,
151
+ max_value,
152
+ min_threshold=activation_quant_cfg.min_threshold,
153
+ quant_error_method=activation_quant_cfg.activation_error_method,
154
+ is_signed=signed
155
+ )
@@ -188,18 +188,30 @@ class KerasImplementation(FrameworkImplementation):
188
188
 
189
189
  def run_model_inference(self,
190
190
  model: Any,
191
- input_list: List[Any]) -> Tuple[tf.Tensor]:
191
+ input_list: List[Any],
192
+ requires_grad: bool = False) -> Tuple[tf.Tensor]:
192
193
  """
193
- Run the model logic on the given the inputs.
194
+ Runs inference on the given Keras model with the provided inputs.
195
+
196
+ This method executes the model on the given input data. If `requires_grad` is set to `False`,
197
+ gradients will not be computed during inference by wrapping execution in a `tf.stop_gradient()` context.
194
198
 
195
199
  Args:
196
- model: Keras model.
197
- input_list: List of inputs for the model.
200
+ model: The Keras model to execute.
201
+ input_list: A list of inputs for the model.
202
+ requires_grad: If False, prevents gradient computation (default: False).
198
203
 
199
204
  Returns:
200
- The Keras model's output.
201
- """
202
- return model(input_list)
205
+ A tuple containing the model's output tensors.
206
+ """
207
+ # Prevent gradient computation if requires_grad is False
208
+ if requires_grad:
209
+ # Record operations for automatic differentiation
210
+ with tf.GradientTape(persistent=True, watch_accessed_variables=False) as g:
211
+ g.watch(input_list)
212
+ return model(input_list)
213
+ else:
214
+ return model(input_list)
203
215
 
204
216
  def shift_negative_correction(self,
205
217
  graph: Graph,
@@ -553,28 +565,24 @@ class KerasImplementation(FrameworkImplementation):
553
565
 
554
566
  Returns: The MAC count og the operation
555
567
  """
556
-
557
- output_shape = node.output_shape
558
- kernel_shape = node.get_weights_by_keys(fw_info.get_kernel_op_attributes(node.type)[0]).shape
559
- output_channel_axis, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
560
-
561
- if node.is_match_type(Conv2D) or node.is_match_type(Conv2DTranspose):
562
- # (C_out * W_out * H_out) * C_in * (W_kernel * H_kernel)
563
- return np.prod([x for x in output_shape if x is not None]) * \
564
- kernel_shape[input_channel_axis] * \
565
- (kernel_shape[0] * kernel_shape[1])
566
- elif node.is_match_type(DepthwiseConv2D):
567
- # Depth * (W_out * H_out) * C_in * (W_kernel * H_kernel)
568
- return node.framework_attr.get(DEPTH_MULTIPLIER) * \
569
- np.prod([x for x in output_shape if x is not None]) / output_shape[output_channel_axis] * \
570
- kernel_shape[input_channel_axis] * \
571
- (kernel_shape[0] * kernel_shape[1])
572
- elif node.is_match_type(Dense):
573
- # IN * OUT
574
- return kernel_shape[0] * kernel_shape[1]
575
- else:
568
+ kernels = fw_info.get_kernel_op_attributes(node.type)
569
+ if not kernels or kernels[0] is None:
576
570
  return 0
577
571
 
572
+ assert len(kernels) == 1
573
+ kernel_shape = node.get_weights_by_keys(kernels[0]).shape
574
+
575
+ if node.is_match_type(Conv2D) or node.is_match_type(Conv2DTranspose) or node.is_match_type(DepthwiseConv2D):
576
+ h, w = node.get_output_shapes_list()[0][-3:-1]
577
+ return np.prod(kernel_shape) * h * w
578
+
579
+ if node.is_match_type(Dense):
580
+ # IN * OUT * (all previous dims[:-1])
581
+ _, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
582
+ return node.get_total_output_params() * kernel_shape[input_channel_axis]
583
+
584
+ return 0
585
+
578
586
  def apply_second_moment_correction(self,
579
587
  quantized_model: Any,
580
588
  core_config: CoreConfig,
@@ -72,19 +72,21 @@ class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
72
72
  input_tensor.requires_grad_()
73
73
  input_tensor.retain_grad()
74
74
 
75
- outputs = model(*self.input_images)
75
+ model_output_tensors = model(*self.input_images)
76
76
 
77
- if len(outputs) != len(grad_model_outputs): # pragma: no cover
77
+ if len(model_output_tensors) != len(grad_model_outputs): # pragma: no cover
78
78
  Logger.critical(f"Mismatch in expected and actual model outputs for activation Hessian approximation. "
79
- f"Expected {len(grad_model_outputs)} outputs, received {len(outputs)}.")
79
+ f"Expected {len(grad_model_outputs)} outputs, received {len(model_output_tensors)}.")
80
+ return model_output_tensors
80
81
 
82
+ def _prep_tensors_for_compute(self, model_output_tensors):
81
83
  # Extracting the intermediate activation tensors and the model real output.
82
84
  # Note that we do not allow computing Hessian for output nodes, so there shouldn't be an overlap.
83
85
  num_target_nodes = len(self.hessian_request.target_nodes)
84
86
  # Extract activation tensors of nodes for which we want to compute Hessian
85
- target_activation_tensors = outputs[:num_target_nodes]
87
+ target_activation_tensors = model_output_tensors[:num_target_nodes]
86
88
  # Extract the model outputs
87
- output_tensors = outputs[num_target_nodes:]
89
+ output_tensors = model_output_tensors[num_target_nodes:]
88
90
 
89
91
  # Concat outputs
90
92
  # First, we need to unfold all outputs that are given as list, to extract the actual output tensors
@@ -98,79 +100,39 @@ class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
98
100
  Returns:
99
101
  List[np.ndarray]: Scores based on the approximated Hessian for the requested nodes.
100
102
  """
101
- output, target_activation_tensors = self.forward_pass()
102
-
103
- if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
104
- hessian_scores = self._compute_per_tensor(output, target_activation_tensors)
105
- elif self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL:
106
- hessian_scores = self._compute_per_channel(output, target_activation_tensors)
103
+ if self.hessian_request.compute_from_tensors:
104
+ model_output_tensors = self.input_images
107
105
  else:
108
- raise NotImplementedError(f'{self.hessian_request.granularity} is not supported') # pragma: no cover
109
-
110
- # Convert results to list of numpy arrays
111
- hessian_results = [torch_tensor_to_numpy(h) for h in hessian_scores]
112
- return hessian_results
106
+ model_output_tensors = self.forward_pass()
107
+ output, target_activation_tensors = self._prep_tensors_for_compute(model_output_tensors)
113
108
 
114
- def _compute_per_tensor(self, output, target_activation_tensors):
115
- assert self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR
116
- ipts_hessian_approx_scores = [torch.tensor([0.0], requires_grad=True, device=output.device)
109
+ ipts_hessian_approx_scores = [torch.tensor(0.0, requires_grad=True, device=output.device)
117
110
  for _ in range(len(target_activation_tensors))]
118
- prev_mean_results = None
119
- for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations
120
- # Getting a random vector
111
+
112
+ for j in tqdm(range(self.num_iterations_for_approximation),
113
+ "Hessian random iterations"): # Approximation iterations
121
114
  v = self._generate_random_vectors_batch(output.shape, output.device)
122
115
  f_v = torch.sum(v * output)
123
116
  for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
124
- # Computing the hessian-approximation scores by getting the gradient of (output * v)
125
117
  hess_v = autograd.grad(outputs=f_v,
126
118
  inputs=ipt_tensor,
127
119
  retain_graph=True,
128
120
  allow_unused=True)[0]
129
-
130
121
  if hess_v is None:
131
122
  # In case we have an output node, which is an interest point, but it is not differentiable,
132
123
  # we consider its Hessian to be the initial value 0.
133
124
  continue # pragma: no cover
134
125
 
135
- # Mean over all dims but the batch (CXHXW for conv)
136
- hessian_approx_scores = torch.sum(hess_v ** 2.0, dim=tuple(d for d in range(1, len(hess_v.shape))))
137
-
138
- # Update node Hessian approximation mean over random iterations
139
- ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1)
140
-
141
- # If the change to the maximal mean Hessian approximation is insignificant we stop the calculation
142
- if j > MIN_HESSIAN_ITER:
143
- if prev_mean_results is not None:
144
- new_mean_res = torch.mean(torch.stack(ipts_hessian_approx_scores), dim=1)
145
- relative_delta_per_node = (torch.abs(new_mean_res - prev_mean_results) /
146
- (torch.abs(new_mean_res) + 1e-6))
147
- max_delta = torch.max(relative_delta_per_node)
148
- if max_delta < HESSIAN_COMP_TOLERANCE:
149
- break
150
- prev_mean_results = torch.mean(torch.stack(ipts_hessian_approx_scores), dim=1)
151
-
152
- # add extra dimension to preserve previous behaviour
153
- ipts_hessian_approx_scores = [torch.unsqueeze(t, -1) for t in ipts_hessian_approx_scores]
154
- return ipts_hessian_approx_scores
155
-
156
- def _compute_per_channel(self, output, target_activation_tensors):
157
- assert self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL
158
- ipts_hessian_approx_scores = [torch.tensor(0.0, requires_grad=True, device=output.device)
159
- for _ in range(len(target_activation_tensors))]
160
-
161
- for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations
162
- v = self._generate_random_vectors_batch(output.shape, output.device)
163
- f_v = torch.sum(v * output)
164
- for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
165
- hess_v = autograd.grad(outputs=f_v,
166
- inputs=ipt_tensor,
167
- retain_graph=True)[0]
168
126
  hessian_approx_scores = hess_v ** 2
169
- rank = len(hess_v.shape)
170
- if rank > 2:
171
- hessian_approx_scores = torch.mean(hessian_approx_scores, dim=tuple(range(2, rank)))
127
+ num_dims = len(hess_v.shape)
128
+ if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
129
+ hessian_approx_scores = torch.sum(hessian_approx_scores, dim=tuple(range(1,num_dims))).unsqueeze(-1)
130
+ elif self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL and num_dims > 2:
131
+ hessian_approx_scores = torch.mean(hessian_approx_scores, dim=tuple(range(2, num_dims)))
172
132
 
173
133
  # Update node Hessian approximation mean over random iterations
174
134
  ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1)
175
135
 
176
- return ipts_hessian_approx_scores
136
+ # Convert results to list of numpy arrays
137
+ hessian_results = [torch_tensor_to_numpy(h) for h in ipts_hessian_approx_scores]
138
+ return hessian_results