mct-nightly 2.1.0.20240708.453__py3-none-any.whl → 2.1.0.20240710.440__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {mct_nightly-2.1.0.20240708.453.dist-info → mct_nightly-2.1.0.20240710.440.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.1.0.20240708.453.dist-info → mct_nightly-2.1.0.20240710.440.dist-info}/RECORD +31 -31
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/framework_implementation.py +12 -12
  5. model_compression_toolkit/core/common/hessian/__init__.py +1 -1
  6. model_compression_toolkit/core/common/hessian/hessian_info_service.py +74 -69
  7. model_compression_toolkit/core/common/hessian/hessian_info_utils.py +1 -1
  8. model_compression_toolkit/core/common/hessian/{trace_hessian_calculator.py → hessian_scores_calculator.py} +11 -11
  9. model_compression_toolkit/core/common/hessian/{trace_hessian_request.py → hessian_scores_request.py} +15 -15
  10. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  11. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +8 -8
  12. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -5
  13. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +4 -4
  14. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -5
  15. model_compression_toolkit/core/keras/hessian/{activation_trace_hessian_calculator_keras.py → activation_hessian_scores_calculator_keras.py} +26 -26
  16. model_compression_toolkit/core/keras/hessian/{trace_hessian_calculator_keras.py → hessian_scores_calculator_keras.py} +14 -14
  17. model_compression_toolkit/core/keras/hessian/{weights_trace_hessian_calculator_keras.py → weights_hessian_scores_calculator_keras.py} +27 -27
  18. model_compression_toolkit/core/keras/keras_implementation.py +30 -30
  19. model_compression_toolkit/core/pytorch/hessian/{activation_trace_hessian_calculator_pytorch.py → activation_hessian_scores_calculator_pytorch.py} +25 -25
  20. model_compression_toolkit/core/pytorch/hessian/{trace_hessian_calculator_pytorch.py → hessian_scores_calculator_pytorch.py} +14 -14
  21. model_compression_toolkit/core/pytorch/hessian/{weights_trace_hessian_calculator_pytorch.py → weights_hessian_scores_calculator_pytorch.py} +25 -25
  22. model_compression_toolkit/core/pytorch/pytorch_implementation.py +30 -30
  23. model_compression_toolkit/core/quantization_prep_runner.py +1 -1
  24. model_compression_toolkit/gptq/common/gptq_training.py +30 -30
  25. model_compression_toolkit/gptq/keras/gptq_training.py +1 -1
  26. model_compression_toolkit/gptq/pytorch/gptq_training.py +1 -1
  27. model_compression_toolkit/gptq/runner.py +2 -2
  28. model_compression_toolkit/qat/pytorch/quantization_facade.py +1 -1
  29. {mct_nightly-2.1.0.20240708.453.dist-info → mct_nightly-2.1.0.20240710.440.dist-info}/LICENSE.md +0 -0
  30. {mct_nightly-2.1.0.20240708.453.dist-info → mct_nightly-2.1.0.20240710.440.dist-info}/WHEEL +0 -0
  31. {mct_nightly-2.1.0.20240708.453.dist-info → mct_nightly-2.1.0.20240710.440.dist-info}/top_level.txt +0 -0
@@ -18,16 +18,16 @@ from typing import List, Any
18
18
 
19
19
  from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
20
20
  from model_compression_toolkit.core.common import Graph
21
- from model_compression_toolkit.core.common.hessian import TraceHessianRequest
21
+ from model_compression_toolkit.core.common.hessian import HessianScoresRequest
22
22
  from model_compression_toolkit.logger import Logger
23
23
 
24
24
 
25
- class TraceHessianCalculator(ABC):
25
+ class HessianScoresCalculator(ABC):
26
26
  """
27
- Abstract base class for computing an approximation of the trace of the Hessian.
27
+ Abstract base class for computing scores based on the Hessian matrix approximation.
28
28
 
29
29
  This class provides a structure for implementing different methods to compute
30
- the trace of the Hessian approximation based on the provided configuration,
30
+ scores based on Hessian-approximation according to the provided configuration,
31
31
  input images, and other parameters.
32
32
  """
33
33
 
@@ -35,15 +35,15 @@ class TraceHessianCalculator(ABC):
35
35
  graph: Graph,
36
36
  input_images: List[Any],
37
37
  fw_impl,
