mct-nightly 2.1.0.20240613.456__py3-none-any.whl → 2.1.0.20240614.431__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.1.0.20240613.456
3
+ Version: 2.1.0.20240614.431
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=UKIkguxChqu5eeC1hBi1AhEvt920ApyiU2_tgi4MVUg,1573
1
+ model_compression_toolkit/__init__.py,sha256=_dCTq18O_raH0zecwk9Lgrp3yuAwu3GbwwHT-lI4tUM,1573
2
2
  model_compression_toolkit/constants.py,sha256=9pVleMwnhlM4QwIL2HcEq42I1uF4rlSw63RUjkxOF4w,3923
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -45,7 +45,7 @@ model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py,sha256
45
45
  model_compression_toolkit/core/common/graph/memory_graph/memory_element.py,sha256=gRmBEFRmyJsNKezQfiwDwQu1cmbGd2wgKCRTH6iw8mw,3961
46
46
  model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py,sha256=gw4av_rzn_3oEAPpD3B7PHZDqnxHMjIESevl6ppPnkk,7175
47
47
  model_compression_toolkit/core/common/hessian/__init__.py,sha256=bxPVbkIlHFJMiOgTdWMVCqcD9JKV5kb2bVdWUTeLpj8,1021
48
- model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=yG3TznPlQgRGZ0Hb8O4ViJLt-xvWrOkbpfHiOypYgqU,20722
48
+ model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=0Ziwyzv6H5mIG5ptW6uC_w1gmxZIdffCuK8cg0STmJQ,20731
49
49
  model_compression_toolkit/core/common/hessian/hessian_info_utils.py,sha256=JepOjcyX1XyiC1UblqM3zdKv2xuUvU3HKWjlE1Bnq_U,1490
50
50
  model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py,sha256=EIV4NVUfvkefqMAFrrjNhQq7cvT3hljHpGz_gpVaFtY,4135
51
51
  model_compression_toolkit/core/common/hessian/trace_hessian_request.py,sha256=uvnaYtJRRmj_CfnYAO6oehnhDqdalW0NgETWJvSzCxc,3245
@@ -109,12 +109,12 @@ model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha
109
109
  model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGNWfT4MRogIU8ehs0tr3lVjnzAI-jeoS9b4TwVBo,2854
110
110
  model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=O4qFJw3nBYUD4cGbO8haGXZ2-piSqoRpDKDD74iXSxw,12417
111
111
  model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
112
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=w367wmtJ7iWmM4_HlpX-YVUuqtYKrsiPP1oDaICIuK8,23308
112
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=4XH-qSo-zG7XkVTx1J0DFNHEklLOhkhxXeEWnXNJ7z8,23602
113
113
  model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=t0XSwjfOxcq2Sj2PGzccntz1GGv2eqVn9oR3OI0t9wo,8533
114
114
  model_compression_toolkit/core/common/quantization/quantization_params_generation/outlier_filter.py,sha256=9gnfJV89jpGwAx8ImJ5E9NjCv3lDtbyulP4OtgWb62M,1772
115
115
  model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=HfnhQ4MxGpb95gOWXD1vnroTxxjFt9VFd4jIdo-rvAQ,10623
116
116
  model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=noEdvGiyyW7acgQ2OFWLedCODibTGYJifC9qo8YIU5U,4558
117
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=E_XFTpYNUZ3JgOk_2qbUbmJH6qGqBM3TDsY4WptYup0,6478
117
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=JS1nhQUMBVBtEjXbevFbbzHsXM0QLKVTG_3DRhdTAa0,8643
118
118
  model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py,sha256=o2XNY_0pUUyId02TUVQBtkux_i40NCcnzuobSeQLy3E,42863
119
119
  model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=zSNda0jN8cP41m6g5TOv5WvATwIhV8z6AVM1Es6rq1s,4419
120
120
  model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=4TP41wPYC0azIzFxUt-lNlKUPIIXQeE4H1SYHkON75k,11875
@@ -186,7 +186,7 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/weights_a
186
186
  model_compression_toolkit/core/keras/hessian/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
187
187
  model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py,sha256=4eJKq_Fx4mm_VuBDeeti0fTcUk1lL2yjebxCugJhvrA,8871
188
188
  model_compression_toolkit/core/keras/hessian/trace_hessian_calculator_keras.py,sha256=hRfAjgZakDaIMuERmTVjJSa_Ww6FmEudYPO9R7SuYuQ,3914
189
- model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py,sha256=P5auDAbKaOQYLNJTFXY0sy2AofS5OeB7cIAQhG5tQzo,11384
189
+ model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py,sha256=KBjGr9FzyZIPD4MFtsV3LDBdJtLa0VFdIXyx_KAnjTQ,12215
190
190
  model_compression_toolkit/core/keras/mixed_precision/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
191
191
  model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py,sha256=aW8wR13fK6P6xzbU9XGU60IO1yYzXSo_Hk4qeq486kg,5137
192
192
  model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py,sha256=Ziydik2j-LvNBXP3TSfUD6rEezPAikzQGib0_IXkmGM,6729
@@ -251,7 +251,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/weights
251
251
  model_compression_toolkit/core/pytorch/hessian/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
252
252
  model_compression_toolkit/core/pytorch/hessian/activation_trace_hessian_calculator_pytorch.py,sha256=eDiTiKVvH5NBgUFV6oBe7QeowJRo6tOQbcXx9t9k2S0,8522
253
253
  model_compression_toolkit/core/pytorch/hessian/trace_hessian_calculator_pytorch.py,sha256=Gat9aobUOQEWGt02x30vVm04mdi3gchdz2Bmmw5p91w,3445
254
- model_compression_toolkit/core/pytorch/hessian/weights_trace_hessian_calculator_pytorch.py,sha256=gTrnnzhqlfQYJoPugEbnHWMaKmtPDWv-2hNRCxtv0yE,7792
254
+ model_compression_toolkit/core/pytorch/hessian/weights_trace_hessian_calculator_pytorch.py,sha256=-B446KhtZHPU_5Ixtm9v_v-3qDQ05NoIj2iyq5DlgR4,8460
255
255
  model_compression_toolkit/core/pytorch/mixed_precision/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
256
256
  model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py,sha256=-6oep2WJ85-JmIxZa-e2AmBpbORoKe4Xdduz2ZidwvM,4871
257
257
  model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py,sha256=KVZTKCYzJqqzF5nFEiuGMv_sNeVuBTxhmxWMFacKOxE,6337
@@ -491,8 +491,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
491
491
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
492
492
  model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
493
493
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=MxylaVFPgN7zBiRBy6WV610EA4scLgRJFbMucKvvNDU,2896
494
- mct_nightly-2.1.0.20240613.456.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
495
- mct_nightly-2.1.0.20240613.456.dist-info/METADATA,sha256=H4hLWwgd8LFtvDh_noWKbl5JqDQn8XoDtN-pQp_ezJQ,19721
496
- mct_nightly-2.1.0.20240613.456.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
497
- mct_nightly-2.1.0.20240613.456.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
498
- mct_nightly-2.1.0.20240613.456.dist-info/RECORD,,
494
+ mct_nightly-2.1.0.20240614.431.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
495
+ mct_nightly-2.1.0.20240614.431.dist-info/METADATA,sha256=nZ_Rmy3k1IwzbSnR7mmNBfZknT6WaEGuXl3UAwEePHQ,19721
496
+ mct_nightly-2.1.0.20240614.431.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
497
+ mct_nightly-2.1.0.20240614.431.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
498
+ mct_nightly-2.1.0.20240614.431.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.1.0.20240613.000456"
30
+ __version__ = "2.1.0.20240614.000431"
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from collections.abc import Iterable
16
15
 
17
16
  import numpy as np
18
17
  from functools import partial
@@ -189,6 +188,10 @@ class HessianInfoService:
189
188
  images, next_iter_remain_samples = representative_dataset_gen(num_hessian_samples=num_hessian_samples,
190
189
  last_iter_remain_samples=last_iter_remain_samples)
191
190
 
191
+ # Compute and store the computed approximation in the saved info
192
+ topo_sorted_nodes_names = [x.name for x in self.graph.get_topo_sorted_nodes()]
193
+ trace_hessian_request.target_nodes.sort(key=lambda x: topo_sorted_nodes_names.index(x.name))
194
+
192
195
  # Get the framework-specific calculator for trace Hessian approximation