38
- trace_hessian_request: TraceHessianRequest,
38
+ hessian_scores_request: HessianScoresRequest,
39
39
  num_iterations_for_approximation: int = HESSIAN_NUM_ITERATIONS):
40
40
  """
41
41
  Args:
42
42
  graph: Computational graph for the float model.
43
43
  input_images: List of input images for the computation.
44
- fw_impl: Framework-specific implementation for trace Hessian computation.
45
- trace_hessian_request: Configuration request for which to compute the trace Hessian approximation.
46
- num_iterations_for_approximation: Number of iterations to use when approximating the Hessian trace.
44
+ fw_impl: Framework-specific implementation for Hessian-approximation scores computation.
45
+ hessian_scores_request: Configuration request for which to compute the Hessian-based approximation.
46
+ num_iterations_for_approximation: Number of iterations to use when approximating the Hessian-approximation scores.
47
47
 
48
48
  """
49
49
  self.graph = graph
@@ -60,15 +60,15 @@ class TraceHessianCalculator(ABC):
60
60
  Logger.critical(f"The graph requires {len(graph.get_inputs())} inputs, but the provided representative dataset contains {len(self.input_images)} inputs.")
61
61
 
62
62
  self.fw_impl = fw_impl
63
- self.hessian_request = trace_hessian_request
63
+ self.hessian_request = hessian_scores_request
64
64
 
65
65
  @abstractmethod
66
66
  def compute(self) -> List[float]:
67
67
  """
68
- Abstract method to compute the approximation of the trace of the Hessian.
68
+ Abstract method to compute the scores based on the Hessian-approximation matrix.
69
69
 
70
70
  This method should be implemented by subclasses to provide the specific
71
- computation method for the trace Hessian approximation.
71
+ computation method for the Hessian-approximation scores.
72
72
  """
73
73
  raise NotImplemented(f'{self.__class__.__name__} have to implement compute method.') # pragma: no cover
74
74
 
@@ -28,11 +28,11 @@ class HessianMode(Enum):
28
28
  ACTIVATION = 1 # Hessian approximation based on activations
29
29
 
30
30
 
31
- class HessianInfoGranularity(Enum):
31
+ class HessianScoresGranularity(Enum):
32
32
  """
33
- Enum representing the granularity level for Hessian information computation.
33
+ Enum representing the granularity level for Hessian scores computation.
34
34
 
35
- This determines the number the Hessian approximations is computed for some node.
35
+ This determines the number the Hessian scores is computed for some node.
36
36
  Note: This is not the actual Hessian but an approximation.
37
37
  """
38
38
  PER_ELEMENT = 0
@@ -40,25 +40,25 @@ class HessianInfoGranularity(Enum):
40
40
  PER_TENSOR = 2
41
41
 
42
42
 
43
- class TraceHessianRequest:
43
+ class HessianScoresRequest:
44
44
  """
45
- Request configuration for the trace of the Hessian approximation.
45
+ Request configuration for the Hessian-approximation scores.
46
46
 
47
- This class defines the parameters for the approximation of the trace of the Hessian matrix.
47
+ This class defines the parameters for the scores based on the Hessian matrix approximation.
48
48
  It specifies the mode (weights/activations), granularity (element/channel/tensor), and the target node.
49
- Note: This does not compute the actual Hessian's trace but approximates it.
49
+
50
+ Note: This does not compute scores using the actual Hessian matrix but an approximation.
50
51
  """
51
52
 
52
53
  def __init__(self,
53
54
  mode: HessianMode,
54
- granularity: HessianInfoGranularity,
55
- target_nodes: List,
56
- ):
55
+ granularity: HessianScoresGranularity,
56
+ target_nodes: List):
57
57
  """
58
58
  Attributes:
59
- mode (HessianMode): Mode of Hessian's trace approximation (w.r.t weights or activations).
60
- granularity (HessianInfoGranularity): Granularity level for the approximation.
61
- target_nodes (List[BaseNode]): The node in the float graph for which the Hessian's trace approximation is targeted.
59
+ mode (HessianMode): Mode of Hessian-approximation score (w.r.t weights or activations).
60
+ granularity (HessianScoresGranularity): Granularity level for the approximation.
61
+ target_nodes (List[BaseNode]): The node in the float graph for which the Hessian's approximation scores is targeted.
62
62
  """
63
63
 
64
64
  self.mode = mode # w.r.t activations or weights
@@ -66,9 +66,9 @@ class TraceHessianRequest:
66
66
  self.target_nodes = target_nodes
67
67
 
68
68
  def __eq__(self, other):
69
- # Checks if the other object is an instance of TraceHessianRequest
69
+ # Checks if the other object is an instance of HessianScoresRequest
70
70
  # and then checks if all attributes are equal.
71
- return isinstance(other, TraceHessianRequest) and \
71
+ return isinstance(other, HessianScoresRequest) and \
72
72
  self.mode == other.mode and \
73
73
  self.granularity == other.granularity and \
74
74
  self.target_nodes == other.target_nodes
@@ -51,7 +51,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
51
51
  mp_config: MixedPrecisionQuantizationConfig,
52
52
  representative_data_gen: Callable,
53
53
  search_method: BitWidthSearchMethod = BitWidthSearchMethod.INTEGER_PROGRAMMING,
54
- hessian_info_service: HessianInfoService=None) -> List[int]:
54
+ hessian_info_service: HessianInfoService = None) -> List[int]:
55
55
  """
56
56
  Search for an MP configuration for a given graph. Given a search_method method (by default, it's linear
57
57
  programming), we use the sensitivity_evaluator object that provides a function to compute an
@@ -68,7 +68,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
68
68
  mp_config: Mixed-precision quantization configuration.
69
69
  representative_data_gen: Dataset to use for retrieving images for the models inputs.
70
70
  search_method: BitWidthSearchMethod to define which searching method to use.
71
- hessian_info_service: HessianInfoService to fetch Hessian traces approximations.
71
+ hessian_info_service: HessianInfoService to fetch Hessian-approximation information.
72
72
 
73
73
  Returns:
74
74
  A MP configuration for the graph (list of integers, where the index in the list, is the node's
@@ -24,8 +24,8 @@ from model_compression_toolkit.core.common.graph.functional_node import Function
24
24
  from model_compression_toolkit.core.common.similarity_analyzer import compute_kl_divergence
25
25
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
26
26
  from model_compression_toolkit.logger import Logger
27
- from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianMode, \
28
- HessianInfoGranularity, HessianInfoService
27
+ from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode, \
28
+ HessianScoresGranularity, HessianInfoService
29
29
 
30
30
 
31
31
  class SensitivityEvaluation:
@@ -65,7 +65,7 @@ class SensitivityEvaluation:
65
65
  set_layer_to_bitwidth: A fw-dependent function that allows to configure a configurable MP model
66
66
  with a specific bit-width configuration.
67
67
  disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
68
- hessian_info_service: HessianInfoService to fetch Hessian traces approximations.
68
+ hessian_info_service: HessianInfoService to fetch Hessian approximation information.
69
69
 
70
70
  """
71
71
  self.graph = graph
@@ -237,14 +237,14 @@ class SensitivityEvaluation:
237
237
  to be used for the distance metric weighted average computation.
238
238
 