193
196
  fw_hessian_calculator = self.fw_impl.get_trace_hessian_calculator(graph=self.graph,
194
197
  input_images=images,
@@ -197,12 +200,7 @@ class HessianInfoService:
197
200
 
198
201
  trace_hessian = fw_hessian_calculator.compute()
199
202
 
200
- # Store the computed approximation in the saved info
201
- topo_sorted_nodes_names = [x.name for x in self.graph.get_topo_sorted_nodes()]
202
- sorted_target_nodes = sorted(trace_hessian_request.target_nodes,
203
- key=lambda x: topo_sorted_nodes_names.index(x.name))
204
-
205
- for node, hessian in zip(sorted_target_nodes, trace_hessian):
203
+ for node, hessian in zip(trace_hessian_request.target_nodes, trace_hessian):
206
204
  single_node_request = self._construct_single_node_request(trace_hessian_request.mode,
207
205
  trace_hessian_request.granularity,
208
206
  node)
@@ -246,6 +244,10 @@ class HessianInfoService:
246
244
  The inner list length dependent on the granularity (1 for per-tensor,
247
245
  OC for per-output-channel when the requested node has OC output-channels, etc.)
248
246
  """
247
+
248
+ if len(trace_hessian_request.target_nodes) == 0:
249
+ return []
250
+
249
251
  if required_size == 0:
250
252
  return [[] for _ in trace_hessian_request.target_nodes]
251
253
 
@@ -19,6 +19,7 @@ import model_compression_toolkit.core.common.quantization.quantization_config as
19
19
  from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianMode, HessianInfoGranularity, \
20
20
  HessianInfoService
21
21
  from model_compression_toolkit.core.common.similarity_analyzer import compute_mse, compute_mae, compute_lp_norm
22
+ from model_compression_toolkit.logger import Logger
22
23
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
23
24
  from model_compression_toolkit.constants import FLOAT_32, NUM_QPARAM_HESSIAN_SAMPLES
24
25
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import uniform_quantize_tensor, \
@@ -376,7 +377,7 @@ def _get_sliced_histogram(bins: np.ndarray,
376
377
 
377
378
  def _compute_hessian_for_hmse(node,
378
379
  hessian_info_service: HessianInfoService,
379
- num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> List[np.ndarray]:
380
+ num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> List[List[np.ndarray]]:
380
381
  """
381
382
  Compute and retrieve Hessian-based scores for using during HMSE error computation.
382
383
 
@@ -476,7 +477,10 @@ def get_threshold_selection_tensor_error_function(quantization_method: Quantizat
476
477
 
477
478
  if quant_error_method == qc.QuantizationErrorMethod.HMSE:
478
479
  node_hessian_scores = _compute_hessian_for_hmse(node, hessian_info_service, num_hessian_samples)
479
- node_hessian_scores = np.sqrt(np.mean(node_hessian_scores, axis=0))
480
+ if len(node_hessian_scores) != 1:
481
+ Logger.critical(f"Expecting single node Hessian score request to return a list of length 1, but got a list "
482
+ f"of length {len(node_hessian_scores)}.")
483
+ node_hessian_scores = np.sqrt(np.mean(node_hessian_scores[0], axis=0))
480
484
 
481
485
  return lambda x, y, threshold: _hmse_error_function_wrapper(x, y, norm=norm, axis=axis,
482
486
  hessian_scores=node_hessian_scores)
@@ -20,7 +20,8 @@ from typing import List
20
20
  from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
21
21
  from model_compression_toolkit.core import QuantizationErrorMethod
22
22
  from model_compression_toolkit.core.common import Graph, BaseNode
23
- from model_compression_toolkit.core.common.hessian import HessianInfoService
23
+ from model_compression_toolkit.core.common.hessian import HessianInfoService, TraceHessianRequest, HessianMode, \
24
+ HessianInfoGranularity
24
25
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
25
26
  import get_activations_qparams
26
27
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_weights_computation import \
@@ -28,6 +29,31 @@ from model_compression_toolkit.core.common.quantization.quantization_params_gene
28
29
  from model_compression_toolkit.logger import Logger
29
30
 
30
31
 
32
+ def _collect_nodes_for_hmse(nodes_list: List[BaseNode], graph: Graph) -> List[BaseNode]:
33
+ """
34
+ Collects nodes that are compatiable for parameters selection search using HMSE,
35
+ that is, have a kernel attribute that is configured for HMSE error method.
36
+
37
+ Args:
38
+ nodes_list: A list of nodes to search quantization parameters for.
39
+ graph: Graph to compute its nodes' quantization parameters..
40
+
41
+ Returns: A (possibly empty) list of nodes.
42
+
43
+ """
44
+ hmse_nodes = []
45
+ for n in nodes_list:
46
+ kernel_attr_name = graph.fw_info.get_kernel_op_attributes(n.type)
47
+ kernel_attr_name = None if kernel_attr_name is None or len(kernel_attr_name) == 0 else kernel_attr_name[0]
48
+
49
+ if kernel_attr_name is not None and n.is_weights_quantization_enabled(kernel_attr_name) and \
50
+ all([c.weights_quantization_cfg.get_attr_config(kernel_attr_name).weights_error_method ==
51
+ QuantizationErrorMethod.HMSE for c in n.candidates_quantization_cfg]):
52
+ hmse_nodes.append(n)
53
+
54
+ return hmse_nodes
55
+
56
+
31
57
  def calculate_quantization_params(graph: Graph,
32
58
  nodes: List[BaseNode] = [],
33
59
  specific_nodes: bool = False,
@@ -58,6 +84,17 @@ def calculate_quantization_params(graph: Graph,
58
84
  # Create a list of nodes to compute their thresholds
59
85
  nodes_list: List[BaseNode] = nodes if specific_nodes else graph.nodes()
60
86
 
87
+ # Collecting nodes that are configured to search weights quantization parameters using HMSE optimization
88
+ # and computing required Hessian information to be used for HMSE parameters selection.
89
+ # The Hessian scores are computed and stored in the hessian_info_service object.
90
+ nodes_for_hmse = _collect_nodes_for_hmse(nodes_list, graph)
91
+ if len(nodes_for_hmse) > 0:
92
+ hessian_info_service.fetch_hessian(TraceHessianRequest(mode=HessianMode.WEIGHTS,
93
+ granularity=HessianInfoGranularity.PER_ELEMENT,
94
+ target_nodes=nodes_for_hmse),
95
+ required_size=num_hessian_samples,
96
+ batch_size=1)
97
+
61
98
  for n in tqdm(nodes_list, "Calculating quantization parameters"): # iterate only nodes that we should compute their thresholds
62
99
  for candidate_qc in n.candidates_quantization_cfg:
63
100
  for attr in n.get_node_weights_attributes():
@@ -73,6 +110,8 @@ def calculate_quantization_params(graph: Graph,
73
110
  mod_attr_cfg = attr_cfg
74
111
 
75
112
  if attr_cfg.weights_error_method == QuantizationErrorMethod.HMSE:
113
+ # Although we collected nodes for HMSE before running the loop, we keep this verification to
114
+ # notify the user in case of HMSE configured for node that is not compatible for this method
76
115
  kernel_attr_name = graph.fw_info.get_kernel_op_attributes(n.type)
77
116
  if len(kernel_attr_name) > 0:
78
117
  kernel_attr_name = kernel_attr_name[0]
@@ -15,9 +15,10 @@
15
15
 
16
16
  import numpy as np
17
17
  import tensorflow as tf
18
+ from tqdm import tqdm
18
19
  from typing import List
19
20
 
20
- from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, HESSIAN_EPS
21
+ from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE
21
22
  from model_compression_toolkit.core.common import Graph
22
23
  from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianInfoGranularity
23
24
  from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
@@ -47,11 +48,6 @@ class WeightsTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
47
48
  num_iterations_for_approximation: Number of iterations to use when approximating the Hessian trace.
48
49
  """
49
50
 
50
- if len(trace_hessian_request.target_nodes) > 1: # pragma: no cover
51
- Logger.critical(f"Weights Hessian approximation is currently supported only for a single target node,"
52
- f" but the provided request contains the following target nodes: "
53
- f"{trace_hessian_request.target_nodes}.")
54
-
55
51
  super(WeightsTraceHessianCalculatorKeras, self).__init__(graph=graph,
56
52
  input_images=input_images,
57
53
  fw_impl=fw_impl,
@@ -73,35 +69,12 @@ class WeightsTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
73
69
  The function returns a list for compatibility reasons.
74
70
 
75
71
  """
76
- # Check if the target node's layer type is supported.
77
- # We assume that weights Hessian computation is done only for a single node at each request.
78
- target_node = self.hessian_request.target_nodes[0]
79
- if not DEFAULT_KERAS_INFO.is_kernel_op(target_node.type):
80
- Logger.critical(f"Hessian information with respect to weights is not supported for "
81
- f"{target_node.type} layers.") # pragma: no cover
82
72
 
83
73
  # Construct the Keras float model for inference
84
74
  model, _ = FloatKerasModelBuilder(graph=self.graph).build_model()
85
75
 
86
- # Get the weight attributes for the target node type
87
- weight_attributes = DEFAULT_KERAS_INFO.get_kernel_op_attributes(target_node.type)
88
-
89
- # Get the weight tensor for the target node
90
- if len(weight_attributes) != 1: # pragma: no cover
91
- Logger.critical(f"Hessian-based scoring with respect to weights is currently supported only for nodes with "
92
- f"a single weight attribute. Found {len(weight_attributes)} attributes.")
93
-
94
- weight_tensor = getattr(model.get_layer(target_node.name), weight_attributes[0])
95
-
96
- # Get the output channel index (needed for HessianInfoGranularity.PER_OUTPUT_CHANNEL case)
97
- output_channel_axis, _ = DEFAULT_KERAS_INFO.kernel_channels_mapping.get(target_node.type)
98
-
99
- # Get number of scores that should be calculated by the granularity.
100
- num_of_scores = self._get_num_scores_by_granularity(weight_tensor,
101
- output_channel_axis)
102
-
103
76
  # Initiate a gradient tape for automatic differentiation
104
- with tf.GradientTape(persistent=True) as tape:
77
+ with (tf.GradientTape(persistent=True) as tape):
105
78
  # Perform a forward pass (inference) to get the output, while watching
106
79
  # the input tensor for gradient computation
107
80
  tape.watch(self.input_images)
@@ -110,55 +83,97 @@ class WeightsTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
110
83
  # Combine outputs if the model returns multiple output tensors
111
84
  output = self._concat_tensors(outputs)
112
85
 
113
- approximation_per_iteration = []
114
- for j in range(self.num_iterations_for_approximation): # Approximation iterations
86
+ ipts_hessian_trace_approx = [tf.Variable([0.0], dtype=tf.float32, trainable=True)
87
+ for _ in range(len(self.hessian_request.target_nodes))]
88
+
89
+ prev_mean_results = None
90
+ tensors_original_shape = []
91
+ for j in tqdm(range(self.num_iterations_for_approximation)): # Approximation iterations
115
92
  # Getting a random vector with normal distribution and the same shape as the model output
116
93
  v = tf.random.normal(shape=output.shape)
117
94
  f_v = tf.reduce_sum(v * output)
118
95
 
119
- # Stop recording operations for automatic differentiation
96
+ for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
97
+
98
+ # Check if the target node's layer type is supported.
99
+ if not DEFAULT_KERAS_INFO.is_kernel_op(ipt_node.type):
100
+ Logger.critical(f"Hessian information with respect to weights is not supported for "
101
+ f"{ipt_node.type} layers.") # pragma: no cover
102
+
103
+ # Get the weight attributes for the target node type
104
+ weight_attributes = DEFAULT_KERAS_INFO.get_kernel_op_attributes(ipt_node.type)
105
+
106
+ # Get the weight tensor for the target node
107
+ if len(weight_attributes) != 1: # pragma: no cover
108
+ Logger.critical(
109
+ f"Hessian-based scoring with respect to weights is currently supported only for nodes with "
110
+ f"a single weight attribute. Found {len(weight_attributes)} attributes.")
111
+
112
+ weight_tensor = getattr(model.get_layer(ipt_node.name), weight_attributes[0])
113
+
114
+ if j == 0:
115
+ # On the first iteration we store the weight_tensor shape for later reshaping the results
116
+ # back if necessary
117
+ tensors_original_shape.append(weight_tensor.shape)
118
+
119
+ # Get the output channel index (needed for HessianInfoGranularity.PER_OUTPUT_CHANNEL case)
120
+ output_channel_axis, _ = DEFAULT_KERAS_INFO.kernel_channels_mapping.get(ipt_node.type)
121
+
122
+ # Get number of scores that should be calculated by the granularity.
123
+ num_of_scores = self._get_num_scores_by_granularity(weight_tensor,
124
+ output_channel_axis)
125
+
126
+ # Stop recording operations for automatic differentiation
127
+ with tape.stop_recording():
128
+ # Compute gradients of f_v with respect to the weights
129
+ gradients = tape.gradient(f_v, weight_tensor)
130
+ gradients = self._reshape_gradients(gradients,
131
+ output_channel_axis,
132
+ num_of_scores)
133
+
134
+ approx = tf.reduce_sum(tf.pow(gradients, 2.0), axis=1)
135
+
136
+ # Update node Hessian approximation mean over random iterations
137
+ ipts_hessian_trace_approx[i] = (j * ipts_hessian_trace_approx[i] + approx) / (j + 1)
138
+
139
+ # Free gradients
140
+ del gradients
141
+
142
+ # If the change to the mean approximation is insignificant (to all outputs)
143
+ # we stop the calculation.
120
144
  with tape.stop_recording():
121
- # Compute gradients of f_v with respect to the weights
122
- gradients = tape.gradient(f_v, weight_tensor)
123
- gradients = self._reshape_gradients(gradients,
124
- output_channel_axis,
125
- num_of_scores)
126
- approx = tf.reduce_sum(tf.pow(gradients, 2.0), axis=1)
127
-
128
- # Free gradients
129
- del gradients
130
-
131
- # If the change to the mean approximation is insignificant (to all outputs)
132
- # we stop the calculation.
133
145
  if j > MIN_HESSIAN_ITER:
134
- # Compute new means and deltas
135
- new_mean = tf.reduce_mean(tf.stack(approximation_per_iteration + approx), axis=0)
136
- delta = new_mean - tf.reduce_mean(tf.stack(approximation_per_iteration), axis=0)
137
- is_converged = np.all(np.abs(delta) / (np.abs(new_mean) + HESSIAN_EPS) < HESSIAN_COMP_TOLERANCE)
138
- if is_converged:
139
- approximation_per_iteration.append(approx)
140
- break
141
-
142
- approximation_per_iteration.append(approx)
146
+ if prev_mean_results is not None:
147
+ new_mean_res = \
148
+ tf.convert_to_tensor([tf.reduce_mean(res) for res in ipts_hessian_trace_approx])
149
+ relative_delta_per_node = (tf.abs(new_mean_res - prev_mean_results) /
150
+ (tf.abs(new_mean_res) + 1e-6))
151
+ max_delta = tf.reduce_max(relative_delta_per_node)
152
+ if max_delta < HESSIAN_COMP_TOLERANCE:
153
+ break
143
154
 
144
- # Compute the mean of the approximations
145
- final_approx = tf.reduce_mean(tf.stack(approximation_per_iteration), axis=0)
155
+ prev_mean_results = tf.convert_to_tensor([tf.reduce_mean(res) for res in ipts_hessian_trace_approx])
146
156
 
147
157
  # Free gradient tape
148
158
  del tape
149
159
 
150
160
  if self.hessian_request.granularity == HessianInfoGranularity.PER_TENSOR:
151
- if final_approx.shape != (1,): # pragma: no cover
152
- Logger.critical(f"For HessianInfoGranularity.PER_TENSOR, the expected score shape is (1,), but found {final_approx.shape}.")
161
+ for final_approx in ipts_hessian_trace_approx:
162
+ if final_approx.shape != (1,): # pragma: no cover
163
+ Logger.critical(f"For HessianInfoGranularity.PER_TENSOR, the expected score shape is (1,), "
164
+ f"but found {final_approx.shape}.")
153
165
  elif self.hessian_request.granularity == HessianInfoGranularity.PER_ELEMENT:
154
166
  # Reshaping the scores to the original weight shape
155
- final_approx = tf.reshape(final_approx, weight_tensor.shape)
167
+ ipts_hessian_trace_approx = \
168
+ [tf.reshape(final_approx, s) for final_approx, s in
169
+ zip(ipts_hessian_trace_approx, tensors_original_shape)]
156
170
 
157
171
  # Add a batch axis to the Hessian approximation tensor (to align with the expected returned shape)
158
172
  # We assume per-image computation, so the batch axis size is 1.
159
- final_approx = final_approx[np.newaxis, ...]
173
+ final_approx = [r_final_approx[np.newaxis, ...].numpy()
174
+ for r_final_approx in ipts_hessian_trace_approx]
160
175
 
161
- return [final_approx.numpy()]
176
+ return final_approx
162
177
 
163
178
  def _reshape_gradients(self,
164
179
  gradients: tf.Tensor,
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
15
+ from tqdm import tqdm
16
16
  from typing import List
17
17
  import torch
18
18
  from torch import autograd
@@ -48,11 +48,6 @@ class WeightsTraceHessianCalculatorPytorch(TraceHessianCalculatorPytorch):
48
48
  num_iterations_for_approximation: Number of iterations to use when approximating the Hessian trace.
49
49
  """
50
50
 
51
- if len(trace_hessian_request.target_nodes) > 1: # pragma: no cover
52
- Logger.critical(f"Weights Hessian approximation is currently supported only for a single target node,"
53
- f" but the provided request contains the following target nodes: "
54
- f"{trace_hessian_request.target_nodes}.")
55
-
56
51
  super(WeightsTraceHessianCalculatorPytorch, self).__init__(graph=graph,
57
52
  input_images=input_images,
58
53
  fw_impl=fw_impl,
@@ -74,73 +69,84 @@ class WeightsTraceHessianCalculatorPytorch(TraceHessianCalculatorPytorch):
74
69
  The function returns a list for compatibility reasons.
75
70
  """
76
71
 
77
- # Check if the target node's layer type is supported.
78
- # We assume that weights Hessian computation is done only for a single node at each request.
79
- target_node = self.hessian_request.target_nodes[0]
80
- if not DEFAULT_PYTORCH_INFO.is_kernel_op(target_node.type):
81
- Logger.critical(f"Hessian information with respect to weights is not supported for "
82
- f"{target_node.type} layers.") # pragma: no cover
83
-
84
72
  # Float model
85
73
  model, _ = FloatPyTorchModelBuilder(graph=self.graph).build_model()
86
74
 
87
- # Get the weight attributes for the target node type
88
- weights_attributes = DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(target_node.type)
89
-
90
- # Get the weight tensor for the target node
91
- if len(weights_attributes) != 1: # pragma: no cover
92
- Logger.critical(f"Currently, Hessian scores with respect to weights are supported only for nodes with a "
93
- f"single weight attribute. {len(weights_attributes)} attributes found.")
94
-
95
- weights_tensor = getattr(getattr(model, target_node.name), weights_attributes[0])
96
-
97
- # Get the output channel index
98
- output_channel_axis, _ = DEFAULT_PYTORCH_INFO.kernel_channels_mapping.get(target_node.type)
99
- shape_channel_axis = [i for i in range(len(weights_tensor.shape))]
100
- if self.hessian_request.granularity == HessianInfoGranularity.PER_OUTPUT_CHANNEL:
101
- shape_channel_axis.remove(output_channel_axis)
102
- elif self.hessian_request.granularity == HessianInfoGranularity.PER_ELEMENT:
103
- shape_channel_axis = ()
104
-
105
75
  # Run model inference
106
76
  outputs = model(self.input_images)
107
77
  output_tensor = self.concat_tensors(outputs)
108
78
  device = output_tensor.device
109
79
 
110
- approximation_per_iteration = []
111
- for j in range(self.num_iterations_for_approximation):
80
+ ipts_hessian_trace_approx = [torch.tensor([0.0],
81
+ requires_grad=True,
82
+ device=device)
83
+ for _ in range(len(self.hessian_request.target_nodes))]
84
+
85
+ prev_mean_results = None
86
+ for j in tqdm(range(self.num_iterations_for_approximation)):
112
87
  # Getting a random vector with normal distribution and the same shape as the model output
113
88
  v = torch.randn_like(output_tensor, device=device)
114
89
  f_v = torch.mean(torch.sum(v * output_tensor, dim=-1))
115
- # Compute gradients of f_v with respect to the weights
116
- f_v_grad = autograd.grad(outputs=f_v,
117
- inputs=weights_tensor,
118
- retain_graph=True)[0]
119
-
120
- # Trace{A^T * A} = sum of all squares values of A
121
- approx = f_v_grad ** 2
122
- if len(shape_channel_axis) > 0:
123
- approx = torch.sum(approx, dim=shape_channel_axis)
124
-
90
+ for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
91
+
92
+ # Check if the target node's layer type is supported.
93
+ if not DEFAULT_PYTORCH_INFO.is_kernel_op(ipt_node.type):
94
+ Logger.critical(f"Hessian information with respect to weights is not supported for "
95
+ f"{ipt_node.type} layers.") # pragma: no cover
96
+
97
+ # Get the weight attributes for the target node type
98
+ weights_attributes = DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(ipt_node.type)
99
+
100
+ # Get the weight tensor for the target node
101
+ if len(weights_attributes) != 1: # pragma: no cover
102
+ Logger.critical(f"Currently, Hessian scores with respect to weights are supported only for nodes with a "
103
+ f"single weight attribute. {len(weights_attributes)} attributes found.")
104
+
105
+ weights_tensor = getattr(getattr(model, ipt_node.name), weights_attributes[0])
106
+
107
+ # Get the output channel index
108
+ output_channel_axis, _ = DEFAULT_PYTORCH_INFO.kernel_channels_mapping.get(ipt_node.type)
109
+ shape_channel_axis = [i for i in range(len(weights_tensor.shape))]
110
+ if self.hessian_request.granularity == HessianInfoGranularity.PER_OUTPUT_CHANNEL:
111
+ shape_channel_axis.remove(output_channel_axis)
112
+ elif self.hessian_request.granularity == HessianInfoGranularity.PER_ELEMENT:
113
+ shape_channel_axis = ()
114
+
115
+ # Compute gradients of f_v with respect to the weights
116
+ f_v_grad = autograd.grad(outputs=f_v,
117
+ inputs=weights_tensor,
118
+ retain_graph=True)[0]
119
+
120
+ # Trace{A^T * A} = sum of all squares values of A
121
+ approx = f_v_grad ** 2
122
+ if len(shape_channel_axis) > 0:
123
+ approx = torch.sum(approx, dim=shape_channel_axis)
124
+
125
+ # Update node Hessian approximation mean over random iterations
126
+ ipts_hessian_trace_approx[i] = (j * ipts_hessian_trace_approx[i] + approx) / (j + 1)
127
+
128
+ # If the change to the maximal mean Hessian approximation is insignificant we stop the calculation
129
+ # Note that we do not consider granularity when computing the mean
125
130
  if j > MIN_HESSIAN_ITER:
126
- new_mean = (torch.sum(torch.stack(approximation_per_iteration), dim=0) + approx)/(j+1)
127
- delta = new_mean - torch.mean(torch.stack(approximation_per_iteration), dim=0)
128
- converged_tensor = torch.abs(delta) / (torch.abs(new_mean) + HESSIAN_EPS) < HESSIAN_COMP_TOLERANCE
129
- if torch.all(converged_tensor):
130
- break
131
-
132
- approximation_per_iteration.append(approx)
131
+ if prev_mean_results is not None:
132
+ new_mean_res = torch.as_tensor([torch.mean(res) for res in ipts_hessian_trace_approx],
133
+ device=device)
134
+ relative_delta_per_node = (torch.abs(new_mean_res - prev_mean_results) /
135
+ (torch.abs(new_mean_res) + 1e-6))
136
+ max_delta = torch.max(relative_delta_per_node)
137
+ if max_delta < HESSIAN_COMP_TOLERANCE:
138
+ break
133
139
 
134
- # Compute the mean of the approximations
135
- final_approx = torch.mean(torch.stack(approximation_per_iteration), dim=0)
140
+ prev_mean_results = torch.as_tensor([torch.mean(res) for res in ipts_hessian_trace_approx], device=device)
136
141
 
137
142
  # Make sure all final shape are tensors and not scalar
138
143
  if self.hessian_request.granularity == HessianInfoGranularity.PER_TENSOR:
139
- final_approx = final_approx.reshape(1)
144
+ ipts_hessian_trace_approx = [final_approx.reshape(1) for final_approx in ipts_hessian_trace_approx]
140
145
 
141
146
  # Add a batch axis to the Hessian approximation tensor (to align with the expected returned shape).
142
147
  # We assume per-image computation, so the batch axis size is 1.
143
- final_approx = final_approx[np.newaxis, ...]
148
+ final_approx = [r_final_approx[np.newaxis, ...].detach().cpu().numpy()
149
+ for r_final_approx in ipts_hessian_trace_approx]
144
150
 
145
- return [final_approx.detach().cpu().numpy()]
151
+ return final_approx
146
152