239
239
  """
240
- # Create a request for trace Hessian approximation with specific configurations
240
+ # Create a request for Hessian approximation scores with specific configurations
241
241
  # (here we use per-tensor approximation of the Hessian's trace w.r.t the node's activations)
242
- trace_hessian_request = TraceHessianRequest(mode=HessianMode.ACTIVATION,
243
- granularity=HessianInfoGranularity.PER_TENSOR,
242
+ hessian_info_request = HessianScoresRequest(mode=HessianMode.ACTIVATION,
243
+ granularity=HessianScoresGranularity.PER_TENSOR,
244
244
  target_nodes=self.interest_points)
245
245
 
246
- # Fetch the trace Hessian approximations for the current interest point
247
- nodes_approximations = self.hessian_info_service.fetch_hessian(trace_hessian_request=trace_hessian_request,
246
+ # Fetch the Hessian approximation scores for the current interest point
247
+ nodes_approximations = self.hessian_info_service.fetch_hessian(hessian_scores_request=hessian_info_request,
248
248
  required_size=self.quant_config.num_of_images,
249
249
  batch_size=self.quant_config.hessian_batch_size)
250
250
 
@@ -19,8 +19,8 @@ from typing import Callable, List, Dict, Tuple
19
19
 
20
20
  from model_compression_toolkit.core.common import Graph, BaseNode
21
21
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
- from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianMode, HessianInfoGranularity, \
23
- TraceHessianRequest
22
+ from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianMode, HessianScoresGranularity, \
23
+ HessianScoresRequest
24
24
  from model_compression_toolkit.core.common.pruning.channels_grouping import ChannelGrouping
25
25
  from model_compression_toolkit.core.common.pruning.importance_metrics.base_importance_metric import BaseImportanceMetric
26
26
  from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
@@ -127,9 +127,9 @@ class LFHImportanceMetric(BaseImportanceMetric):
127
127
  # Fetch and process Hessian scores for output channels of entry nodes.
128
128
  nodes_scores = []
129
129
  for node in entry_nodes:
130
- _request = TraceHessianRequest(mode=HessianMode.WEIGHTS,
131
- granularity=HessianInfoGranularity.PER_OUTPUT_CHANNEL,
132
- target_nodes=[node])
130
+ _request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
131
+ granularity=HessianScoresGranularity.PER_OUTPUT_CHANNEL,
132
+ target_nodes=[node])
133
133
  _scores_for_node = hessian_info_service.fetch_hessian(_request,
134
134
  required_size=self.pruning_config.num_score_approximations)
135
135
  nodes_scores.append(_scores_for_node)
@@ -16,7 +16,7 @@ from copy import deepcopy
16
16
  from typing import Tuple, Callable, List
17
17
  import numpy as np
18
18
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
19
- from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianMode, HessianInfoGranularity, \
19
+ from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode, HessianScoresGranularity, \
20
20
  HessianInfoService
21
21
  from model_compression_toolkit.core.common.similarity_analyzer import compute_mse, compute_mae, compute_lp_norm
22
22
  from model_compression_toolkit.logger import Logger
@@ -389,9 +389,9 @@ def _compute_hessian_for_hmse(node,
389
389
  Returns: A list with computed Hessian-based scores tensors for the given node.
390
390
 
391
391
  """
392
- _request = TraceHessianRequest(mode=HessianMode.WEIGHTS,
393
- granularity=HessianInfoGranularity.PER_ELEMENT,
394
- target_nodes=[node])
392
+ _request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
393
+ granularity=HessianScoresGranularity.PER_ELEMENT,
394
+ target_nodes=[node])
395
395
  _scores_for_node = hessian_info_service.fetch_hessian(_request,
396
396
  required_size=num_hessian_samples)
397
397
 
@@ -20,8 +20,8 @@ from typing import List
20
20
  from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
21
21
  from model_compression_toolkit.core import QuantizationErrorMethod
22
22
  from model_compression_toolkit.core.common import Graph, BaseNode
23
- from model_compression_toolkit.core.common.hessian import HessianInfoService, TraceHessianRequest, HessianMode, \
24
- HessianInfoGranularity
23
+ from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
24
+ HessianScoresGranularity
25
25
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
26
26
  import get_activations_qparams
27
27
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_weights_computation import \
@@ -89,9 +89,9 @@ def calculate_quantization_params(graph: Graph,
89
89
  # The Hessian scores are computed and stored in the hessian_info_service object.
90
90
  nodes_for_hmse = _collect_nodes_for_hmse(nodes_list, graph)
91
91
  if len(nodes_for_hmse) > 0:
92
- hessian_info_service.fetch_hessian(TraceHessianRequest(mode=HessianMode.WEIGHTS,
93
- granularity=HessianInfoGranularity.PER_ELEMENT,
94
- target_nodes=nodes_for_hmse),
92
+ hessian_info_service.fetch_hessian(HessianScoresRequest(mode=HessianMode.WEIGHTS,
93
+ granularity=HessianScoresGranularity.PER_ELEMENT,
94
+ target_nodes=nodes_for_hmse),
95
95
  required_size=num_hessian_samples,
96
96
  batch_size=1)
97
97
 
@@ -22,45 +22,45 @@ import numpy as np
22
22
  from model_compression_toolkit.constants import MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, \
23
23
  HESSIAN_NUM_ITERATIONS
24
24
  from model_compression_toolkit.core.common import Graph
25
- from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianInfoGranularity
25
+ from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
26
26
  from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
27
- from model_compression_toolkit.core.keras.hessian.trace_hessian_calculator_keras import TraceHessianCalculatorKeras
27
+ from model_compression_toolkit.core.keras.hessian.hessian_scores_calculator_keras import HessianScoresCalculatorKeras
28
28
  from model_compression_toolkit.logger import Logger
29
29
 
30
30
 
31
- class ActivationTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
31
+ class ActivationHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
32
32
  """
33
- Keras implementation of the Trace Hessian Calculator for activations.
33
+ Keras implementation of the Hessian-approximation scores Calculator for activations.
34
34
  """
35
35
  def __init__(self,
36
36
  graph: Graph,
37
37
  input_images: List[tf.Tensor],
38
38
  fw_impl,
39
- trace_hessian_request: TraceHessianRequest,
39
+ hessian_scores_request: HessianScoresRequest,
40
40
  num_iterations_for_approximation: int = HESSIAN_NUM_ITERATIONS):
41
41
  """
42
42
  Args:
43
43
  graph: Computational graph for the float model.
44
44
  input_images: List of input images for the computation.
45
- fw_impl: Framework-specific implementation for trace Hessian approximation computation.
46
- trace_hessian_request: Configuration request for which to compute the trace Hessian approximation.
47
- num_iterations_for_approximation: Number of iterations to use when approximating the Hessian trace.
45
+ fw_impl: Framework-specific implementation for Hessian approximation scores computation.
46
+ hessian_scores_request: Configuration request for which to compute the Hessian approximation scores.
47
+ num_iterations_for_approximation: Number of iterations to use when approximating the Hessian scores.
48
48
 
49
49
  """
50
- super(ActivationTraceHessianCalculatorKeras, self).__init__(graph=graph,
51
- input_images=input_images,
52
- fw_impl=fw_impl,
53
- trace_hessian_request=trace_hessian_request,
54
- num_iterations_for_approximation=num_iterations_for_approximation)
50
+ super(ActivationHessianScoresCalculatorKeras, self).__init__(graph=graph,
51
+ input_images=input_images,
52
+ fw_impl=fw_impl,
53
+ hessian_scores_request=hessian_scores_request,
54
+ num_iterations_for_approximation=num_iterations_for_approximation)
55
55
 
56
56
  def compute(self) -> List[np.ndarray]:
57
57
  """
58
- Compute the approximation of the trace of the Hessian w.r.t the requested target nodes' activations.
58
+ Compute the Hessian-approximation based scores w.r.t the requested target nodes' activations.
59
59
 
60
60
  Returns:
61
- List[np.ndarray]: Approximated trace of the Hessian for the requested nodes.
61
+ List[np.ndarray]: Scores based on the Hessian-approximation for the requested nodes.
62
62
  """
63
- if self.hessian_request.granularity == HessianInfoGranularity.PER_TENSOR:
63
+ if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
64
64
  model_output_nodes = [ot.node for ot in self.graph.get_outputs()]
65
65
 
66
66
  if len([n for n in self.hessian_request.target_nodes if n in model_output_nodes]) > 0:
@@ -98,9 +98,9 @@ class ActivationTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
98
98
  # Unfold and concatenate all outputs to form a single tensor
99
99
  output = self._concat_tensors(output_tensors)
100
100
 
101
- # List to store the approximated trace of the Hessian for each interest point
102
- ipts_hessian_trace_approx = [tf.Variable([0.0], dtype=tf.float32, trainable=True)
103
- for _ in range(len(target_activation_tensors))]
101
+ # List to store the Hessian-approximation scores for each interest point
102
+ ipts_hessian_approximations = [tf.Variable([0.0], dtype=tf.float32, trainable=True)
103
+ for _ in range(len(target_activation_tensors))]
104
104
 
105
105
  # Loop through each interest point activation tensor
106
106
  prev_mean_results = None
@@ -120,29 +120,29 @@ class ActivationTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
120
120
  continue # pragma: no cover
121
121
 
122
122
  # Mean over all dims but the batch (CXHXW for conv)
123
- hessian_trace_approx = tf.reduce_sum(hess_v ** 2.0,
124
- axis=tuple(d for d in range(1, len(hess_v.shape))))
123
+ hessian_approx = tf.reduce_sum(hess_v ** 2.0,
124
+ axis=tuple(d for d in range(1, len(hess_v.shape))))
125
125
 
126
126
  # Free gradients
127
127
  del hess_v
128
128
 
129
129
  # Update node Hessian approximation mean over random iterations
130
- ipts_hessian_trace_approx[i] = (j * ipts_hessian_trace_approx[i] + hessian_trace_approx) / (j + 1)
130
+ ipts_hessian_approximations[i] = (j * ipts_hessian_approximations[i] + hessian_approx) / (j + 1)
131
131
 
132
132
  # If the change to the mean approximation is insignificant (to all outputs)
133
133
  # we stop the calculation.
134
134
  if j > MIN_HESSIAN_ITER:
135
135
  if prev_mean_results is not None:
136
- new_mean_res = tf.reduce_mean(tf.stack(ipts_hessian_trace_approx), axis=1)
136
+ new_mean_res = tf.reduce_mean(tf.stack(ipts_hessian_approximations), axis=1)
137
137
  relative_delta_per_node = (tf.abs(new_mean_res - prev_mean_results) /
138
138
  (tf.abs(new_mean_res) + 1e-6))
139
139
  max_delta = tf.reduce_max(relative_delta_per_node)
140
140
  if max_delta < HESSIAN_COMP_TOLERANCE:
141
141
  break
142
- prev_mean_results = tf.reduce_mean(tf.stack(ipts_hessian_trace_approx), axis=1)
142
+ prev_mean_results = tf.reduce_mean(tf.stack(ipts_hessian_approximations), axis=1)
143
143
 
144
144
  # Convert results to list of numpy arrays
145
- hessian_results = [h.numpy() for h in ipts_hessian_trace_approx]
145
+ hessian_results = [h.numpy() for h in ipts_hessian_approximations]
146
146
  # Extend the Hessian tensors shape to align with expected return type
147
147
  # TODO: currently, only per-tensor Hessian is available for activation.
148
148
  # Once implementing per-channel or per-element, this alignment needs to be verified and handled separately.
@@ -152,4 +152,4 @@ class ActivationTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
152
152
 
153
153
  else: # pragma: no cover
154
154
  Logger.critical(f"{self.hessian_request.granularity} "
155
- f"is not supported for Keras activation hessian\'s trace approximation calculator.")
155
+ f"is not supported for Keras activation hessian\'s approximation scores calculator.")
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.core.common.hessian.trace_hessian_calculator import TraceHessianCalculator
16
+ from model_compression_toolkit.core.common.hessian.hessian_scores_calculator import HessianScoresCalculator
17
17
 
18
18
  from typing import List, Tuple, Dict, Any, Union
19
19
 
@@ -23,38 +23,38 @@ from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
23
23
  from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
24
24
  from model_compression_toolkit.core.common import Graph, BaseNode
25
25
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
26
- from model_compression_toolkit.core.common.hessian import TraceHessianRequest
26
+ from model_compression_toolkit.core.common.hessian import HessianScoresRequest
27
27
  from model_compression_toolkit.core.keras.back2framework.instance_builder import OperationHandler
28
28
  from tensorflow.python.util.object_identity import Reference as TFReference
29
29
 
30
30
  from model_compression_toolkit.logger import Logger
31
31
 
32
32
 
33
- class TraceHessianCalculatorKeras(TraceHessianCalculator):
33
+ class HessianScoresCalculatorKeras(HessianScoresCalculator):
34
34
  """
35
- Keras-specific implementation of the Trace Hessian approximation Calculator.
36
- This class serves as a base for other Keras-specific trace Hessian approximation calculators.
35
+ Keras-specific implementation of the Hessian approximation scores Calculator.
36
+ This class serves as a base for other Keras-specific Hessian approximation scores calculators.
37
37
  """
38
38
  def __init__(self,
39
39
  graph: Graph,
40
40
  input_images: List[tf.Tensor],
41
41
  fw_impl,
42
- trace_hessian_request: TraceHessianRequest,
42
+ hessian_scores_request: HessianScoresRequest,
43
43
  num_iterations_for_approximation: int = HESSIAN_NUM_ITERATIONS):
44
44
  """
45
45
 
46
46
  Args:
47
47
  graph: Computational graph for the float model.
48
48
  input_images: List of input images for the computation.
49
- fw_impl: Framework-specific implementation for trace Hessian computation.
50
- trace_hessian_request: Configuration request for which to compute the trace Hessian approximation.
51
- num_iterations_for_approximation: Number of iterations to use when approximating the Hessian trace.
49
+ fw_impl: Framework-specific implementation for Hessian-approximation scores computation.
50
+ hessian_scores_request: Configuration request for which to compute the Hessian approximation scores.
51
+ num_iterations_for_approximation: Number of iterations to use when approximating the Hessian-based scores.
52
52
  """
53
- super(TraceHessianCalculatorKeras, self).__init__(graph=graph,
54
- input_images=input_images,
55
- fw_impl=fw_impl,
56
- trace_hessian_request=trace_hessian_request,
57
- num_iterations_for_approximation=num_iterations_for_approximation)
53
+ super(HessianScoresCalculatorKeras, self).__init__(graph=graph,
54
+ input_images=input_images,
55
+ fw_impl=fw_impl,
56
+ hessian_scores_request=hessian_scores_request,
57
+ num_iterations_for_approximation=num_iterations_for_approximation)
58
58
 
59
59
  def _concat_tensors(self, tensors_to_concate: Union[tf.Tensor, List[tf.Tensor]]) -> tf.Tensor:
60
60
  """
@@ -20,39 +20,39 @@ from typing import List
20
20
 
21
21
  from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE
22
22
  from model_compression_toolkit.core.common import Graph
23
- from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianInfoGranularity
23
+ from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
24
24
  from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
25
25
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
26
- from model_compression_toolkit.core.keras.hessian.trace_hessian_calculator_keras import TraceHessianCalculatorKeras
26
+ from model_compression_toolkit.core.keras.hessian.hessian_scores_calculator_keras import HessianScoresCalculatorKeras
27
27
  from model_compression_toolkit.logger import Logger
28
28
 
29
29
 
30
- class WeightsTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
30
+ class WeightsHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
31
31
  """
32
- Keras-specific implementation of the Trace Hessian approximation computation w.r.t a node's weights.
32
+ Keras-specific implementation of the Hessian-approximation scores computation w.r.t a node's weights.
33
33
  """
34
34
 
35
35
  def __init__(self,
36
36
  graph: Graph,
37
37
  input_images: List[tf.Tensor],
38
38
  fw_impl,
39
- trace_hessian_request: TraceHessianRequest,
39
+ hessian_scores_request: HessianScoresRequest,
40
40
  num_iterations_for_approximation: int = HESSIAN_NUM_ITERATIONS):
41
41
  """
42
42
 
43
43
  Args:
44
44
  graph: Computational graph for the float model.
45
45
  input_images: List of input images for the computation.
46
- fw_impl: Framework-specific implementation for trace Hessian computation.
47
- trace_hessian_request: Configuration request for which to compute the trace Hessian approximation.
48
- num_iterations_for_approximation: Number of iterations to use when approximating the Hessian trace.
46
+ fw_impl: Framework-specific implementation for Hessian scores computation.
47
+ hessian_scores_request: Configuration request for which to compute the Hessian-approximation scores.
48
+ num_iterations_for_approximation: Number of iterations to use when approximating the Hessian-based scores.
49
49
  """
50
50
 
51
- super(WeightsTraceHessianCalculatorKeras, self).__init__(graph=graph,
52
- input_images=input_images,
53
- fw_impl=fw_impl,
54
- trace_hessian_request=trace_hessian_request,
55
- num_iterations_for_approximation=num_iterations_for_approximation)
51
+ super(WeightsHessianScoresCalculatorKeras, self).__init__(graph=graph,
52
+ input_images=input_images,
53
+ fw_impl=fw_impl,
54
+ hessian_scores_request=hessian_scores_request,
55
+ num_iterations_for_approximation=num_iterations_for_approximation)
56
56
 
57
57
  def compute(self) -> List[np.ndarray]:
58
58
  """
@@ -83,7 +83,7 @@ class WeightsTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
83
83
  # Combine outputs if the model returns multiple output tensors
84
84
  output = self._concat_tensors(outputs)
85
85
 
86
- ipts_hessian_trace_approx = [tf.Variable([0.0], dtype=tf.float32, trainable=True)
86
+ ipts_hessian_scores_approx = [tf.Variable([0.0], dtype=tf.float32, trainable=True)
87
87
  for _ in range(len(self.hessian_request.target_nodes))]
88
88
 
89
89
  prev_mean_results = None
@@ -134,7 +134,7 @@ class WeightsTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
134
134
  approx = tf.reduce_sum(tf.pow(gradients, 2.0), axis=1)
135
135
 
136
136
  # Update node Hessian approximation mean over random iterations
137
- ipts_hessian_trace_approx[i] = (j * ipts_hessian_trace_approx[i] + approx) / (j + 1)
137
+ ipts_hessian_scores_approx[i] = (j * ipts_hessian_scores_approx[i] + approx) / (j + 1)
138
138
 
139
139
  # Free gradients
140
140
  del gradients
@@ -145,33 +145,33 @@ class WeightsTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
145
145
  if j > MIN_HESSIAN_ITER:
146
146
  if prev_mean_results is not None:
147
147
  new_mean_res = \
148
- tf.convert_to_tensor([tf.reduce_mean(res) for res in ipts_hessian_trace_approx])
148
+ tf.convert_to_tensor([tf.reduce_mean(res) for res in ipts_hessian_scores_approx])
149
149
  relative_delta_per_node = (tf.abs(new_mean_res - prev_mean_results) /
150
150
  (tf.abs(new_mean_res) + 1e-6))
151
151
  max_delta = tf.reduce_max(relative_delta_per_node)
152
152
  if max_delta < HESSIAN_COMP_TOLERANCE:
153
153
  break
154
154
 
155
- prev_mean_results = tf.convert_to_tensor([tf.reduce_mean(res) for res in ipts_hessian_trace_approx])
155
+ prev_mean_results = tf.convert_to_tensor([tf.reduce_mean(res) for res in ipts_hessian_scores_approx])
156
156
 
157
157
  # Free gradient tape
158
158
  del tape
159
159
 
160
- if self.hessian_request.granularity == HessianInfoGranularity.PER_TENSOR:
161
- for final_approx in ipts_hessian_trace_approx:
160
+ if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
161
+ for final_approx in ipts_hessian_scores_approx:
162
162
  if final_approx.shape != (1,): # pragma: no cover
163
163
  Logger.critical(f"For HessianInfoGranularity.PER_TENSOR, the expected score shape is (1,), "
164
164
  f"but found {final_approx.shape}.")
165
- elif self.hessian_request.granularity == HessianInfoGranularity.PER_ELEMENT:
165
+ elif self.hessian_request.granularity == HessianScoresGranularity.PER_ELEMENT:
166
166
  # Reshaping the scores to the original weight shape
167
- ipts_hessian_trace_approx = \
167
+ ipts_hessian_scores_approx = \
168
168
  [tf.reshape(final_approx, s) for final_approx, s in
169
- zip(ipts_hessian_trace_approx, tensors_original_shape)]
169
+ zip(ipts_hessian_scores_approx, tensors_original_shape)]
170
170
 
171
171
  # Add a batch axis to the Hessian approximation tensor (to align with the expected returned shape)
172
172
  # We assume per-image computation, so the batch axis size is 1.
173
173
  final_approx = [r_final_approx[np.newaxis, ...].numpy()
174
- for r_final_approx in ipts_hessian_trace_approx]
174
+ for r_final_approx in ipts_hessian_scores_approx]
175
175
 
176
176
  return final_approx
177
177
 
@@ -193,7 +193,7 @@ class WeightsTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
193
193
  tf.Tensor: Reshaped gradient tensor based on the granularity.
194
194
  """
195
195
  # Reshape the gradients based on the granularity (whole tensor, per channel, or per element)
196
- if self.hessian_request.granularity != HessianInfoGranularity.PER_OUTPUT_CHANNEL:
196
+ if self.hessian_request.granularity != HessianScoresGranularity.PER_OUTPUT_CHANNEL:
197
197
  gradients = tf.reshape(gradients, [num_of_scores, -1])
198
198
  else:
199
199
  # Slice the gradients, vectorize them and stack them along the first axis.
@@ -216,11 +216,11 @@ class WeightsTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
216
216
  Returns:
217
217
  int: The number of scores.
218
218
  """
219
- if self.hessian_request.granularity == HessianInfoGranularity.PER_TENSOR:
219
+ if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
220
220
  return 1
221
- elif self.hessian_request.granularity == HessianInfoGranularity.PER_OUTPUT_CHANNEL:
221
+ elif self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL:
222
222
  return weight_tensor.shape[output_channel_axis]
223
- elif self.hessian_request.granularity == HessianInfoGranularity.PER_ELEMENT:
223
+ elif self.hessian_request.granularity == HessianScoresGranularity.PER_ELEMENT:
224
224
  return tf.size(weight_tensor).numpy()
225
225
  else: # pragma: no cover
226
226
  Logger.critical(f"Unexpected granularity encountered: {self.hessian_request.granularity}.")